From ce889e990e95a2aa74889f14d281dcc360ac9ed5 Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Sun, 24 Mar 2024 22:21:47 +0800 Subject: [PATCH] some minor bug fixs (#41) * fix joinset leak; * fix udp packet format * fix trace log panic * avoid waiting after listener accept --- src/common/mod.rs | 88 +++++++++++++++++++++++++++++++++ src/connector/udp_hole_punch.rs | 27 +++++----- src/gateway/tcp_proxy.rs | 12 +++-- src/instance/listeners.rs | 22 +++++---- src/peers/packet.rs | 15 +++++- src/tunnels/udp_tunnel.rs | 71 +++++++++++++++++--------- 6 files changed, 186 insertions(+), 49 deletions(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index 6647f58..0aeeed1 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,3 +1,11 @@ +use std::{ + fmt::Debug, + future, + sync::{Arc, Mutex}, +}; +use tokio::task::JoinSet; +use tracing::Instrument; + pub mod config; pub mod constants; pub mod error; @@ -30,3 +38,83 @@ pub type PeerId = u32; pub fn new_peer_id() -> PeerId { rand::random() } + +pub fn join_joinset_background( + js: Arc>>, + origin: String, +) { + let js = Arc::downgrade(&js); + tokio::spawn( + async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + if js.weak_count() == 0 { + tracing::info!("joinset task exit"); + break; + } + + future::poll_fn(|cx| { + tracing::info!("try join joinset tasks"); + let Some(js) = js.upgrade() else { + return std::task::Poll::Ready(()); + }; + + let mut js = js.lock().unwrap(); + while !js.is_empty() { + let ret = js.poll_join_next(cx); + if ret.is_pending() { + return std::task::Poll::Pending; + } + } + + std::task::Poll::Ready(()) + }) + .await; + } + } + .instrument(tracing::info_span!( + "join_joinset_background", + origin = origin + )), + ); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_join_joinset_backgroud() { + let js = Arc::new(Mutex::new(JoinSet::<()>::new())); + join_joinset_background(js.clone(), "TEST".to_owned()); + js.try_lock().unwrap().spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + }); + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + assert!(js.try_lock().unwrap().is_empty()); + + for _ in 0..5 { + js.try_lock().unwrap().spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + }); + tokio::task::yield_now().await; + } + + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + + for _ in 0..5 { + js.try_lock().unwrap().spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + }); + tokio::task::yield_now().await; + } + + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + assert!(js.try_lock().unwrap().is_empty()); + + let weak_js = Arc::downgrade(&js); + drop(js); + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + assert_eq!(weak_js.weak_count(), 0); + } +} diff --git a/src/connector/udp_hole_punch.rs b/src/connector/udp_hole_punch.rs index 1eea53a..016528f 100644 --- a/src/connector/udp_hole_punch.rs +++ b/src/connector/udp_hole_punch.rs @@ -8,8 +8,8 @@ use tracing::Instrument; use crate::{ common::{ - constants, error::Error, global_ctx::ArcGlobalCtx, rkyv_util::encode_to_bytes, - stun::StunInfoCollectorTrait, PeerId, + constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, + rkyv_util::encode_to_bytes, stun::StunInfoCollectorTrait, PeerId, }, peers::peer_manager::PeerManager, rpc::NatType, @@ -75,9 +75,15 @@ impl UdpHolePunchListener { while let Ok(conn) = listener.accept().await { last_connected_time_clone.store(std::time::Instant::now()); tracing::warn!(?conn, "udp hole punching listener got peer connection"); - if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await { - tracing::error!(?e, "failed to add tunnel as server in hole punch listener"); - } + let peer_mgr = peer_mgr.clone(); + tokio::spawn(async move { + if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await { + tracing::error!( + ?e, + "failed to add tunnel as server in hole punch listener" + ); + } + }); } running_clone.store(false); @@ -115,7 +121,7 @@ struct UdpHolePunchConnectorData { struct UdpHolePunchRpcServer { data: Arc, - tasks: Arc>>, + tasks: Arc>>, } #[tarpc::server] @@ -140,7 +146,7 @@ impl UdpHolePunchService for UdpHolePunchRpcServer { || my_udp_nat_type == NatType::Restricted as i32 { // send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second - self.tasks.lock().await.spawn(async move { + self.tasks.lock().unwrap().spawn(async move { for _ in 0..10 { tracing::info!(?local_mapped_addr, "sending hole punching packet"); // generate a 128 bytes vec with random data @@ -164,10 +170,9 @@ impl UdpHolePunchService for UdpHolePunchRpcServer { impl UdpHolePunchRpcServer { pub fn new(data: Arc) -> Self { - Self { - data, - tasks: Arc::new(Mutex::new(JoinSet::new())), - } + let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); + join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned()); + Self { data, tasks } } async fn select_listener(&self) -> Option<(Arc, SocketAddr)> { diff --git a/src/gateway/tcp_proxy.rs b/src/gateway/tcp_proxy.rs index ad30bf2..f5abd52 100644 --- a/src/gateway/tcp_proxy.rs +++ b/src/gateway/tcp_proxy.rs @@ -16,6 +16,7 @@ use tracing::Instrument; use crate::common::error::Result; use crate::common::global_ctx::GlobalCtx; +use crate::common::join_joinset_background; use crate::common::netns::NetNS; use crate::peers::packet::{self, ArchivedPacket}; use crate::peers::peer_manager::PeerManager; @@ -71,7 +72,7 @@ pub struct TcpProxy { peer_manager: Arc, local_port: AtomicU16, - tasks: Arc>>, + tasks: Arc>>, syn_map: SynSockMap, conn_map: ConnSockMap, @@ -215,7 +216,7 @@ impl TcpProxy { peer_manager, local_port: AtomicU16::new(0), - tasks: Arc::new(Mutex::new(JoinSet::new())), + tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())), syn_map: Arc::new(DashMap::new()), conn_map: Arc::new(DashMap::new()), @@ -247,6 +248,7 @@ impl TcpProxy { self.peer_manager .add_nic_packet_process_pipeline(Box::new(self.clone())) .await; + join_joinset_background(self.tasks.clone(), "TcpProxy".to_owned()); Ok(()) } @@ -268,7 +270,7 @@ impl TcpProxy { tokio::time::sleep(Duration::from_secs(10)).await; } }; - tasks.lock().await.spawn(syn_map_cleaner_task); + tasks.lock().unwrap().spawn(syn_map_cleaner_task); Ok(()) } @@ -312,7 +314,7 @@ impl TcpProxy { let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone()); assert!(old_nat_val.is_none()); - tasks.lock().await.spawn(Self::connect_to_nat_dst( + tasks.lock().unwrap().spawn(Self::connect_to_nat_dst( net_ns.clone(), tcp_stream, conn_map.clone(), @@ -325,7 +327,7 @@ impl TcpProxy { }; self.tasks .lock() - .await + .unwrap() .spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener"))); Ok(()) diff --git a/src/instance/listeners.rs b/src/instance/listeners.rs index 70f2052..36a725c 100644 --- a/src/instance/listeners.rs +++ b/src/instance/listeners.rs @@ -100,15 +100,19 @@ impl ListenerManage tunnel_info.remote_addr.clone(), )); tracing::info!(ret = ?ret, "conn accepted"); - let server_ret = peer_manager.handle_tunnel(ret).await; - if let Err(e) = &server_ret { - global_ctx.issue_event(GlobalCtxEvent::ConnectionError( - tunnel_info.local_addr, - tunnel_info.remote_addr, - e.to_string(), - )); - tracing::error!(error = ?e, "handle conn error"); - } + let peer_manager = peer_manager.clone(); + let global_ctx = global_ctx.clone(); + tokio::spawn(async move { + let server_ret = peer_manager.handle_tunnel(ret).await; + if let Err(e) = &server_ret { + global_ctx.issue_event(GlobalCtxEvent::ConnectionError( + tunnel_info.local_addr, + tunnel_info.remote_addr, + e.to_string(), + )); + tracing::error!(error = ?e, "handle conn error"); + } + }); } } diff --git a/src/peers/packet.rs b/src/peers/packet.rs index bd50c58..943b07c 100644 --- a/src/peers/packet.rs +++ b/src/peers/packet.rs @@ -99,7 +99,7 @@ pub enum PacketType { TaRpc = 6, } -#[derive(Archive, Deserialize, Serialize, Debug)] +#[derive(Archive, Deserialize, Serialize)] #[archive(compare(PartialEq), check_bytes)] // Derives can be passed through to the generated type: pub struct Packet { @@ -109,6 +109,19 @@ pub struct Packet { pub payload: String, } +impl std::fmt::Debug for Packet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}", + self.from_peer, + self.to_peer, + self.packet_type, + &self.payload.as_bytes() + ) + } +} + impl std::fmt::Debug for ArchivedPacket { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( diff --git a/src/tunnels/udp_tunnel.rs b/src/tunnels/udp_tunnel.rs index 366169a..3d49826 100644 --- a/src/tunnels/udp_tunnel.rs +++ b/src/tunnels/udp_tunnel.rs @@ -13,7 +13,10 @@ use tokio_util::{ use tracing::Instrument; use crate::{ - common::rkyv_util::{self, encode_to_bytes, vec_to_string}, + common::{ + join_joinset_background, + rkyv_util::{self, encode_to_bytes, vec_to_string}, + }, rpc::TunnelInfo, tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector}, }; @@ -27,7 +30,7 @@ use super::{ pub const UDP_DATA_MTU: usize = 2500; -#[derive(Archive, Deserialize, Serialize, Debug)] +#[derive(Archive, Deserialize, Serialize)] #[archive(compare(PartialEq), check_bytes)] // Derives can be passed through to the generated type: pub enum UdpPacketPayload { @@ -37,14 +40,29 @@ pub enum UdpPacketPayload { Data(String), } +impl std::fmt::Debug for UdpPacketPayload { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut tmp = f.debug_struct("ArchivedUdpPacketPayload"); + match self { + UdpPacketPayload::Syn => tmp.field("Syn", &"").finish(), + UdpPacketPayload::Sack => tmp.field("Sack", &"").finish(), + UdpPacketPayload::HolePunch(s) => tmp.field("HolePunch", &s.as_bytes()).finish(), + UdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(), + } + } +} + #[derive(Archive, Deserialize, Serialize, Debug)] #[archive(compare(PartialEq), check_bytes)] #[archive_attr(derive(Debug))] pub struct UdpPacket { pub conn_id: u32, + pub magic: u32, pub payload: UdpPacketPayload, } +const UDP_PACKET_MAGIC: u32 = 0x19941126; + impl std::fmt::Debug for ArchivedUdpPacketPayload { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut tmp = f.debug_struct("ArchivedUdpPacketPayload"); @@ -63,6 +81,7 @@ impl UdpPacket { pub fn new_data_packet(conn_id: u32, data: Vec) -> Self { Self { conn_id, + magic: UDP_PACKET_MAGIC, payload: UdpPacketPayload::Data(vec_to_string(data)), } } @@ -70,6 +89,7 @@ impl UdpPacket { pub fn new_hole_punch_packet(data: Vec) -> Self { Self { conn_id: 0, + magic: UDP_PACKET_MAGIC, payload: UdpPacketPayload::HolePunch(vec_to_string(data)), } } @@ -77,6 +97,7 @@ impl UdpPacket { pub fn new_syn_packet(conn_id: u32) -> Self { Self { conn_id, + magic: UDP_PACKET_MAGIC, payload: UdpPacketPayload::Syn, } } @@ -84,6 +105,7 @@ impl UdpPacket { pub fn new_sack_packet(conn_id: u32) -> Self { Self { conn_id, + magic: UDP_PACKET_MAGIC, payload: UdpPacketPayload::Sack, } } @@ -100,6 +122,11 @@ fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option { return None; } + if udp_packet.magic != UDP_PACKET_MAGIC { + tracing::warn!(?udp_packet, "udp magic not match"); + return None; + } + let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else { tracing::warn!(?udp_packet, "udp payload not data"); return None; @@ -189,7 +216,7 @@ pub struct UdpTunnelListener { socket: Option>, sock_map: Arc>, - forward_tasks: Arc>>, + forward_tasks: Arc>>, conn_recv: tokio::sync::mpsc::Receiver>, conn_send: Option>>, @@ -202,7 +229,7 @@ impl UdpTunnelListener { addr, socket: None, sock_map: Arc::new(DashMap::new()), - forward_tasks: Arc::new(Mutex::new(JoinSet::new())), + forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())), conn_recv, conn_send: Some(conn_send), } @@ -234,7 +261,7 @@ impl UdpTunnelListener { async fn handle_connect( socket: Arc, addr: SocketAddr, - forward_tasks: Arc>>, + forward_tasks: Arc>>, sock_map: Arc>, local_url: url::Url, conn_id: u32, @@ -251,7 +278,7 @@ impl UdpTunnelListener { let addr_copy = addr.clone(); sock_map.insert(addr, Arc::new(Mutex::new(ss_pair))); let ctunnel_stream = ctunnel.pin_stream(); - forward_tasks.lock().await.spawn(async move { + forward_tasks.lock().unwrap().spawn(async move { let ret = ctunnel_stream .map(|v| { tracing::trace!(?v, "udp stream recv something in forward task"); @@ -304,7 +331,7 @@ impl TunnelListener for UdpTunnelListener { let sock_map = self.sock_map.clone(); let conn_send = self.conn_send.take().unwrap(); let local_url = self.local_url().clone(); - self.forward_tasks.lock().await.spawn( + self.forward_tasks.lock().unwrap().spawn( async move { loop { let mut buf = BytesMut::new(); @@ -323,6 +350,11 @@ impl TunnelListener for UdpTunnelListener { continue; }; + if udp_packet.magic != UDP_PACKET_MAGIC { + tracing::info!(?udp_packet, "udp magic not match"); + continue; + } + if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) { let Ok(conn) = Self::handle_connect( socket.clone(), @@ -350,22 +382,7 @@ impl TunnelListener for UdpTunnelListener { .instrument(tracing::info_span!("udp forward task", ?self.socket)), ); - // let forward_tasks_clone = self.forward_tasks.clone(); - // tokio::spawn(async move { - // loop { - // let mut locked_forward_tasks = forward_tasks_clone.lock().await; - // tokio::select! { - // ret = locked_forward_tasks.join_next() => { - // tracing::warn!(?ret, "udp forward task exit"); - // } - // else => { - // drop(locked_forward_tasks); - // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - // continue; - // } - // } - // } - // }); + join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned()); Ok(()) } @@ -453,6 +470,14 @@ impl UdpTunnelConnector { ))); }; + if udp_packet.magic != UDP_PACKET_MAGIC { + tracing::info!(?udp_packet, "udp magic not match"); + return Err(super::TunnelError::ConnectError(format!( + "udp connect error, magic not match. magic: {:?}", + udp_packet.magic + ))); + } + if conn_id != udp_packet.conn_id { return Err(super::TunnelError::ConnectError(format!( "udp connect error, conn id not match. conn_id: {:?}, {:?}",