diff --git a/Cargo.lock b/Cargo.lock index 64f3587..1d2a025 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1623,6 +1623,7 @@ dependencies = [ "derivative", "encoding", "futures", + "futures-util", "gethostname 0.5.0", "globwalk", "http 1.1.0", diff --git a/easytier-gui/src-tauri/src/lib.rs b/easytier-gui/src-tauri/src/lib.rs index 5c057dd..0a5162f 100644 --- a/easytier-gui/src-tauri/src/lib.rs +++ b/easytier-gui/src-tauri/src/lib.rs @@ -136,7 +136,7 @@ impl NetworkConfig { } cfg.set_rpc_portal( - format!("127.0.0.1:{}", self.rpc_port) + format!("0.0.0.0:{}", self.rpc_port) .parse() .with_context(|| format!("failed to parse rpc portal port: {}", self.rpc_port))?, ); diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 569fd49..25b4704 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -203,6 +203,7 @@ zip = "0.6.6" [dev-dependencies] serial_test = "3.0.0" rstest = "0.18.2" +futures-util = "0.3.30" [target.'cfg(target_os = "linux")'.dev-dependencies] defguard_wireguard_rs = "0.4.2" diff --git a/easytier/src/common/mod.rs b/easytier/src/common/mod.rs index 1c6012f..8f3bd45 100644 --- a/easytier/src/common/mod.rs +++ b/easytier/src/common/mod.rs @@ -14,6 +14,7 @@ pub mod global_ctx; pub mod ifcfg; pub mod netns; pub mod network; +pub mod scoped_task; pub mod stun; pub mod stun_codec_ext; diff --git a/easytier/src/common/scoped_task.rs b/easytier/src/common/scoped_task.rs new file mode 100644 index 0000000..5669008 --- /dev/null +++ b/easytier/src/common/scoped_task.rs @@ -0,0 +1,134 @@ +//! This crate provides a wrapper type of Tokio's JoinHandle: `ScopedTask`, which aborts the task when it's dropped. +//! `ScopedTask` can still be awaited to join the child-task, and abort-on-drop will still trigger while it is being awaited. +//! +//! For example, if task A spawned task B but is doing something else, and task B is waiting for task C to join, +//! aborting A will also abort both B and C. + +use std::future::Future; +use std::ops::Deref; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::task::JoinHandle; + +#[derive(Debug)] +pub struct ScopedTask { + inner: JoinHandle, +} + +impl Drop for ScopedTask { + fn drop(&mut self) { + self.inner.abort() + } +} + +impl Future for ScopedTask { + type Output = as Future>::Output; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.inner).poll(cx) + } +} + +impl From> for ScopedTask { + fn from(inner: JoinHandle) -> Self { + Self { inner } + } +} + +impl Deref for ScopedTask { + type Target = JoinHandle; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +#[cfg(test)] +mod tests { + use super::ScopedTask; + use futures_util::future::pending; + use std::sync::{Arc, RwLock}; + use tokio::task::yield_now; + + struct Sentry(Arc>); + impl Drop for Sentry { + fn drop(&mut self) { + *self.0.write().unwrap() = true + } + } + + #[tokio::test] + async fn drop_while_not_waiting_for_join() { + let dropped = Arc::new(RwLock::new(false)); + let sentry = Sentry(dropped.clone()); + let task = ScopedTask::from(tokio::spawn(async move { + let _sentry = sentry; + pending::<()>().await + })); + yield_now().await; + assert!(!*dropped.read().unwrap()); + drop(task); + yield_now().await; + assert!(*dropped.read().unwrap()); + } + + #[tokio::test] + async fn drop_while_waiting_for_join() { + let dropped = Arc::new(RwLock::new(false)); + let sentry = Sentry(dropped.clone()); + let handle = tokio::spawn(async move { + ScopedTask::from(tokio::spawn(async move { + let _sentry = sentry; + pending::<()>().await + })) + .await + .unwrap() + }); + yield_now().await; + assert!(!*dropped.read().unwrap()); + handle.abort(); + yield_now().await; + assert!(*dropped.read().unwrap()); + } + + #[tokio::test] + async fn no_drop_only_join() { + assert_eq!( + ScopedTask::from(tokio::spawn(async { + yield_now().await; + 5 + })) + .await + .unwrap(), + 5 + ) + } + + #[tokio::test] + async fn manually_abort_before_drop() { + let dropped = Arc::new(RwLock::new(false)); + let sentry = Sentry(dropped.clone()); + let task = ScopedTask::from(tokio::spawn(async move { + let _sentry = sentry; + pending::<()>().await + })); + yield_now().await; + assert!(!*dropped.read().unwrap()); + task.abort(); + yield_now().await; + assert!(*dropped.read().unwrap()); + } + + #[tokio::test] + async fn manually_abort_then_join() { + let dropped = Arc::new(RwLock::new(false)); + let sentry = Sentry(dropped.clone()); + let task = ScopedTask::from(tokio::spawn(async move { + let _sentry = sentry; + pending::<()>().await + })); + yield_now().await; + assert!(!*dropped.read().unwrap()); + task.abort(); + yield_now().await; + assert!(task.await.is_err()); + } +} diff --git a/easytier/src/connector/udp_hole_punch.rs b/easytier/src/connector/udp_hole_punch.rs index caa8b43..e0f69db 100644 --- a/easytier/src/connector/udp_hole_punch.rs +++ b/easytier/src/connector/udp_hole_punch.rs @@ -22,7 +22,7 @@ use zerocopy::FromBytes; use crate::{ common::{ constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS, - stun::StunInfoCollectorTrait, PeerId, + scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId, }, defer, peers::peer_manager::PeerManager, @@ -417,12 +417,12 @@ impl UdpHolePunchService for UdpHolePunchRpcServer { } // send max k1 packets if we are predicting the dst port - let max_k1 = 180; + let max_k1 = 60; // send max k2 packets if we are sending to random port let max_k2 = rand::thread_rng().gen_range(600..800); // this means the NAT is allocating port in a predictable way - if max_port.abs_diff(min_port) <= max_k1 && round <= 6 && punch_predictablely { + if max_port.abs_diff(min_port) <= 3 * max_k1 && round <= 6 && punch_predictablely { let (min_port, max_port) = { // round begin from 0. if round is even, we guess port in increasing order let port_delta = (max_k1 as u32) / ip_count as u32; @@ -849,10 +849,10 @@ impl UdpHolePunchConnector { return Err(anyhow::anyhow!("failed to get public ips")); } - let mut last_port_idx = 0; + let mut last_port_idx = rand::thread_rng().gen_range(0..data.shuffled_port_vec.len()); - for round in 0..30 { - let Some(next_last_port_idx) = data + for round in 0..5 { + let ret = data .peer_mgr .get_peer_rpc_mgr() .do_client_rpc_scoped( @@ -879,11 +879,20 @@ impl UdpHolePunchConnector { last_port_idx }, ) - .await? - else { - return Err(anyhow::anyhow!("failed to get remote mapped addr")); + .await; + + let next_last_port_idx = match ret { + Ok(Some(idx)) => idx, + err => { + tracing::error!(?err, "failed to get remote mapped addr"); + rand::thread_rng().gen_range(0..data.shuffled_port_vec.len()) + } }; + // wait for some time to increase the chance of receiving hole punching packet + tokio::time::sleep(Duration::from_secs(2)).await; + + // no matter what the result is, we should check if we received any hole punching packet while let Some(socket) = udp_array.try_fetch_punched_socket(tid) { if let Ok(tunnel) = Self::try_connect_with_socket(socket, remote_mapped_addr).await { @@ -901,8 +910,8 @@ impl UdpHolePunchConnector { data: Arc, peer_id: PeerId, ) -> Result<(), anyhow::Error> { - const MAX_BACKOFF_TIME: u64 = 600; - let mut backoff_time = vec![15, 15, 30, 30, 60, 120, 300, MAX_BACKOFF_TIME]; + const MAX_BACKOFF_TIME: u64 = 300; + let mut backoff_time = vec![15, 15, 30, 30, 60, 120, 180, MAX_BACKOFF_TIME]; let my_nat_type = data.my_nat_type(); loop { @@ -942,7 +951,7 @@ impl UdpHolePunchConnector { async fn main_loop(data: Arc) { type JoinTaskRet = Result<(), anyhow::Error>; - type JoinTask = tokio::task::JoinHandle; + type JoinTask = ScopedTask; let punching_task = Arc::new(DashMap::<(PeerId, NatType), JoinTask>::new()); let mut last_my_nat_type = NatType::Unknown; @@ -978,23 +987,27 @@ impl UdpHolePunchConnector { last_my_nat_type = my_nat_type; if !peers_to_connect.is_empty() { - let my_nat_type = data.my_nat_type(); - if my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall { - let mut udp_array = data.udp_array.lock().await; - if udp_array.is_none() { - *udp_array = Some(Arc::new(UdpSocketArray::new( - data.udp_array_size.load(Ordering::Relaxed), - data.global_ctx.net_ns.clone(), - ))); - } - let udp_array = udp_array.as_ref().unwrap(); - udp_array.start().await.unwrap(); - } - for item in peers_to_connect { + if punching_task.contains_key(&item) { + continue; + } + + let my_nat_type = data.my_nat_type(); + if my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall { + let mut udp_array = data.udp_array.lock().await; + if udp_array.is_none() { + *udp_array = Some(Arc::new(UdpSocketArray::new( + data.udp_array_size.load(Ordering::Relaxed), + data.global_ctx.net_ns.clone(), + ))); + } + let udp_array = udp_array.as_ref().unwrap(); + udp_array.start().await.unwrap(); + } + punching_task.insert( item, - tokio::spawn(Self::peer_punching_task(data.clone(), item.0)), + tokio::spawn(Self::peer_punching_task(data.clone(), item.0)).into(), ); } } else if punching_task.is_empty() { @@ -1173,9 +1186,9 @@ pub mod tests { let udp_self = Arc::new(UdpSocket::bind("0.0.0.0:40144").await.unwrap()); let udp_inc = Arc::new(UdpSocket::bind("0.0.0.0:40147").await.unwrap()); - let udp_inc2 = Arc::new(UdpSocket::bind("0.0.0.0:40400").await.unwrap()); + let udp_inc2 = Arc::new(UdpSocket::bind("0.0.0.0:40200").await.unwrap()); let udp_dec = Arc::new(UdpSocket::bind("0.0.0.0:40140").await.unwrap()); - let udp_dec2 = Arc::new(UdpSocket::bind("0.0.0.0:40350").await.unwrap()); + let udp_dec2 = Arc::new(UdpSocket::bind("0.0.0.0:40050").await.unwrap()); let udps = vec![udp_self, udp_inc, udp_inc2, udp_dec, udp_dec2]; let counter = Arc::new(AtomicU32::new(0)); @@ -1186,7 +1199,7 @@ pub mod tests { tokio::spawn(async move { let mut buf = [0u8; 1024]; let (len, addr) = udp.recv_from(&mut buf).await.unwrap(); - println!("{:?} {:?}", len, addr); + println!("{:?} {:?} {:?}", len, addr, udp.local_addr()); counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); }); } diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index b1f0eea..b879535 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -340,7 +340,7 @@ impl Cli { } fn check_tcp_available(port: u16) -> Option { - let s = format!("127.0.0.1:{}", port).parse::().unwrap(); + let s = format!("0.0.0.0:{}", port).parse::().unwrap(); TcpSocket::new_v4().unwrap().bind(s).map(|_| s).ok() } @@ -353,9 +353,9 @@ impl Cli { return s; } } - return "127.0.0.1:0".parse().unwrap(); + return "0.0.0.0:0".parse().unwrap(); } - return format!("127.0.0.1:{}", port).parse().unwrap(); + return format!("0.0.0.0:{}", port).parse().unwrap(); } self.rpc_portal.parse().unwrap() diff --git a/easytier/src/vpn_portal/wireguard.rs b/easytier/src/vpn_portal/wireguard.rs index 43e6205..4239fac 100644 --- a/easytier/src/vpn_portal/wireguard.rs +++ b/easytier/src/vpn_portal/wireguard.rs @@ -24,7 +24,7 @@ use crate::{ mpsc::{MpscTunnel, MpscTunnelSender}, packet_def::{PacketType, ZCPacket, ZCPacketType}, wireguard::{WgConfig, WgTunnelListener}, - Tunnel, TunnelError, TunnelListener, + Tunnel, TunnelListener, }, }; diff --git a/script/install.sh b/script/install.sh index 5382310..f312b09 100644 --- a/script/install.sh +++ b/script/install.sh @@ -188,7 +188,7 @@ listeners = [ ] exit_nodes = [] peer = [] -rpc_portal = "127.0.0.1:15888" +rpc_portal = "0.0.0.0:15888" [network_identity] network_name = "default"