some minor bug fixs (#41)

* fix joinset leak; 

* fix udp packet format

* fix trace log panic

* avoid waiting after listener accept
This commit is contained in:
Sijie.Sun 2024-03-24 22:21:47 +08:00 committed by GitHub
parent 0f6f553010
commit ce889e990e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 186 additions and 49 deletions

View File

@ -1,3 +1,11 @@
use std::{
fmt::Debug,
future,
sync::{Arc, Mutex},
};
use tokio::task::JoinSet;
use tracing::Instrument;
pub mod config;
pub mod constants;
pub mod error;
@ -30,3 +38,83 @@ pub type PeerId = u32;
pub fn new_peer_id() -> PeerId {
rand::random()
}
pub fn join_joinset_background<T: Debug + Send + Sync + 'static>(
js: Arc<Mutex<JoinSet<T>>>,
origin: String,
) {
let js = Arc::downgrade(&js);
tokio::spawn(
async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
if js.weak_count() == 0 {
tracing::info!("joinset task exit");
break;
}
future::poll_fn(|cx| {
tracing::info!("try join joinset tasks");
let Some(js) = js.upgrade() else {
return std::task::Poll::Ready(());
};
let mut js = js.lock().unwrap();
while !js.is_empty() {
let ret = js.poll_join_next(cx);
if ret.is_pending() {
return std::task::Poll::Pending;
}
}
std::task::Poll::Ready(())
})
.await;
}
}
.instrument(tracing::info_span!(
"join_joinset_background",
origin = origin
)),
);
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_join_joinset_backgroud() {
let js = Arc::new(Mutex::new(JoinSet::<()>::new()));
join_joinset_background(js.clone(), "TEST".to_owned());
js.try_lock().unwrap().spawn(async {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
});
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert!(js.try_lock().unwrap().is_empty());
for _ in 0..5 {
js.try_lock().unwrap().spawn(async {
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
});
tokio::task::yield_now().await;
}
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
for _ in 0..5 {
js.try_lock().unwrap().spawn(async {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
});
tokio::task::yield_now().await;
}
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert!(js.try_lock().unwrap().is_empty());
let weak_js = Arc::downgrade(&js);
drop(js);
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert_eq!(weak_js.weak_count(), 0);
}
}

View File

@ -8,8 +8,8 @@ use tracing::Instrument;
use crate::{
common::{
constants, error::Error, global_ctx::ArcGlobalCtx, rkyv_util::encode_to_bytes,
stun::StunInfoCollectorTrait, PeerId,
constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background,
rkyv_util::encode_to_bytes, stun::StunInfoCollectorTrait, PeerId,
},
peers::peer_manager::PeerManager,
rpc::NatType,
@ -75,9 +75,15 @@ impl UdpHolePunchListener {
while let Ok(conn) = listener.accept().await {
last_connected_time_clone.store(std::time::Instant::now());
tracing::warn!(?conn, "udp hole punching listener got peer connection");
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
tracing::error!(?e, "failed to add tunnel as server in hole punch listener");
}
let peer_mgr = peer_mgr.clone();
tokio::spawn(async move {
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
tracing::error!(
?e,
"failed to add tunnel as server in hole punch listener"
);
}
});
}
running_clone.store(false);
@ -115,7 +121,7 @@ struct UdpHolePunchConnectorData {
struct UdpHolePunchRpcServer {
data: Arc<UdpHolePunchConnectorData>,
tasks: Arc<Mutex<JoinSet<()>>>,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
}
#[tarpc::server]
@ -140,7 +146,7 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
|| my_udp_nat_type == NatType::Restricted as i32
{
// send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second
self.tasks.lock().await.spawn(async move {
self.tasks.lock().unwrap().spawn(async move {
for _ in 0..10 {
tracing::info!(?local_mapped_addr, "sending hole punching packet");
// generate a 128 bytes vec with random data
@ -164,10 +170,9 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
impl UdpHolePunchRpcServer {
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
Self {
data,
tasks: Arc::new(Mutex::new(JoinSet::new())),
}
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned());
Self { data, tasks }
}
async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> {

View File

@ -16,6 +16,7 @@ use tracing::Instrument;
use crate::common::error::Result;
use crate::common::global_ctx::GlobalCtx;
use crate::common::join_joinset_background;
use crate::common::netns::NetNS;
use crate::peers::packet::{self, ArchivedPacket};
use crate::peers::peer_manager::PeerManager;
@ -71,7 +72,7 @@ pub struct TcpProxy {
peer_manager: Arc<PeerManager>,
local_port: AtomicU16,
tasks: Arc<Mutex<JoinSet<()>>>,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
syn_map: SynSockMap,
conn_map: ConnSockMap,
@ -215,7 +216,7 @@ impl TcpProxy {
peer_manager,
local_port: AtomicU16::new(0),
tasks: Arc::new(Mutex::new(JoinSet::new())),
tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
syn_map: Arc::new(DashMap::new()),
conn_map: Arc::new(DashMap::new()),
@ -247,6 +248,7 @@ impl TcpProxy {
self.peer_manager
.add_nic_packet_process_pipeline(Box::new(self.clone()))
.await;
join_joinset_background(self.tasks.clone(), "TcpProxy".to_owned());
Ok(())
}
@ -268,7 +270,7 @@ impl TcpProxy {
tokio::time::sleep(Duration::from_secs(10)).await;
}
};
tasks.lock().await.spawn(syn_map_cleaner_task);
tasks.lock().unwrap().spawn(syn_map_cleaner_task);
Ok(())
}
@ -312,7 +314,7 @@ impl TcpProxy {
let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone());
assert!(old_nat_val.is_none());
tasks.lock().await.spawn(Self::connect_to_nat_dst(
tasks.lock().unwrap().spawn(Self::connect_to_nat_dst(
net_ns.clone(),
tcp_stream,
conn_map.clone(),
@ -325,7 +327,7 @@ impl TcpProxy {
};
self.tasks
.lock()
.await
.unwrap()
.spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener")));
Ok(())

View File

@ -100,15 +100,19 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
tunnel_info.remote_addr.clone(),
));
tracing::info!(ret = ?ret, "conn accepted");
let server_ret = peer_manager.handle_tunnel(ret).await;
if let Err(e) = &server_ret {
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
tunnel_info.local_addr,
tunnel_info.remote_addr,
e.to_string(),
));
tracing::error!(error = ?e, "handle conn error");
}
let peer_manager = peer_manager.clone();
let global_ctx = global_ctx.clone();
tokio::spawn(async move {
let server_ret = peer_manager.handle_tunnel(ret).await;
if let Err(e) = &server_ret {
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
tunnel_info.local_addr,
tunnel_info.remote_addr,
e.to_string(),
));
tracing::error!(error = ?e, "handle conn error");
}
});
}
}

View File

@ -99,7 +99,7 @@ pub enum PacketType {
TaRpc = 6,
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[derive(Archive, Deserialize, Serialize)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
pub struct Packet {
@ -109,6 +109,19 @@ pub struct Packet {
pub payload: String,
}
impl std::fmt::Debug for Packet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
self.from_peer,
self.to_peer,
self.packet_type,
&self.payload.as_bytes()
)
}
}
impl std::fmt::Debug for ArchivedPacket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(

View File

@ -13,7 +13,10 @@ use tokio_util::{
use tracing::Instrument;
use crate::{
common::rkyv_util::{self, encode_to_bytes, vec_to_string},
common::{
join_joinset_background,
rkyv_util::{self, encode_to_bytes, vec_to_string},
},
rpc::TunnelInfo,
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
};
@ -27,7 +30,7 @@ use super::{
pub const UDP_DATA_MTU: usize = 2500;
#[derive(Archive, Deserialize, Serialize, Debug)]
#[derive(Archive, Deserialize, Serialize)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
pub enum UdpPacketPayload {
@ -37,14 +40,29 @@ pub enum UdpPacketPayload {
Data(String),
}
impl std::fmt::Debug for UdpPacketPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
match self {
UdpPacketPayload::Syn => tmp.field("Syn", &"").finish(),
UdpPacketPayload::Sack => tmp.field("Sack", &"").finish(),
UdpPacketPayload::HolePunch(s) => tmp.field("HolePunch", &s.as_bytes()).finish(),
UdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(),
}
}
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))]
pub struct UdpPacket {
pub conn_id: u32,
pub magic: u32,
pub payload: UdpPacketPayload,
}
const UDP_PACKET_MAGIC: u32 = 0x19941126;
impl std::fmt::Debug for ArchivedUdpPacketPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
@ -63,6 +81,7 @@ impl UdpPacket {
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
Self {
conn_id,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::Data(vec_to_string(data)),
}
}
@ -70,6 +89,7 @@ impl UdpPacket {
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
Self {
conn_id: 0,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::HolePunch(vec_to_string(data)),
}
}
@ -77,6 +97,7 @@ impl UdpPacket {
pub fn new_syn_packet(conn_id: u32) -> Self {
Self {
conn_id,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::Syn,
}
}
@ -84,6 +105,7 @@ impl UdpPacket {
pub fn new_sack_packet(conn_id: u32) -> Self {
Self {
conn_id,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::Sack,
}
}
@ -100,6 +122,11 @@ fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
return None;
}
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::warn!(?udp_packet, "udp magic not match");
return None;
}
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
tracing::warn!(?udp_packet, "udp payload not data");
return None;
@ -189,7 +216,7 @@ pub struct UdpTunnelListener {
socket: Option<Arc<UdpSocket>>,
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
forward_tasks: Arc<Mutex<JoinSet<()>>>,
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
@ -202,7 +229,7 @@ impl UdpTunnelListener {
addr,
socket: None,
sock_map: Arc::new(DashMap::new()),
forward_tasks: Arc::new(Mutex::new(JoinSet::new())),
forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
conn_recv,
conn_send: Some(conn_send),
}
@ -234,7 +261,7 @@ impl UdpTunnelListener {
async fn handle_connect(
socket: Arc<UdpSocket>,
addr: SocketAddr,
forward_tasks: Arc<Mutex<JoinSet<()>>>,
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
local_url: url::Url,
conn_id: u32,
@ -251,7 +278,7 @@ impl UdpTunnelListener {
let addr_copy = addr.clone();
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
let ctunnel_stream = ctunnel.pin_stream();
forward_tasks.lock().await.spawn(async move {
forward_tasks.lock().unwrap().spawn(async move {
let ret = ctunnel_stream
.map(|v| {
tracing::trace!(?v, "udp stream recv something in forward task");
@ -304,7 +331,7 @@ impl TunnelListener for UdpTunnelListener {
let sock_map = self.sock_map.clone();
let conn_send = self.conn_send.take().unwrap();
let local_url = self.local_url().clone();
self.forward_tasks.lock().await.spawn(
self.forward_tasks.lock().unwrap().spawn(
async move {
loop {
let mut buf = BytesMut::new();
@ -323,6 +350,11 @@ impl TunnelListener for UdpTunnelListener {
continue;
};
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::info!(?udp_packet, "udp magic not match");
continue;
}
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
let Ok(conn) = Self::handle_connect(
socket.clone(),
@ -350,22 +382,7 @@ impl TunnelListener for UdpTunnelListener {
.instrument(tracing::info_span!("udp forward task", ?self.socket)),
);
// let forward_tasks_clone = self.forward_tasks.clone();
// tokio::spawn(async move {
// loop {
// let mut locked_forward_tasks = forward_tasks_clone.lock().await;
// tokio::select! {
// ret = locked_forward_tasks.join_next() => {
// tracing::warn!(?ret, "udp forward task exit");
// }
// else => {
// drop(locked_forward_tasks);
// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
// continue;
// }
// }
// }
// });
join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned());
Ok(())
}
@ -453,6 +470,14 @@ impl UdpTunnelConnector {
)));
};
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::info!(?udp_packet, "udp magic not match");
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, magic not match. magic: {:?}",
udp_packet.magic
)));
}
if conn_id != udp_packet.conn_id {
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, conn id not match. conn_id: {:?}, {:?}",