diff --git a/Cargo.lock b/Cargo.lock index 41b2175..10b63ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -289,6 +289,16 @@ dependencies = [ "syn 2.0.74", ] +[[package]] +name = "async-ringbuf" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32690af15155711360e74119b99605416c9e4dfd45b0859bd9af795a50693bec" +dependencies = [ + "futures", + "ringbuf", +] + [[package]] name = "async-signal" version = "0.2.10" @@ -1534,6 +1544,7 @@ dependencies = [ "aes-gcm", "anyhow", "async-recursion", + "async-ringbuf", "async-stream", "async-trait", "atomic-shim", @@ -1580,6 +1591,7 @@ dependencies = [ "regex", "reqwest 0.11.27", "ring 0.17.8", + "ringbuf", "rpc_build", "rstest", "rust-i18n", @@ -4882,6 +4894,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ringbuf" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb0d14419487131a897031a7e81c3b23d092296984fac4eb6df48cc4e3b2f3c5" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "rpc_build" version = "0.1.0" diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index f29187e..91c64fc 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -178,6 +178,9 @@ wildmatch = "2.3.4" rust-i18n = "3" sys-locale = "0.3" +ringbuf = "0.4.5" +async-ringbuf = "0.3.1" + [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.52", features = [ "Win32_Networking_WinSock", diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index 81c9197..8454829 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -292,7 +292,6 @@ rust_i18n::i18n!("locales", fallback = "en"); impl Cli { fn parse_listeners(no_listener: bool, listeners: Vec) -> Vec { - println!("parsing listeners: {:?}", listeners); let proto_port_offset = vec![("tcp", 0), ("udp", 0), ("wg", 1), ("ws", 1), ("wss", 2)]; if no_listener || listeners.is_empty() { @@ -376,7 +375,6 @@ impl From for TomlConfigLoader { let cfg = TomlConfigLoader::default(); - cfg.set_hostname(cli.hostname); cfg.set_network_identity(NetworkIdentity::new(cli.network_name, cli.network_secret)); diff --git a/easytier/src/instance/virtual_nic.rs b/easytier/src/instance/virtual_nic.rs index 285726e..128cd53 100644 --- a/easytier/src/instance/virtual_nic.rs +++ b/easytier/src/instance/virtual_nic.rs @@ -598,6 +598,7 @@ impl NicCtx { } Self::do_forward_nic_to_peers_ipv4(ret.unwrap(), mgr.as_ref()).await; } + panic!("nic stream closed"); }); Ok(()) @@ -618,6 +619,7 @@ impl NicCtx { tracing::error!(?ret, "do_forward_tunnel_to_nic sink error"); } } + panic!("peer packet receiver closed"); }); } diff --git a/easytier/src/peers/foreign_network_manager.rs b/easytier/src/peers/foreign_network_manager.rs index 3a114b8..67167ad 100644 --- a/easytier/src/peers/foreign_network_manager.rs +++ b/easytier/src/peers/foreign_network_manager.rs @@ -522,7 +522,7 @@ mod tests { tests::{connect_peer_manager, wait_route_appear}, }, proto::common::NatType, - tunnel::common::tests::{enable_log, wait_for_condition}, + tunnel::common::tests::wait_for_condition, }; use super::*; diff --git a/easytier/src/peers/peer_conn.rs b/easytier/src/peers/peer_conn.rs index 0773809..0fa73cf 100644 --- a/easytier/src/peers/peer_conn.rs +++ b/easytier/src/peers/peer_conn.rs @@ -25,6 +25,7 @@ use zerocopy::AsBytes; use crate::{ common::{ config::{NetworkIdentity, NetworkSecretDigest}, + defer, error::Error, global_ctx::ArcGlobalCtx, PeerId, @@ -103,7 +104,9 @@ impl PeerConn { my_peer_id, global_ctx, - tunnel: Arc::new(Mutex::new(Box::new(mpsc_tunnel))), + tunnel: Arc::new(Mutex::new(Box::new(defer::Defer::new(move || { + mpsc_tunnel.close() + })))), sink, recv: Arc::new(Mutex::new(Some(recv))), tunnel_info, diff --git a/easytier/src/tunnel/mpsc.rs b/easytier/src/tunnel/mpsc.rs index f7496d9..37f3498 100644 --- a/easytier/src/tunnel/mpsc.rs +++ b/easytier/src/tunnel/mpsc.rs @@ -7,11 +7,10 @@ use tokio::time::timeout; use crate::common::scoped_task::ScopedTask; -use super::{ - packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream, -}; +use super::{packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream}; -use tachyonix::{channel, Receiver, Sender}; +// use tokio::sync::mpsc::{channel, error::TrySendError, Receiver, Sender}; +use tachyonix::{channel, Receiver, Sender, TrySendError}; use futures::SinkExt; @@ -26,8 +25,8 @@ impl MpscTunnelSender { pub fn try_send(&self, item: ZCPacket) -> Result<(), TunnelError> { self.0.try_send(item).map_err(|e| match e { - tachyonix::TrySendError::Full(_) => TunnelError::BufferFull, - tachyonix::TrySendError::Closed(_) => TunnelError::Shutdown, + TrySendError::Full(_) => TunnelError::BufferFull, + TrySendError::Closed(_) => TunnelError::Shutdown, }) } } @@ -53,6 +52,7 @@ impl MpscTunnel { break; } } + rx.close(); let close_ret = timeout(Duration::from_secs(5), sink.close()).await; tracing::warn!(?close_ret, "mpsc close sink"); }); @@ -72,7 +72,10 @@ impl MpscTunnel { let item = rx.recv().await.with_context(|| "recv error")?; sink.feed(item).await?; while let Ok(item) = rx.try_recv() { - if let Err(e) = sink.feed(item).await { + if let Err(e) = timeout(Duration::from_secs(5), sink.feed(item)) + .await + .unwrap() + { tracing::error!(?e, "feed error"); break; } diff --git a/easytier/src/tunnel/ring.rs b/easytier/src/tunnel/ring.rs index 1be12cb..58a221d 100644 --- a/easytier/src/tunnel/ring.rs +++ b/easytier/src/tunnel/ring.rs @@ -1,17 +1,15 @@ use std::{ collections::HashMap, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Poll, Waker}, + fmt::Debug, + sync::Arc, + task::{ready, Poll}, }; -use atomicbox::AtomicOptionBox; -use crossbeam_queue::ArrayQueue; +use async_ringbuf::{traits::*, AsyncHeapCons, AsyncHeapProd, AsyncHeapRb}; +use crossbeam::atomic::AtomicCell; use async_trait::async_trait; -use futures::{Sink, Stream}; +use futures::{Sink, SinkExt, Stream, StreamExt}; use once_cell::sync::Lazy; use tokio::sync::{ @@ -30,83 +28,30 @@ use super::{ static RING_TUNNEL_CAP: usize = 128; -#[derive(Debug)] +type RingLock = parking_lot::Mutex<()>; + +type RingItem = SinkItem; + pub struct RingTunnel { id: Uuid, - ring: ArrayQueue, - closed: AtomicBool, - wait_for_new_item: AtomicOptionBox, - wait_for_empty_slot: AtomicOptionBox, + ring_cons_impl: AtomicCell>>, + ring_prod_impl: AtomicCell>>, } impl RingTunnel { - fn wait_for_new_item(&self, cx: &mut std::task::Context<'_>) -> Poll { - let ret = self - .wait_for_new_item - .swap(Some(Box::new(cx.waker().clone())), Ordering::AcqRel); - if let Some(old_waker) = ret { - assert!(old_waker.will_wake(cx.waker())); - } - Poll::Pending - } - - fn wait_for_empty_slot(&self, cx: &mut std::task::Context<'_>) -> Poll { - let ret = self - .wait_for_empty_slot - .swap(Some(Box::new(cx.waker().clone())), Ordering::AcqRel); - if let Some(old_waker) = ret { - assert!(old_waker.will_wake(cx.waker())); - } - Poll::Pending - } - - fn notify_new_item(&self) { - if let Some(w) = self.wait_for_new_item.take(Ordering::AcqRel) { - tracing::trace!(?self.id, "notify new item"); - w.wake(); - } - } - - fn notify_empty_slot(&self) { - if let Some(w) = self.wait_for_empty_slot.take(Ordering::AcqRel) { - tracing::trace!(?self.id, "notify empty slot"); - w.wake(); - } - } - fn id(&self) -> &Uuid { &self.id } - pub fn len(&self) -> usize { - self.ring.len() - } - - pub fn capacity(&self) -> usize { - self.ring.capacity() - } - - fn close(&self) { - tracing::info!("close ring tunnel {:?}", self.id); - self.closed - .store(true, std::sync::atomic::Ordering::Relaxed); - self.notify_new_item(); - } - - fn closed(&self) -> bool { - self.closed.load(std::sync::atomic::Ordering::Relaxed) - } - pub fn new(cap: usize) -> Self { let id = Uuid::new_v4(); + let ring_impl = AsyncHeapRb::new(cap); + let (ring_prod_impl, ring_cons_impl) = ring_impl.split(); Self { id: id.clone(), - ring: ArrayQueue::new(cap), - closed: AtomicBool::new(false), - - wait_for_new_item: AtomicOptionBox::new(None), - wait_for_empty_slot: AtomicOptionBox::new(None), + ring_cons_impl: AtomicCell::new(Some(ring_cons_impl)), + ring_prod_impl: AtomicCell::new(Some(ring_prod_impl)), } } @@ -117,14 +62,23 @@ impl RingTunnel { } } -#[derive(Debug)] +impl Debug for RingTunnel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RingTunnel").field("id", &self.id).finish() + } +} + pub struct RingStream { - tunnel: Arc, + id: Uuid, + ring_cons_impl: AsyncHeapCons, } impl RingStream { pub fn new(tunnel: Arc) -> Self { - Self { tunnel } + Self { + id: tunnel.id.clone(), + ring_cons_impl: tunnel.ring_cons_impl.take().unwrap(), + } } } @@ -135,56 +89,39 @@ impl Stream for RingStream { self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - let s = self.get_mut(); - let ret = s.tunnel.ring.pop(); + let ret = ready!(self.get_mut().ring_cons_impl.poll_next_unpin(cx)); match ret { - Some(v) => { - s.tunnel.notify_empty_slot(); - return Poll::Ready(Some(Ok(v))); - } - None => { - if s.tunnel.closed() { - tracing::warn!("ring recv tunnel {:?} closed", s.tunnel.id()); - return Poll::Ready(None); - } else { - tracing::trace!("waiting recv buffer, id: {}", s.tunnel.id()); - } - s.tunnel.wait_for_new_item(cx) - } + Some(item) => Poll::Ready(Some(Ok(item))), + None => Poll::Ready(None), } } } -#[derive(Debug)] -pub struct RingSink { - tunnel: Arc, +impl Debug for RingStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RingStream") + .field("id", &self.id) + .field("len", &self.ring_cons_impl.base().occupied_len()) + .field("cap", &self.ring_cons_impl.base().capacity()) + .finish() + } } -impl Drop for RingSink { - fn drop(&mut self) { - self.tunnel.close(); - } +pub struct RingSink { + id: Uuid, + ring_prod_impl: AsyncHeapProd, } impl RingSink { pub fn new(tunnel: Arc) -> Self { - Self { tunnel } - } - - pub fn push_no_check(&self, item: SinkItem) -> Result<(), TunnelError> { - if self.tunnel.closed() { - return Err(TunnelError::Shutdown); + Self { + id: tunnel.id.clone(), + ring_prod_impl: tunnel.ring_prod_impl.take().unwrap(), } - - tracing::trace!(id=?self.tunnel.id(), ?item, "send buffer"); - let _ = self.tunnel.ring.push(item); - self.tunnel.notify_new_item(); - - Ok(()) } - pub fn has_empty_slot(&self) -> bool { - self.tunnel.len() < self.tunnel.capacity() + pub fn try_send(&mut self, item: RingItem) -> Result<(), RingItem> { + self.ring_prod_impl.try_push(item) } } @@ -195,37 +132,41 @@ impl Sink for RingSink { self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let self_mut = self.get_mut(); - if !self_mut.has_empty_slot() { - if self_mut.tunnel.closed() { - return Poll::Ready(Err(TunnelError::Shutdown)); - } - self_mut.tunnel.wait_for_empty_slot(cx) - } else { - Poll::Ready(Ok(())) - } + let ret = ready!(self.get_mut().ring_prod_impl.poll_ready_unpin(cx)); + Poll::Ready(ret.map_err(|_| TunnelError::Shutdown)) } fn start_send(self: std::pin::Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { - self.push_no_check(item) + self.get_mut() + .ring_prod_impl + .start_send_unpin(item) + .map_err(|_| TunnelError::Shutdown) } fn poll_flush( self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, + cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - if self.tunnel.closed() { - return Poll::Ready(Err(TunnelError::Shutdown)); - } - Poll::Ready(Ok(())) + let ret = ready!(self.get_mut().ring_prod_impl.poll_flush_unpin(cx)); + Poll::Ready(ret.map_err(|_| TunnelError::Shutdown)) } fn poll_close( self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, + cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - self.tunnel.close(); - Poll::Ready(Ok(())) + let ret = ready!(self.get_mut().ring_prod_impl.poll_close_unpin(cx)); + Poll::Ready(ret.map_err(|_| TunnelError::Shutdown)) + } +} + +impl Debug for RingSink { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RingSink") + .field("id", &self.id) + .field("len", &self.ring_prod_impl.base().occupied_len()) + .field("cap", &self.ring_prod_impl.base().capacity()) + .finish() } } diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index 639be1b..61fcd80 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -151,9 +151,9 @@ async fn forward_from_ring_to_udp( } } -async fn udp_recv_from_socket_forward_task(socket: Arc, f: F) +async fn udp_recv_from_socket_forward_task(socket: Arc, mut f: F) where - F: Fn(ZCPacket, SocketAddr) -> (), + F: FnMut(ZCPacket, SocketAddr) -> (), { let mut buf = BytesMut::new(); loop { @@ -220,7 +220,7 @@ impl UdpConnection { } } - pub fn handle_packet_from_remote(&self, zc_packet: ZCPacket) -> Result<(), TunnelError> { + pub fn handle_packet_from_remote(&mut self, zc_packet: ZCPacket) -> Result<(), TunnelError> { let header = zc_packet.udp_tunnel_header().unwrap(); let conn_id = header.conn_id.get(); @@ -232,12 +232,10 @@ impl UdpConnection { return Err(TunnelError::ConnIdNotMatch(self.conn_id, conn_id)); } - if !self.ring_sender.has_empty_slot() { - return Err(TunnelError::BufferFull); + if let Err(e) = self.ring_sender.try_send(zc_packet) { + tracing::trace!(?e, "ring sender full, drop packet"); } - self.ring_sender.push_no_check(zc_packet)?; - Ok(()) } } @@ -294,8 +292,8 @@ impl UdpTunnelListenerData { return; } - let ring_for_send_udp = Arc::new(RingTunnel::new(128)); - let ring_for_recv_udp = Arc::new(RingTunnel::new(128)); + let ring_for_send_udp = Arc::new(RingTunnel::new(32)); + let ring_for_recv_udp = Arc::new(RingTunnel::new(32)); tracing::debug!( ?ring_for_send_udp, ?ring_for_recv_udp, @@ -336,7 +334,7 @@ impl UdpTunnelListenerData { if header.msg_type == UdpPacketType::Syn as u8 { tokio::spawn(Self::handle_new_connect(self.clone(), addr, zc_packet)); } else if header.msg_type != UdpPacketType::HolePunch as u8 { - let Some(conn) = self.sock_map.get(&addr) else { + let Some(mut conn) = self.sock_map.get_mut(&addr) else { tracing::trace!(?header, "udp forward packet error, connection not found"); return; }; @@ -569,7 +567,7 @@ impl UdpTunnelConnector { let ring_recv = RingStream::new(ring_for_send_udp.clone()); let ring_sender = RingSink::new(ring_for_recv_udp.clone()); - let udp_conn = UdpConnection::new( + let mut udp_conn = UdpConnection::new( socket.clone(), conn_id, dst_addr,