diff --git a/easytier/src/common/stun.rs b/easytier/src/common/stun.rs index f812fb1..fedf404 100644 --- a/easytier/src/common/stun.rs +++ b/easytier/src/common/stun.rs @@ -343,6 +343,8 @@ impl StunClientBuilder { pub struct UdpNatTypeDetectResult { source_addr: SocketAddr, stun_resps: Vec, + // if we are easy symmetric nat, we need to test with another port to check inc or dec + extra_bind_test: Option, } impl UdpNatTypeDetectResult { @@ -350,6 +352,7 @@ impl UdpNatTypeDetectResult { Self { source_addr, stun_resps, + extra_bind_test: None, } } @@ -406,7 +409,7 @@ impl UdpNatTypeDetectResult { .filter_map(|x| x.mapped_socket_addr) .collect::>() .len(); - mapped_addr_count < self.stun_server_count() + mapped_addr_count == 1 } pub fn nat_type(&self) -> NatType { @@ -429,7 +432,32 @@ impl UdpNatTypeDetectResult { return NatType::PortRestricted; } } else if !self.stun_resps.is_empty() { - return NatType::Symmetric; + if self.public_ips().len() != 1 + || self.usable_stun_resp_count() <= 1 + || self.max_port() - self.min_port() > 15 + || self.extra_bind_test.is_none() + || self + .extra_bind_test + .as_ref() + .unwrap() + .mapped_socket_addr + .is_none() + { + return NatType::Symmetric; + } else { + let extra_bind_test = self.extra_bind_test.as_ref().unwrap(); + let extra_port = extra_bind_test.mapped_socket_addr.unwrap().port(); + + let max_port_diff = extra_port.saturating_sub(self.max_port()); + let min_port_diff = self.min_port().saturating_sub(extra_port); + if max_port_diff != 0 && max_port_diff < 100 { + return NatType::SymmetricEasyInc; + } else if min_port_diff != 0 && min_port_diff < 100 { + return NatType::SymmetricEasyDec; + } else { + return NatType::Symmetric; + } + } } else { return NatType::Unknown; } @@ -477,6 +505,13 @@ impl UdpNatTypeDetectResult { .max() .unwrap_or(u16::MAX) } + + pub fn usable_stun_resp_count(&self) -> usize { + self.stun_resps + .iter() + .filter(|x| x.mapped_socket_addr.is_some()) + .count() + } } pub struct UdpNatTypeDetector { @@ -492,6 +527,19 @@ impl UdpNatTypeDetector { } } + async fn get_extra_bind_result( + &self, + source_port: u16, + stun_server: SocketAddr, + ) -> Result { + let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?); + let client_builder = StunClientBuilder::new(udp.clone()); + client_builder + .new_stun_client(stun_server) + .bind_request(false, false) + .await + } + pub async fn detect_nat_type(&self, source_port: u16) -> Result { let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?); self.detect_nat_type_with_socket(udp).await @@ -578,13 +626,28 @@ impl StunInfoCollectorTrait for StunInfoCollector { async fn get_udp_port_mapping(&self, local_port: u16) -> Result { self.start_stun_routine(); - let stun_servers = self + let mut stun_servers = self .udp_nat_test_result .read() .unwrap() .clone() .map(|x| x.collect_available_stun_server()) - .ok_or(Error::NotFound)?; + .unwrap_or(vec![]); + + if stun_servers.is_empty() { + let mut host_resolver = + HostResolverIter::new(self.stun_servers.read().unwrap().clone(), 2); + while let Some(addr) = host_resolver.next().await { + stun_servers.push(addr); + if stun_servers.len() >= 2 { + break; + } + } + } + + if stun_servers.is_empty() { + return Err(Error::NotFound); + } let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", local_port)).await?); let mut client_builder = StunClientBuilder::new(udp.clone()); @@ -630,9 +693,9 @@ impl StunInfoCollector { // stun server cross nation may return a external ip address with high latency and loss rate vec![ "stun.miwifi.com", - "stun.cdnbye.com", - "stun.hitv.com", "stun.chat.bilibili.com", + "stun.hitv.com", + "stun.cdnbye.com", "stun.douyucdn.cn:18000", "fwa.lifesizecloud.com", "global.turn.twilio.com", @@ -673,38 +736,41 @@ impl StunInfoCollector { .map(|x| x.to_string()) .collect(); let detector = UdpNatTypeDetector::new(servers, 1); - let ret = detector.detect_nat_type(0).await; + let mut ret = detector.detect_nat_type(0).await; tracing::debug!(?ret, "finish udp nat type detect"); + let mut nat_type = NatType::Unknown; - let sleep_sec = match &ret { - Ok(resp) => { - *udp_nat_test_result.write().unwrap() = Some(resp.clone()); - udp_test_time.store(Local::now()); - nat_type = resp.nat_type(); - if nat_type == NatType::Unknown { - 15 - } else { - 600 - } - } - _ => 15, - }; + if let Ok(resp) = &ret { + tracing::debug!(?resp, "got udp nat type detect result"); + nat_type = resp.nat_type(); + } // if nat type is symmtric, detect with another port to gather more info if nat_type == NatType::Symmetric { - let old_resp = ret.unwrap(); - let old_local_port = old_resp.local_addr().port(); - let new_port = if old_local_port >= 65535 { - old_local_port - 1 - } else { - old_local_port + 1 - }; - let ret = detector.detect_nat_type(new_port).await; - tracing::debug!(?ret, "finish udp nat type detect with another port"); - if let Ok(resp) = ret { - udp_nat_test_result.write().unwrap().as_mut().map(|x| { - x.extend_result(resp); - }); + let old_resp = ret.as_mut().unwrap(); + tracing::debug!(?old_resp, "start get extra bind result"); + let available_stun_servers = old_resp.collect_available_stun_server(); + for server in available_stun_servers.iter() { + let ret = detector + .get_extra_bind_result(0, *server) + .await + .with_context(|| "get extra bind result failed"); + tracing::debug!(?ret, "finish udp nat type detect with another port"); + if let Ok(resp) = ret { + old_resp.extra_bind_test = Some(resp); + break; + } + } + } + + let mut sleep_sec = 10; + if let Ok(resp) = &ret { + udp_test_time.store(Local::now()); + *udp_nat_test_result.write().unwrap() = Some(resp.clone()); + if nat_type != NatType::Unknown + && (nat_type != NatType::Symmetric || resp.extra_bind_test.is_some()) + { + sleep_sec = 600 } } @@ -734,7 +800,7 @@ impl StunInfoCollectorTrait for MockStunInfoCollector { last_update_time: std::time::Instant::now().elapsed().as_secs() as i64, min_port: 100, max_port: 200, - ..Default::default() + public_ip: vec!["127.0.0.1".to_string()], } } diff --git a/easytier/src/connector/direct.rs b/easytier/src/connector/direct.rs index cda5e58..585c1f0 100644 --- a/easytier/src/connector/direct.rs +++ b/easytier/src/connector/direct.rs @@ -425,7 +425,7 @@ impl DirectConnectorManager { ); let ip_list = rpc_stub - .get_ip_list(BaseController {}, GetIpListRequest {}) + .get_ip_list(BaseController::default(), GetIpListRequest {}) .await .with_context(|| format!("get ip list from peer {}", dst_peer_id))?; diff --git a/easytier/src/connector/udp_hole_punch.rs b/easytier/src/connector/udp_hole_punch.rs deleted file mode 100644 index d8851ce..0000000 --- a/easytier/src/connector/udp_hole_punch.rs +++ /dev/null @@ -1,1205 +0,0 @@ -use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, - sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, - Arc, - }, - time::Duration, - u16, -}; - -use anyhow::Context; -use crossbeam::atomic::AtomicCell; -use dashmap::{DashMap, DashSet}; -use rand::{seq::SliceRandom, Rng}; -use tokio::{ - net::UdpSocket, - sync::{Mutex, Notify}, - task::JoinSet, -}; -use tracing::{instrument, Instrument, Level}; -use zerocopy::FromBytes; - -use crate::{ - common::{ - error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS, - scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId, - }, - defer, - peers::peer_manager::PeerManager, - proto::{ - common::NatType, - peer_rpc::{ - TryPunchHoleRequest, TryPunchHoleResponse, TryPunchSymmetricRequest, - TryPunchSymmetricResponse, UdpHolePunchRpc, UdpHolePunchRpcClientFactory, - UdpHolePunchRpcServer, - }, - rpc_types::{self, controller::BaseController}, - }, - tunnel::{ - common::setup_sokcet2, - packet_def::{UDPTunnelHeader, UdpPacketType, UDP_TUNNEL_HEADER_SIZE}, - udp::{new_hole_punch_packet, UdpTunnelConnector, UdpTunnelListener}, - Tunnel, TunnelConnCounter, TunnelListener, - }, -}; - -use super::direct::PeerManagerForDirectConnector; - -const HOLE_PUNCH_PACKET_BODY_LEN: u16 = 16; - -fn generate_shuffled_port_vec() -> Vec { - let mut rng = rand::thread_rng(); - let mut port_vec: Vec = (1..=65535).collect(); - port_vec.shuffle(&mut rng); - port_vec -} - -// used for symmetric hole punching, binding to multiple ports to increase the chance of success -struct UdpSocketArray { - sockets: Arc>>, - max_socket_count: usize, - net_ns: NetNS, - tasks: Arc>>, - - intreast_tids: Arc>, - tid_to_socket: Arc>>>, -} - -impl UdpSocketArray { - pub fn new(max_socket_count: usize, net_ns: NetNS) -> Self { - let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); - join_joinset_background(tasks.clone(), "UdpSocketArray".to_owned()); - - Self { - sockets: Arc::new(DashMap::new()), - max_socket_count, - net_ns, - tasks, - - intreast_tids: Arc::new(DashSet::new()), - tid_to_socket: Arc::new(DashMap::new()), - } - } - - pub fn started(&self) -> bool { - !self.sockets.is_empty() - } - - async fn add_new_socket(&self) -> Result<(), anyhow::Error> { - let socket = { - let _g = self.net_ns.guard(); - Arc::new(UdpSocket::bind("0.0.0.0:0").await?) - }; - let local_addr = socket.local_addr()?; - self.sockets.insert(local_addr, socket.clone()); - - let intreast_tids = self.intreast_tids.clone(); - let tid_to_socket = self.tid_to_socket.clone(); - self.tasks.lock().unwrap().spawn( - async move { - let mut buf = [0u8; UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize]; - tracing::trace!(?local_addr, "udp socket added"); - loop { - let Ok((len, addr)) = socket.recv_from(&mut buf).await else { - break; - }; - - tracing::debug!(?len, ?addr, "got raw packet"); - - if len != UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize { - continue; - } - - let Some(p) = UDPTunnelHeader::ref_from_prefix(&buf) else { - continue; - }; - - tracing::debug!(?p, ?addr, "got udp hole punch packet"); - - if p.msg_type != UdpPacketType::HolePunch as u8 - || p.len.get() != HOLE_PUNCH_PACKET_BODY_LEN - { - continue; - } - - let tid = p.conn_id.get(); - if intreast_tids.contains(&tid) { - tracing::info!(?addr, "got hole punching packet with intreast tid"); - tid_to_socket - .entry(tid) - .or_insert_with(Vec::new) - .push(socket); - break; - } - } - tracing::debug!(?local_addr, "udp socket recv loop end"); - } - .instrument(tracing::info_span!("udp array socket recv loop")), - ); - Ok(()) - } - - #[instrument(err)] - pub async fn start(&self) -> Result<(), anyhow::Error> { - if self.started() { - return Ok(()); - } - - tracing::info!("starting udp socket array"); - - while self.sockets.len() < self.max_socket_count { - self.add_new_socket().await?; - } - - Ok(()) - } - - #[instrument(err)] - pub async fn send_with_all(&self, data: &[u8], addr: SocketAddr) -> Result<(), anyhow::Error> { - tracing::info!(?addr, "sending hole punching packet"); - - for socket in self.sockets.iter() { - let socket = socket.value(); - socket.send_to(data, addr).await?; - } - - Ok(()) - } - - #[instrument(ret(level = Level::DEBUG))] - pub fn try_fetch_punched_socket(&self, tid: u32) -> Option> { - tracing::debug!(?tid, "try fetch punched socket"); - self.tid_to_socket.get_mut(&tid)?.value_mut().pop() - } - - pub fn add_intreast_tid(&self, tid: u32) { - self.intreast_tids.insert(tid); - } - - pub fn remove_intreast_tid(&self, tid: u32) { - self.intreast_tids.remove(&tid); - self.tid_to_socket.remove(&tid); - } -} - -impl std::fmt::Debug for UdpSocketArray { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("UdpSocketArray") - .field("sockets", &self.sockets.len()) - .field("max_socket_count", &self.max_socket_count) - .field("started", &self.started()) - .field("intreast_tids", &self.intreast_tids.len()) - .field("tid_to_socket", &self.tid_to_socket.len()) - .finish() - } -} - -#[derive(Debug)] -struct UdpHolePunchListener { - socket: Arc, - tasks: JoinSet<()>, - running: Arc>, - mapped_addr: SocketAddr, - conn_counter: Arc>, - - listen_time: std::time::Instant, - last_select_time: AtomicCell, - last_active_time: Arc>, -} - -impl UdpHolePunchListener { - async fn get_avail_port() -> Result { - let socket = UdpSocket::bind("0.0.0.0:0").await?; - Ok(socket.local_addr()?.port()) - } - - #[instrument(err)] - pub async fn new(peer_mgr: Arc) -> Result { - let port = Self::get_avail_port().await?; - let listen_url = format!("udp://0.0.0.0:{}", port); - - let gctx = peer_mgr.get_global_ctx(); - let stun_info_collect = gctx.get_stun_info_collector(); - let mapped_addr = stun_info_collect.get_udp_port_mapping(port).await?; - - let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap()); - - { - let _g = peer_mgr.get_global_ctx().net_ns.guard(); - listener.listen().await?; - } - let socket = listener.get_socket().unwrap(); - - let running = Arc::new(AtomicCell::new(true)); - let running_clone = running.clone(); - - let conn_counter = listener.get_conn_counter(); - let mut tasks = JoinSet::new(); - - tasks.spawn(async move { - while let Ok(conn) = listener.accept().await { - tracing::warn!(?conn, "udp hole punching listener got peer connection"); - let peer_mgr = peer_mgr.clone(); - tokio::spawn(async move { - if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await { - tracing::error!( - ?e, - "failed to add tunnel as server in hole punch listener" - ); - } - }); - } - - running_clone.store(false); - }); - - let last_active_time = Arc::new(AtomicCell::new(std::time::Instant::now())); - let conn_counter_clone = conn_counter.clone(); - let last_active_time_clone = last_active_time.clone(); - tasks.spawn(async move { - loop { - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - if conn_counter_clone.get() != 0 { - last_active_time_clone.store(std::time::Instant::now()); - } - } - }); - - tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started"); - - Ok(Self { - tasks, - socket, - running, - mapped_addr, - conn_counter, - - listen_time: std::time::Instant::now(), - last_select_time: AtomicCell::new(std::time::Instant::now()), - last_active_time, - }) - } - - pub async fn get_socket(&self) -> Arc { - self.last_select_time.store(std::time::Instant::now()); - self.socket.clone() - } -} - -struct UdpHolePunchConnectorData { - global_ctx: ArcGlobalCtx, - peer_mgr: Arc, - listeners: Arc>>, - shuffled_port_vec: Arc>, - - udp_array: Arc>>>, - try_direct_connect: AtomicBool, - punch_predicablely: AtomicBool, - punch_randomly: AtomicBool, - udp_array_size: AtomicUsize, -} - -impl std::fmt::Debug for UdpHolePunchConnectorData { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // print peer id listener count - let peer_id = self.peer_mgr.my_peer_id(); - f.debug_struct("UdpHolePunchConnectorData") - .field("peer_id", &peer_id) - .finish() - } -} - -impl UdpHolePunchConnectorData { - fn my_nat_type(&self) -> NatType { - let stun_info = self.global_ctx.get_stun_info_collector().get_stun_info(); - NatType::try_from(stun_info.udp_nat_type).unwrap() - } -} - -#[derive(Clone)] -struct UdpHolePunchRpcService { - data: Arc, - - tasks: Arc>>, -} - -#[async_trait::async_trait] -impl UdpHolePunchRpc for UdpHolePunchRpcService { - type Controller = BaseController; - - #[tracing::instrument(skip(self))] - async fn try_punch_hole( - &self, - _: BaseController, - request: TryPunchHoleRequest, - ) -> Result { - let local_mapped_addr = request.local_mapped_addr.ok_or(anyhow::anyhow!( - "try_punch_hole request missing local_mapped_addr" - ))?; - let local_mapped_addr = std::net::SocketAddr::from(local_mapped_addr); - // local mapped addr will be unspecified if peer is symmetric - let peer_is_symmetric = local_mapped_addr.ip().is_unspecified(); - let (socket, mapped_addr) = - self.select_listener(peer_is_symmetric) - .await - .ok_or(anyhow::anyhow!( - "failed to select listener for hole punching" - ))?; - tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching"); - - if !peer_is_symmetric { - let my_udp_nat_type = self - .data - .global_ctx - .get_stun_info_collector() - .get_stun_info() - .udp_nat_type; - - // if we are cone, we need to send hole punching resp to client - if my_udp_nat_type == NatType::PortRestricted as i32 - || my_udp_nat_type == NatType::Restricted as i32 - || my_udp_nat_type == NatType::FullCone as i32 - { - let notifier = Arc::new(Notify::new()); - - let n = notifier.clone(); - // send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second - self.tasks.lock().unwrap().spawn(async move { - for i in 0..10 { - tracing::info!(?local_mapped_addr, "sending hole punching packet"); - - let udp_packet = new_hole_punch_packet(100, HOLE_PUNCH_PACKET_BODY_LEN); - let _ = socket - .send_to(&udp_packet.into_bytes(), local_mapped_addr) - .await; - let sleep_ms = if i < 4 { 10 } else { 500 }; - tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await; - if i == 3 { - n.notify_one(); - } - } - }); - - notifier.notified().await; - } - } - - Ok(TryPunchHoleResponse { - remote_mapped_addr: Some(mapped_addr.into()), - }) - } - - #[instrument(skip(self))] - async fn try_punch_symmetric( - &self, - _: BaseController, - request: TryPunchSymmetricRequest, - ) -> Result { - let listener_addr = request.listener_addr.ok_or(anyhow::anyhow!( - "try_punch_symmetric request missing listener_addr" - ))?; - let listener_addr = std::net::SocketAddr::from(listener_addr); - let port = request.port as u16; - let public_ips = request - .public_ips - .into_iter() - .map(|ip| std::net::Ipv4Addr::from(ip)) - .collect::>(); - let mut min_port = request.min_port as u16; - let mut max_port = request.max_port as u16; - let transaction_id = request.transaction_id; - let round = request.round; - let last_port_index = request.last_port_index as usize; - - tracing::info!("try_punch_symmetric start"); - - let punch_predictablely = self.data.punch_predicablely.load(Ordering::Relaxed); - let punch_randomly = self.data.punch_randomly.load(Ordering::Relaxed); - let total_port_count = self.data.shuffled_port_vec.len(); - let listener = self - .find_listener(&listener_addr) - .await - .ok_or(anyhow::anyhow!( - "try_punch_symmetric failed to find listener" - ))?; - let ip_count = public_ips.len(); - if ip_count == 0 { - tracing::warn!("try_punch_symmetric got zero len public ip"); - return Err(anyhow::anyhow!("try_punch_symmetric got zero len public ip").into()); - } - - min_port = std::cmp::max(1, min_port); - if max_port == 0 { - max_port = u16::MAX; - } - if max_port < min_port { - std::mem::swap(&mut min_port, &mut max_port); - } - - // send max k1 packets if we are predicting the dst port - let max_k1 = 180; - // send max k2 packets if we are sending to random port - let max_k2 = rand::thread_rng().gen_range(600..800); - - // this means the NAT is allocating port in a predictable way - if max_port.abs_diff(min_port) <= max_k1 && round <= 6 && punch_predictablely { - let (min_port, max_port) = { - // round begin from 0. if round is even, we guess port in increasing order - let port_delta = (max_k1 as u32) / ip_count as u32; - let port_diff_for_min = std::cmp::min((round / 2) * port_delta, u16::MAX as u32); - if round % 2 == 0 { - let lower = std::cmp::max(1, port.saturating_add(port_diff_for_min as u16)); - let upper = lower.saturating_add(port_delta as u16); - (lower, upper) - } else { - let upper = std::cmp::max(1, port.saturating_sub(port_diff_for_min as u16)); - let lower = std::cmp::max(1, upper.saturating_sub(port_delta as u16)); - (lower, upper) - } - }; - let mut ports = (min_port..=max_port).collect::>(); - ports.push(max_port); - ports.shuffle(&mut rand::thread_rng()); - self.send_symmetric_hole_punch_packet( - listener.clone(), - transaction_id, - &public_ips, - &ports, - ) - .await - .with_context(|| "failed to send symmetric hole punch packet predict")?; - } - - if punch_randomly { - let start = last_port_index % total_port_count; - let diff = std::cmp::max(10, max_k2 / ip_count); - let end = std::cmp::min(start + diff, self.data.shuffled_port_vec.len()); - self.send_symmetric_hole_punch_packet( - listener.clone(), - transaction_id, - &public_ips, - &self.data.shuffled_port_vec[start..end], - ) - .await - .with_context(|| "failed to send symmetric hole punch packet randomly")?; - - return if end >= self.data.shuffled_port_vec.len() { - Ok(TryPunchSymmetricResponse { last_port_index: 1 }) - } else { - Ok(TryPunchSymmetricResponse { - last_port_index: end as u32, - }) - }; - } - - return Ok(TryPunchSymmetricResponse { last_port_index: 1 }); - } -} - -impl UdpHolePunchRpcService { - pub fn new(data: Arc) -> Self { - let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); - join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned()); - Self { data, tasks } - } - - async fn find_listener(&self, addr: &SocketAddr) -> Option> { - let all_listener_sockets = self.data.listeners.lock().await; - - let listener = all_listener_sockets - .iter() - .find(|listener| listener.mapped_addr == *addr && listener.running.load())?; - - Some(listener.get_socket().await) - } - - async fn select_listener( - &self, - use_new_listener: bool, - ) -> Option<(Arc, SocketAddr)> { - let all_listener_sockets = &self.data.listeners; - - // remove listener that is not active for 40 seconds but keep listeners that are selected less than 30 seconds - all_listener_sockets.lock().await.retain(|listener| { - listener.last_active_time.load().elapsed().as_secs() < 40 - || listener.last_select_time.load().elapsed().as_secs() < 30 - }); - - let mut use_last = false; - if all_listener_sockets.lock().await.len() < 4 || use_new_listener { - tracing::warn!("creating new udp hole punching listener"); - all_listener_sockets.lock().await.push( - UdpHolePunchListener::new(self.data.peer_mgr.clone()) - .await - .ok()?, - ); - use_last = true; - } - - let locked = all_listener_sockets.lock().await; - - let listener = if use_last { - locked.last()? - } else { - // use the listener that is active most recently - locked - .iter() - .max_by_key(|listener| listener.last_active_time.load())? - }; - - Some((listener.get_socket().await, listener.mapped_addr)) - } - - #[tracing::instrument(err, ret(level=Level::DEBUG), skip(self, ports))] - async fn send_symmetric_hole_punch_packet( - &self, - udp: Arc, - transaction_id: u32, - public_ips: &Vec, - ports: &[u16], - ) -> Result<(), Error> { - tracing::debug!( - ?public_ips, - "sending symmetric hole punching packet, ports len: {}", - ports.len(), - ); - for port in ports { - for pub_ip in public_ips { - let addr = SocketAddr::V4(SocketAddrV4::new(*pub_ip, *port)); - let packet = new_hole_punch_packet(transaction_id, HOLE_PUNCH_PACKET_BODY_LEN); - udp.send_to(&packet.into_bytes(), addr).await?; - tokio::time::sleep(Duration::from_millis(2)).await; - } - } - Ok(()) - } -} - -pub struct UdpHolePunchConnector { - data: Arc, - tasks: JoinSet<()>, -} - -// Currently support: -// Symmetric -> Full Cone -// Any Type of Full Cone -> Any Type of Full Cone - -// if same level of full cone, node with smaller peer_id will be the initiator -// if different level of full cone, node with more strict level will be the initiator - -impl UdpHolePunchConnector { - pub fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc) -> Self { - Self { - data: Arc::new(UdpHolePunchConnectorData { - global_ctx, - peer_mgr, - listeners: Arc::new(Mutex::new(Vec::new())), - shuffled_port_vec: Arc::new(generate_shuffled_port_vec()), - udp_array: Arc::new(Mutex::new(None)), - try_direct_connect: AtomicBool::new(true), - punch_predicablely: AtomicBool::new(true), - punch_randomly: AtomicBool::new(true), - udp_array_size: AtomicUsize::new(80), - }), - tasks: JoinSet::new(), - } - } - - pub async fn run_as_client(&mut self) -> Result<(), Error> { - let data = self.data.clone(); - self.tasks.spawn(async move { - Self::main_loop(data).await; - }); - - Ok(()) - } - - pub async fn run_as_server(&mut self) -> Result<(), Error> { - self.data - .peer_mgr - .get_peer_rpc_mgr() - .rpc_server() - .registry() - .register( - UdpHolePunchRpcServer::new(UdpHolePunchRpcService::new(self.data.clone())), - &self.data.global_ctx.get_network_name(), - ); - - Ok(()) - } - - pub async fn run(&mut self) -> Result<(), Error> { - if self.data.global_ctx.get_flags().disable_p2p { - return Ok(()); - } - if self.data.global_ctx.get_flags().disable_udp_hole_punching { - return Ok(()); - } - - self.run_as_client().await?; - self.run_as_server().await?; - - Ok(()) - } - - async fn collect_peer_to_connect( - data: Arc, - ) -> Vec<(PeerId, NatType)> { - let mut peers_to_connect = Vec::new(); - - // do not do anything if: - // 1. our stun test has not finished - // 2. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us - let my_nat_type = data.my_nat_type(); - if my_nat_type == NatType::Unknown - || my_nat_type == NatType::OpenInternet - || my_nat_type == NatType::NoPat - { - return peers_to_connect; - } - - // collect peer list from peer manager and do some filter: - // 1. peers without direct conns; - // 2. peers is full cone (any restricted type); - for route in data.peer_mgr.list_routes().await.iter() { - let Some(peer_stun_info) = route.stun_info.as_ref() else { - continue; - }; - let Ok(peer_nat_type) = NatType::try_from(peer_stun_info.udp_nat_type) else { - continue; - }; - - let peer_id: PeerId = route.peer_id; - let conns = data.peer_mgr.list_peer_conns(peer_id).await; - if conns.is_some() && conns.unwrap().len() > 0 { - continue; - } - - // if peer is symmetric ignore it because we cannot connect to it - // if peer is open internet or no pat, direct connector will connecto to it - if peer_nat_type == NatType::Unknown - || peer_nat_type == NatType::OpenInternet - || peer_nat_type == NatType::NoPat - || peer_nat_type == NatType::Symmetric - || peer_nat_type == NatType::SymUdpFirewall - { - continue; - } - - // if we are symmetric, we can only connect to cone peer - if (my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall) - && (peer_nat_type == NatType::Symmetric || peer_nat_type == NatType::SymUdpFirewall) - { - continue; - } - - // if we have smae level of full cone, node with smaller peer_id will be the initiator - if my_nat_type == peer_nat_type { - if data.peer_mgr.my_peer_id() > peer_id { - continue; - } - } else { - // if we have different level of full cone - // we will be the initiator if we have more strict level - if my_nat_type < peer_nat_type { - continue; - } - } - - tracing::info!( - ?peer_id, - ?peer_nat_type, - ?my_nat_type, - ?data.global_ctx.id, - "found peer to do hole punching" - ); - - peers_to_connect.push((peer_id, peer_nat_type)); - } - - peers_to_connect - } - - async fn try_connect_with_socket( - socket: Arc, - remote_mapped_addr: SocketAddr, - ) -> Result, Error> { - let connector = UdpTunnelConnector::new( - format!( - "udp://{}:{}", - remote_mapped_addr.ip(), - remote_mapped_addr.port() - ) - .to_string() - .parse() - .unwrap(), - ); - connector - .try_connect_with_socket(socket, remote_mapped_addr) - .await - .map_err(|e| Error::from(e)) - } - - #[tracing::instrument(err)] - async fn do_hole_punching_cone( - data: Arc, - dst_peer_id: PeerId, - ) -> Result, anyhow::Error> { - tracing::info!(?dst_peer_id, "start hole punching"); - // client: choose a local udp port, and get the pubic mapped port from stun server - let socket = { - let _g = data.global_ctx.net_ns.guard(); - UdpSocket::bind("0.0.0.0:0").await.with_context(|| "")? - }; - let local_socket_addr = socket.local_addr()?; - let local_port = socket.local_addr()?.port(); - drop(socket); // drop the socket to release the port - - let local_mapped_addr = data - .global_ctx - .get_stun_info_collector() - .get_udp_port_mapping(local_port) - .await - .with_context(|| "failed to get udp port mapping")?; - - // client -> server: tell server the mapped port, server will return the mapped address of listening port. - let rpc_stub = data - .peer_mgr - .get_peer_rpc_mgr() - .rpc_client() - .scoped_client::>( - data.peer_mgr.my_peer_id(), - dst_peer_id, - data.global_ctx.get_network_name(), - ); - - let remote_mapped_addr = rpc_stub - .try_punch_hole( - BaseController {}, - TryPunchHoleRequest { - local_mapped_addr: Some(local_mapped_addr.into()), - }, - ) - .await? - .remote_mapped_addr - .ok_or(anyhow::anyhow!("failed to get remote mapped addr"))?; - - // server: will send some punching resps, total 10 packets. - // client: use the socket to create UdpTunnel with UdpTunnelConnector - // NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote. - let _g = data.global_ctx.net_ns.guard(); - let socket2_socket = socket2::Socket::new( - socket2::Domain::for_address(local_socket_addr), - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - )?; - setup_sokcet2(&socket2_socket, &local_socket_addr)?; - let socket = Arc::new(UdpSocket::from_std(socket2_socket.into())?); - - Ok( - Self::try_connect_with_socket(socket, remote_mapped_addr.into()) - .await - .with_context(|| "UdpTunnelConnector failed to connect remote")?, - ) - } - - #[tracing::instrument(err(level = Level::ERROR))] - async fn do_hole_punching_symmetric( - data: Arc, - dst_peer_id: PeerId, - ) -> Result, anyhow::Error> { - let Some(udp_array) = data.udp_array.lock().await.clone() else { - return Err(anyhow::anyhow!("udp array not started")); - }; - - let rpc_stub = data - .peer_mgr - .get_peer_rpc_mgr() - .rpc_client() - .scoped_client::>( - data.peer_mgr.my_peer_id(), - dst_peer_id, - data.global_ctx.get_network_name(), - ); - - let local_mapped_addr: SocketAddr = "0.0.0.0:0".parse().unwrap(); - let remote_mapped_addr = rpc_stub - .try_punch_hole( - BaseController {}, - TryPunchHoleRequest { - local_mapped_addr: Some(local_mapped_addr.into()), - }, - ) - .await? - .remote_mapped_addr - .ok_or(anyhow::anyhow!("failed to get remote mapped addr"))? - .into(); - - // try direct connect first - if data.try_direct_connect.load(Ordering::Relaxed) { - if let Ok(tunnel) = Self::try_connect_with_socket( - Arc::new(UdpSocket::bind("0.0.0.0:0").await?), - remote_mapped_addr, - ) - .await - { - return Ok(tunnel); - } - } - - let tid = rand::thread_rng().gen(); - let packet = new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(); - udp_array.add_intreast_tid(tid); - defer! { udp_array.remove_intreast_tid(tid);} - udp_array.send_with_all(&packet, remote_mapped_addr).await?; - - // get latest port mapping - let local_mapped_addr = data - .global_ctx - .get_stun_info_collector() - .get_udp_port_mapping(0) - .await?; - let port = local_mapped_addr.port(); - let IpAddr::V4(ipv4) = local_mapped_addr.ip() else { - return Err(anyhow::anyhow!("failed to get local mapped addr")); - }; - let stun_info = data.global_ctx.get_stun_info_collector().get_stun_info(); - let mut public_ips: Vec = stun_info - .public_ip - .iter() - .map(|x| x.parse().unwrap()) - .collect(); - if !public_ips.contains(&ipv4) { - public_ips.push(ipv4); - } - if public_ips.is_empty() { - return Err(anyhow::anyhow!("failed to get public ips")); - } - - let mut last_port_idx = rand::thread_rng().gen_range(0..data.shuffled_port_vec.len()); - - for round in 0..30 { - let ret = rpc_stub - .try_punch_symmetric( - BaseController {}, - TryPunchSymmetricRequest { - listener_addr: Some(remote_mapped_addr.into()), - port: port as u32, - public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(), - min_port: stun_info.min_port as u32, - max_port: stun_info.max_port as u32, - transaction_id: tid, - round, - last_port_index: last_port_idx as u32, - }, - ) - .await; - tracing::info!(?ret, "punch symmetric return"); - - let next_last_port_idx = match ret { - Ok(s) => s.last_port_index as usize, - Err(err) => { - tracing::error!(?err, "failed to get remote mapped addr"); - rand::thread_rng().gen_range(0..data.shuffled_port_vec.len()) - } - }; - - // wait for some time to increase the chance of receiving hole punching packet - tokio::time::sleep(Duration::from_secs(2)).await; - - // no matter what the result is, we should check if we received any hole punching packet - while let Some(socket) = udp_array.try_fetch_punched_socket(tid) { - if let Ok(tunnel) = Self::try_connect_with_socket(socket, remote_mapped_addr).await - { - return Ok(tunnel); - } - } - - last_port_idx = next_last_port_idx; - } - - return Err(anyhow::anyhow!("udp array not started")); - } - - async fn peer_punching_task( - data: Arc, - peer_id: PeerId, - ) -> Result<(), anyhow::Error> { - const MAX_BACKOFF_TIME: u64 = 300; - let mut backoff_time = vec![15, 15, 30, 30, 60, 120, 180, MAX_BACKOFF_TIME]; - let my_nat_type = data.my_nat_type(); - - loop { - let ret = if my_nat_type == NatType::FullCone - || my_nat_type == NatType::Restricted - || my_nat_type == NatType::PortRestricted - { - Self::do_hole_punching_cone(data.clone(), peer_id).await - } else { - Self::do_hole_punching_symmetric(data.clone(), peer_id).await - }; - - match ret { - Err(_) => { - tokio::time::sleep(Duration::from_secs( - backoff_time.pop().unwrap_or(MAX_BACKOFF_TIME), - )) - .await; - continue; - } - - Ok(tunnel) => { - let _ = data - .peer_mgr - .add_client_tunnel(tunnel) - .await - .with_context(|| { - "failed to add tunnel as client in hole punch connector" - })?; - break; - } - } - } - - Ok(()) - } - - async fn main_loop(data: Arc) { - type JoinTaskRet = Result<(), anyhow::Error>; - type JoinTask = ScopedTask; - let punching_task = Arc::new(DashMap::<(PeerId, NatType), JoinTask>::new()); - let mut last_my_nat_type = NatType::Unknown; - - loop { - let my_nat_type = data.my_nat_type(); - let peers_to_connect = Self::collect_peer_to_connect(data.clone()).await; - - // remove task not in peers_to_connect - let mut to_remove = vec![]; - for item in punching_task.iter() { - if !peers_to_connect.contains(item.key()) - || item.value().is_finished() - || my_nat_type != last_my_nat_type - { - to_remove.push(item.key().clone()); - } - } - for key in to_remove { - if let Some((_, task)) = punching_task.remove(&key) { - task.abort(); - match task.await { - Ok(Ok(_)) => {} - Ok(Err(task_ret)) => { - tracing::error!(?task_ret, "hole punching task failed"); - } - Err(e) => { - tracing::error!(?e, "hole punching task aborted"); - } - } - } - } - - last_my_nat_type = my_nat_type; - - if !peers_to_connect.is_empty() { - for item in peers_to_connect { - if punching_task.contains_key(&item) { - continue; - } - - let my_nat_type = data.my_nat_type(); - if my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall { - let mut udp_array = data.udp_array.lock().await; - if udp_array.is_none() { - *udp_array = Some(Arc::new(UdpSocketArray::new( - data.udp_array_size.load(Ordering::Relaxed), - data.global_ctx.net_ns.clone(), - ))); - } - let udp_array = udp_array.as_ref().unwrap(); - udp_array.start().await.unwrap(); - } - - punching_task.insert( - item, - tokio::spawn(Self::peer_punching_task(data.clone(), item.0)).into(), - ); - } - } else if punching_task.is_empty() { - data.udp_array.lock().await.take(); - } - - tokio::time::sleep(std::time::Duration::from_secs(10)).await; - } - } -} - -#[cfg(test)] -pub mod tests { - use std::sync::atomic::AtomicU32; - use std::sync::Arc; - use std::time::Duration; - - use tokio::net::UdpSocket; - - use crate::common::stun::MockStunInfoCollector; - use crate::proto::common::NatType; - use crate::tunnel::common::tests::wait_for_condition; - - use crate::{ - connector::udp_hole_punch::UdpHolePunchConnector, - peers::{ - peer_manager::PeerManager, - tests::{ - connect_peer_manager, create_mock_peer_manager, wait_route_appear, - wait_route_appear_with_cost, - }, - }, - }; - - pub fn replace_stun_info_collector(peer_mgr: Arc, udp_nat_type: NatType) { - let collector = Box::new(MockStunInfoCollector { udp_nat_type }); - peer_mgr - .get_global_ctx() - .replace_stun_info_collector(collector); - } - - pub async fn create_mock_peer_manager_with_mock_stun( - udp_nat_type: NatType, - ) -> Arc { - let p_a = create_mock_peer_manager().await; - replace_stun_info_collector(p_a.clone(), udp_nat_type); - p_a - } - - #[tokio::test] - async fn hole_punching_cone() { - let p_a = create_mock_peer_manager_with_mock_stun(NatType::Restricted).await; - let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; - let p_c = create_mock_peer_manager_with_mock_stun(NatType::Restricted).await; - connect_peer_manager(p_a.clone(), p_b.clone()).await; - connect_peer_manager(p_b.clone(), p_c.clone()).await; - - wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); - - println!("{:?}", p_a.list_routes().await); - - let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone()); - let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone()); - - hole_punching_a.run().await.unwrap(); - hole_punching_c.run().await.unwrap(); - - wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1)) - .await - .unwrap(); - println!("{:?}", p_a.list_routes().await); - } - - #[tokio::test] - async fn hole_punching_symmetric_only_random() { - let p_a = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await; - let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; - let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; - connect_peer_manager(p_a.clone(), p_b.clone()).await; - connect_peer_manager(p_b.clone(), p_c.clone()).await; - wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); - - let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone()); - let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone()); - - hole_punching_a - .data - .try_direct_connect - .store(false, std::sync::atomic::Ordering::Relaxed); - - hole_punching_c - .data - .punch_predicablely - .store(false, std::sync::atomic::Ordering::Relaxed); - - hole_punching_a.run().await.unwrap(); - hole_punching_c.run().await.unwrap(); - - wait_for_condition( - || async { hole_punching_a.data.udp_array.lock().await.is_some() }, - Duration::from_secs(5), - ) - .await; - - wait_for_condition( - || async { - wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1)) - .await - .is_ok() - }, - Duration::from_secs(30), - ) - .await; - println!("{:?}", p_a.list_routes().await); - - wait_for_condition( - || async { hole_punching_a.data.udp_array.lock().await.is_none() }, - Duration::from_secs(20), - ) - .await; - } - - #[tokio::test] - async fn hole_punching_symmetric_only_predict() { - let p_a = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await; - let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; - let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; - connect_peer_manager(p_a.clone(), p_b.clone()).await; - connect_peer_manager(p_b.clone(), p_c.clone()).await; - wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); - - let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone()); - let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone()); - - hole_punching_a - .data - .try_direct_connect - .store(false, std::sync::atomic::Ordering::Relaxed); - hole_punching_a - .data - .udp_array_size - .store(0, std::sync::atomic::Ordering::Relaxed); - - hole_punching_c - .data - .punch_randomly - .store(false, std::sync::atomic::Ordering::Relaxed); - - hole_punching_a.run().await.unwrap(); - hole_punching_c.run().await.unwrap(); - - let udp_self = Arc::new(UdpSocket::bind("0.0.0.0:40144").await.unwrap()); - let udp_inc = Arc::new(UdpSocket::bind("0.0.0.0:40147").await.unwrap()); - let udp_inc2 = Arc::new(UdpSocket::bind("0.0.0.0:40200").await.unwrap()); - let udp_dec = Arc::new(UdpSocket::bind("0.0.0.0:40140").await.unwrap()); - let udp_dec2 = Arc::new(UdpSocket::bind("0.0.0.0:40050").await.unwrap()); - let udps = vec![udp_self, udp_inc, udp_inc2, udp_dec, udp_dec2]; - - let counter = Arc::new(AtomicU32::new(0)); - - // all these sockets should receive hole punching packet - for udp in udps.iter().map(Arc::clone) { - let counter = counter.clone(); - tokio::spawn(async move { - let mut buf = [0u8; 1024]; - let (len, addr) = udp.recv_from(&mut buf).await.unwrap(); - println!("{:?} {:?} {:?}", len, addr, udp.local_addr()); - counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }); - } - - let udp_len = udps.len(); - wait_for_condition( - || async { counter.load(std::sync::atomic::Ordering::Relaxed) == udp_len as u32 }, - Duration::from_secs(30), - ) - .await; - } -} diff --git a/easytier/src/connector/udp_hole_punch/both_easy_sym.rs b/easytier/src/connector/udp_hole_punch/both_easy_sym.rs new file mode 100644 index 0000000..4e66c17 --- /dev/null +++ b/easytier/src/connector/udp_hole_punch/both_easy_sym.rs @@ -0,0 +1,396 @@ +use std::{ + net::{IpAddr, SocketAddr, SocketAddrV4}, + sync::Arc, + time::{Duration, Instant}, +}; + +use anyhow::Context; +use tokio::sync::Mutex; + +use crate::{ + common::{scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId}, + connector::udp_hole_punch::common::{ + try_connect_with_socket, UdpHolePunchListener, HOLE_PUNCH_PACKET_BODY_LEN, + }, + peers::peer_manager::PeerManager, + proto::{ + peer_rpc::{ + SendPunchPacketBothEasySymRequest, SendPunchPacketBothEasySymResponse, + UdpHolePunchRpcClientFactory, + }, + rpc_types::{self, controller::BaseController}, + }, + tunnel::{udp::new_hole_punch_packet, Tunnel}, +}; + +use super::common::{PunchHoleServerCommon, UdpNatType, UdpSocketArray}; + +const UDP_ARRAY_SIZE_FOR_BOTH_EASY_SYM: usize = 25; +const DST_PORT_OFFSET: u16 = 20; +const REMOTE_WAIT_TIME_MS: u64 = 5000; + +pub(crate) struct PunchBothEasySymHoleServer { + common: Arc, + task: Mutex>>, +} + +impl PunchBothEasySymHoleServer { + pub(crate) fn new(common: Arc) -> Self { + Self { + common, + task: Mutex::new(None), + } + } + + // hard sym means public port is random and cannot be predicted + #[tracing::instrument(skip(self), ret, err)] + pub(crate) async fn send_punch_packet_both_easy_sym( + &self, + request: SendPunchPacketBothEasySymRequest, + ) -> Result { + tracing::info!("send_punch_packet_both_easy_sym start"); + let busy_resp = Ok(SendPunchPacketBothEasySymResponse { + is_busy: true, + ..Default::default() + }); + let Ok(mut locked_task) = self.task.try_lock() else { + return busy_resp; + }; + if locked_task.is_some() && !locked_task.as_ref().unwrap().is_finished() { + return busy_resp; + } + + let global_ctx = self.common.get_global_ctx(); + let cur_mapped_addr = global_ctx + .get_stun_info_collector() + .get_udp_port_mapping(0) + .await + .with_context(|| "failed to get udp port mapping")?; + + tracing::info!("send_punch_packet_hard_sym start"); + let socket_count = request.udp_socket_count as usize; + let public_ips = request + .public_ip + .ok_or(anyhow::anyhow!("public_ip is required"))?; + let transaction_id = request.transaction_id; + + let udp_array = + UdpSocketArray::new(socket_count, self.common.get_global_ctx().net_ns.clone()); + udp_array.start().await?; + udp_array.add_intreast_tid(transaction_id); + let peer_mgr = self.common.get_peer_mgr(); + + let punch_packet = + new_hole_punch_packet(transaction_id, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(); + let mut punched = vec![]; + let common = self.common.clone(); + + let task = tokio::spawn(async move { + let mut listeners = Vec::new(); + let start_time = Instant::now(); + let wait_time_ms = request.wait_time_ms.min(8000); + while start_time.elapsed() < Duration::from_millis(wait_time_ms as u64) { + if let Err(e) = udp_array + .send_with_all( + &punch_packet, + SocketAddr::V4(SocketAddrV4::new( + public_ips.into(), + request.dst_port_num as u16, + )), + ) + .await + { + tracing::error!(?e, "failed to send hole punch packet"); + break; + } + + tokio::time::sleep(Duration::from_millis(100)).await; + + if let Some(s) = udp_array.try_fetch_punched_socket(transaction_id) { + tracing::info!(?s, ?transaction_id, "got punched socket in both easy sym"); + assert!(Arc::strong_count(&s.socket) == 1); + let Some(port) = s.socket.local_addr().ok().map(|addr| addr.port()) else { + tracing::warn!("failed to get local addr from punched socket"); + continue; + }; + let remote_addr = s.remote_addr; + drop(s); + + let listener = + match UdpHolePunchListener::new_ext(peer_mgr.clone(), false, Some(port)) + .await + { + Ok(l) => l, + Err(e) => { + tracing::warn!(?e, "failed to create listener"); + continue; + } + }; + punched.push((listener.get_socket().await, remote_addr)); + listeners.push(listener); + } + + // if any listener is punched, we can break the loop + for l in &listeners { + if l.get_conn_count().await > 0 { + tracing::info!(?l, "got punched listener"); + break; + } + } + + if !punched.is_empty() { + tracing::debug!(?punched, "got punched socket and keep sending punch packet"); + } + + for p in &punched { + let (socket, remote_addr) = p; + let send_remote_ret = socket.send_to(&punch_packet, remote_addr).await; + tracing::debug!( + ?send_remote_ret, + ?socket, + "send hole punch packet to punched remote" + ); + } + } + + for l in listeners { + if l.get_conn_count().await > 0 { + common.add_listener(l).await; + } + } + }); + + *locked_task = Some(task.into()); + return Ok(SendPunchPacketBothEasySymResponse { + is_busy: false, + base_mapped_addr: Some(cur_mapped_addr.into()), + }); + } +} + +#[derive(Debug)] +pub(crate) struct PunchBothEasySymHoleClient { + peer_mgr: Arc, +} + +impl PunchBothEasySymHoleClient { + pub(crate) fn new(peer_mgr: Arc) -> Self { + Self { peer_mgr } + } + + #[tracing::instrument(ret)] + pub(crate) async fn do_hole_punching( + &self, + dst_peer_id: PeerId, + my_nat_info: UdpNatType, + peer_nat_info: UdpNatType, + is_busy: &mut bool, + ) -> Result, anyhow::Error> { + *is_busy = false; + + let udp_array = UdpSocketArray::new( + UDP_ARRAY_SIZE_FOR_BOTH_EASY_SYM, + self.peer_mgr.get_global_ctx().net_ns.clone(), + ); + udp_array.start().await?; + + let global_ctx = self.peer_mgr.get_global_ctx(); + let cur_mapped_addr = global_ctx + .get_stun_info_collector() + .get_udp_port_mapping(0) + .await + .with_context(|| "failed to get udp port mapping")?; + let my_public_ip = match cur_mapped_addr.ip() { + IpAddr::V4(v4) => v4, + _ => { + anyhow::bail!("ipv6 is not supported"); + } + }; + let me_is_incremental = my_nat_info + .get_inc_of_easy_sym() + .ok_or(anyhow::anyhow!("me_is_incremental is required"))?; + let peer_is_incremental = peer_nat_info + .get_inc_of_easy_sym() + .ok_or(anyhow::anyhow!("peer_is_incremental is required"))?; + + let rpc_stub = self + .peer_mgr + .get_peer_rpc_mgr() + .rpc_client() + .scoped_client::>( + self.peer_mgr.my_peer_id(), + dst_peer_id, + global_ctx.get_network_name(), + ); + + let tid = rand::random(); + udp_array.add_intreast_tid(tid); + + let remote_ret = rpc_stub + .send_punch_packet_both_easy_sym( + BaseController { + timeout_ms: 2000, + ..Default::default() + }, + SendPunchPacketBothEasySymRequest { + transaction_id: tid, + public_ip: Some(my_public_ip.into()), + dst_port_num: if me_is_incremental { + cur_mapped_addr.port().saturating_add(DST_PORT_OFFSET) + } else { + cur_mapped_addr.port().saturating_sub(DST_PORT_OFFSET) + } as u32, + udp_socket_count: UDP_ARRAY_SIZE_FOR_BOTH_EASY_SYM as u32, + wait_time_ms: REMOTE_WAIT_TIME_MS as u32, + }, + ) + .await?; + if remote_ret.is_busy { + *is_busy = true; + anyhow::bail!("remote is busy"); + } + + let mut remote_mapped_addr = remote_ret + .base_mapped_addr + .ok_or(anyhow::anyhow!("remote_mapped_addr is required"))?; + + let now = Instant::now(); + remote_mapped_addr.port = if peer_is_incremental { + remote_mapped_addr + .port + .saturating_add(DST_PORT_OFFSET as u32) + } else { + remote_mapped_addr + .port + .saturating_sub(DST_PORT_OFFSET as u32) + }; + tracing::debug!( + ?remote_mapped_addr, + ?remote_ret, + "start send hole punch packet for both easy sym" + ); + + while now.elapsed().as_millis() < (REMOTE_WAIT_TIME_MS + 1000).into() { + udp_array + .send_with_all( + &new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(), + remote_mapped_addr.into(), + ) + .await?; + + tokio::time::sleep(Duration::from_millis(100)).await; + + let Some(socket) = udp_array.try_fetch_punched_socket(tid) else { + tracing::trace!( + ?remote_mapped_addr, + ?tid, + "no punched socket found, send some more hole punch packets" + ); + continue; + }; + + tracing::info!( + ?socket, + ?remote_mapped_addr, + ?tid, + "got punched socket in both easy sym" + ); + + for _ in 0..2 { + match try_connect_with_socket(socket.socket.clone(), remote_mapped_addr.into()) + .await + { + Ok(tunnel) => { + return Ok(tunnel); + } + Err(e) => { + tracing::error!(?e, "failed to connect with socket"); + continue; + } + } + } + udp_array.add_new_socket(socket.socket).await?; + } + + anyhow::bail!("failed to punch hole for both easy sym"); + } +} + +#[cfg(test)] +pub mod tests { + use std::{ + sync::{atomic::AtomicU32, Arc}, + time::Duration, + }; + + use tokio::net::UdpSocket; + + use crate::{ + connector::udp_hole_punch::{ + tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector, + }, + peers::tests::{connect_peer_manager, wait_route_appear}, + proto::common::NatType, + tunnel::common::tests::wait_for_condition, + }; + + #[rstest::rstest] + #[tokio::test] + #[serial_test::serial(hole_punch)] + async fn hole_punching_easy_sym(#[values("true", "false")] is_inc: bool) { + let p_a = create_mock_peer_manager_with_mock_stun(if is_inc { + NatType::SymmetricEasyInc + } else { + NatType::SymmetricEasyDec + }) + .await; + let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + let p_c = create_mock_peer_manager_with_mock_stun(if !is_inc { + NatType::SymmetricEasyInc + } else { + NatType::SymmetricEasyDec + }) + .await; + connect_peer_manager(p_a.clone(), p_b.clone()).await; + connect_peer_manager(p_b.clone(), p_c.clone()).await; + wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); + + let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone()); + let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone()); + + hole_punching_a.run().await.unwrap(); + hole_punching_c.run().await.unwrap(); + + // 144 + DST_PORT_OFFSET = 164 + let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40164").await.unwrap()); + // 144 - DST_PORT_OFFSET = 124 + let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40124").await.unwrap()); + let udps = vec![udp1, udp2]; + + let counter = Arc::new(AtomicU32::new(0)); + + // all these sockets should receive hole punching packet + for udp in udps.iter().map(Arc::clone) { + let counter = counter.clone(); + tokio::spawn(async move { + let mut buf = [0u8; 1024]; + let (len, addr) = udp.recv_from(&mut buf).await.unwrap(); + println!( + "got predictable punch packet, {:?} {:?} {:?}", + len, + addr, + udp.local_addr() + ); + counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + }); + } + + hole_punching_a.client.run_immediately().await; + let udp_len = udps.len(); + wait_for_condition( + || async { counter.load(std::sync::atomic::Ordering::Relaxed) == udp_len as u32 }, + Duration::from_secs(30), + ) + .await; + } +} diff --git a/easytier/src/connector/udp_hole_punch/common.rs b/easytier/src/connector/udp_hole_punch/common.rs new file mode 100644 index 0000000..cd225af --- /dev/null +++ b/easytier/src/connector/udp_hole_punch/common.rs @@ -0,0 +1,573 @@ +use std::{ + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + sync::Arc, + time::Duration, +}; + +use crossbeam::atomic::AtomicCell; +use dashmap::{DashMap, DashSet}; +use rand::seq::SliceRandom as _; +use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; +use tracing::{instrument, Instrument, Level}; +use zerocopy::FromBytes as _; + +use crate::{ + common::{ + error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS, + stun::StunInfoCollectorTrait as _, + }, + defer, + peers::peer_manager::PeerManager, + proto::common::NatType, + tunnel::{ + packet_def::{UDPTunnelHeader, UdpPacketType, UDP_TUNNEL_HEADER_SIZE}, + udp::{new_hole_punch_packet, UdpTunnelConnector, UdpTunnelListener}, + Tunnel, TunnelConnCounter, TunnelListener as _, + }, +}; + +pub(crate) const HOLE_PUNCH_PACKET_BODY_LEN: u16 = 16; + +fn generate_shuffled_port_vec() -> Vec { + let mut rng = rand::thread_rng(); + let mut port_vec: Vec = (1..=65535).collect(); + port_vec.shuffle(&mut rng); + port_vec +} + +pub(crate) enum UdpPunchClientMethod { + None, + ConeToCone, + SymToCone, + EasySymToEasySym, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) enum UdpNatType { + Unknown, + Open(NatType), + Cone(NatType), + // bool means if it is incremental + EasySymmetric(NatType, bool), + HardSymmetric(NatType), +} + +impl From for UdpNatType { + fn from(nat_type: NatType) -> Self { + match nat_type { + NatType::Unknown => UdpNatType::Unknown, + NatType::NoPat | NatType::OpenInternet => UdpNatType::Open(nat_type), + NatType::FullCone | NatType::Restricted | NatType::PortRestricted => { + UdpNatType::Cone(nat_type) + } + NatType::Symmetric | NatType::SymUdpFirewall => UdpNatType::HardSymmetric(nat_type), + NatType::SymmetricEasyInc => UdpNatType::EasySymmetric(nat_type, true), + NatType::SymmetricEasyDec => UdpNatType::EasySymmetric(nat_type, false), + } + } +} + +impl Into for UdpNatType { + fn into(self) -> NatType { + match self { + UdpNatType::Unknown => NatType::Unknown, + UdpNatType::Open(nat_type) => nat_type, + UdpNatType::Cone(nat_type) => nat_type, + UdpNatType::EasySymmetric(nat_type, _) => nat_type, + UdpNatType::HardSymmetric(nat_type) => nat_type, + } + } +} + +impl UdpNatType { + pub(crate) fn is_open(&self) -> bool { + matches!(self, UdpNatType::Open(_)) + } + + pub(crate) fn is_unknown(&self) -> bool { + matches!(self, UdpNatType::Unknown) + } + + pub(crate) fn is_sym(&self) -> bool { + self.is_hard_sym() || self.is_easy_sym() + } + + pub(crate) fn is_hard_sym(&self) -> bool { + matches!(self, UdpNatType::HardSymmetric(_)) + } + + pub(crate) fn is_easy_sym(&self) -> bool { + matches!(self, UdpNatType::EasySymmetric(_, _)) + } + + pub(crate) fn is_cone(&self) -> bool { + matches!(self, UdpNatType::Cone(_)) + } + + pub(crate) fn get_inc_of_easy_sym(&self) -> Option { + match self { + UdpNatType::EasySymmetric(_, inc) => Some(*inc), + _ => None, + } + } + + pub(crate) fn get_punch_hole_method(&self, other: Self) -> UdpPunchClientMethod { + if other.is_unknown() { + if self.is_sym() { + return UdpPunchClientMethod::SymToCone; + } else { + return UdpPunchClientMethod::ConeToCone; + } + } + + if self.is_unknown() { + if other.is_sym() { + return UdpPunchClientMethod::None; + } else { + return UdpPunchClientMethod::ConeToCone; + } + } + + if self.is_open() || other.is_open() { + // open nat does not need to punch hole + return UdpPunchClientMethod::None; + } + + if self.is_cone() { + if other.is_sym() { + return UdpPunchClientMethod::None; + } else { + return UdpPunchClientMethod::ConeToCone; + } + } else if self.is_easy_sym() { + if other.is_hard_sym() { + return UdpPunchClientMethod::None; + } else if other.is_easy_sym() { + return UdpPunchClientMethod::EasySymToEasySym; + } else { + return UdpPunchClientMethod::SymToCone; + } + } else if self.is_hard_sym() { + if other.is_sym() { + return UdpPunchClientMethod::None; + } else { + return UdpPunchClientMethod::SymToCone; + } + } + + unreachable!("invalid nat type"); + } + + pub(crate) fn can_punch_hole_as_client(&self, other: Self) -> bool { + !matches!( + self.get_punch_hole_method(other), + UdpPunchClientMethod::None + ) + } +} + +#[derive(Debug)] +pub(crate) struct PunchedUdpSocket { + pub(crate) socket: Arc, + pub(crate) tid: u32, + pub(crate) remote_addr: SocketAddr, +} + +// used for symmetric hole punching, binding to multiple ports to increase the chance of success +pub(crate) struct UdpSocketArray { + sockets: Arc>>, + max_socket_count: usize, + net_ns: NetNS, + tasks: Arc>>, + + intreast_tids: Arc>, + tid_to_socket: Arc>>, +} + +impl UdpSocketArray { + pub fn new(max_socket_count: usize, net_ns: NetNS) -> Self { + let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); + join_joinset_background(tasks.clone(), "UdpSocketArray".to_owned()); + + Self { + sockets: Arc::new(DashMap::new()), + max_socket_count, + net_ns, + tasks, + + intreast_tids: Arc::new(DashSet::new()), + tid_to_socket: Arc::new(DashMap::new()), + } + } + + pub fn started(&self) -> bool { + !self.sockets.is_empty() + } + + pub async fn add_new_socket(&self, socket: Arc) -> Result<(), anyhow::Error> { + let socket_map = self.sockets.clone(); + let local_addr = socket.local_addr()?; + let intreast_tids = self.intreast_tids.clone(); + let tid_to_socket = self.tid_to_socket.clone(); + socket_map.insert(local_addr, socket.clone()); + self.tasks.lock().unwrap().spawn( + async move { + defer!(socket_map.remove(&local_addr);); + let mut buf = [0u8; UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize]; + tracing::trace!(?local_addr, "udp socket added"); + loop { + let Ok((len, addr)) = socket.recv_from(&mut buf).await else { + break; + }; + + tracing::debug!(?len, ?addr, "got raw packet"); + + if len != UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize { + continue; + } + + let Some(p) = UDPTunnelHeader::ref_from_prefix(&buf) else { + continue; + }; + + let tid = p.conn_id.get(); + let valid = p.msg_type == UdpPacketType::HolePunch as u8 + && p.len.get() == HOLE_PUNCH_PACKET_BODY_LEN; + tracing::debug!(?p, ?addr, ?tid, ?valid, ?p, "got udp hole punch packet"); + + if !valid { + continue; + } + + if intreast_tids.contains(&tid) { + tracing::info!(?addr, ?tid, "got hole punching packet with intreast tid"); + tid_to_socket + .entry(tid) + .or_insert_with(Vec::new) + .push(PunchedUdpSocket { + socket: socket.clone(), + tid, + remote_addr: addr, + }); + break; + } + } + tracing::debug!(?local_addr, "udp socket recv loop end"); + } + .instrument(tracing::info_span!("udp array socket recv loop")), + ); + Ok(()) + } + + #[instrument(err)] + pub async fn start(&self) -> Result<(), anyhow::Error> { + tracing::info!("starting udp socket array"); + + while self.sockets.len() < self.max_socket_count { + let socket = { + let _g = self.net_ns.guard(); + Arc::new(UdpSocket::bind("0.0.0.0:0").await?) + }; + + self.add_new_socket(socket).await?; + } + + Ok(()) + } + + #[instrument(err)] + pub async fn send_with_all(&self, data: &[u8], addr: SocketAddr) -> Result<(), anyhow::Error> { + tracing::info!(?addr, "sending hole punching packet"); + + for socket in self.sockets.iter() { + let socket = socket.value(); + socket.send_to(data, addr).await?; + } + + Ok(()) + } + + #[instrument(ret(level = Level::DEBUG))] + pub fn try_fetch_punched_socket(&self, tid: u32) -> Option { + tracing::debug!(?tid, "try fetch punched socket"); + self.tid_to_socket.get_mut(&tid)?.value_mut().pop() + } + + pub fn add_intreast_tid(&self, tid: u32) { + self.intreast_tids.insert(tid); + } + + pub fn remove_intreast_tid(&self, tid: u32) { + self.intreast_tids.remove(&tid); + self.tid_to_socket.remove(&tid); + } +} + +impl std::fmt::Debug for UdpSocketArray { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("UdpSocketArray") + .field("sockets", &self.sockets.len()) + .field("max_socket_count", &self.max_socket_count) + .field("started", &self.started()) + .field("intreast_tids", &self.intreast_tids.len()) + .field("tid_to_socket", &self.tid_to_socket.len()) + .finish() + } +} + +#[derive(Debug)] +pub(crate) struct UdpHolePunchListener { + socket: Arc, + tasks: JoinSet<()>, + running: Arc>, + mapped_addr: SocketAddr, + conn_counter: Arc>, + + listen_time: std::time::Instant, + last_select_time: AtomicCell, + last_active_time: Arc>, +} + +impl UdpHolePunchListener { + async fn get_avail_port() -> Result { + let socket = UdpSocket::bind("0.0.0.0:0").await?; + Ok(socket.local_addr()?.port()) + } + + #[instrument(err)] + pub async fn new(peer_mgr: Arc) -> Result { + Self::new_ext(peer_mgr, true, None).await + } + + #[instrument(err)] + pub async fn new_ext( + peer_mgr: Arc, + with_mapped_addr: bool, + port: Option, + ) -> Result { + let port = port.unwrap_or(Self::get_avail_port().await?); + let listen_url = format!("udp://0.0.0.0:{}", port); + + let mapped_addr = if with_mapped_addr { + let gctx = peer_mgr.get_global_ctx(); + let stun_info_collect = gctx.get_stun_info_collector(); + stun_info_collect.get_udp_port_mapping(port).await? + } else { + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), port)) + }; + + let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap()); + + { + let _g = peer_mgr.get_global_ctx().net_ns.guard(); + listener.listen().await?; + } + let socket = listener.get_socket().unwrap(); + + let running = Arc::new(AtomicCell::new(true)); + let running_clone = running.clone(); + + let conn_counter = listener.get_conn_counter(); + let mut tasks = JoinSet::new(); + + tasks.spawn(async move { + while let Ok(conn) = listener.accept().await { + tracing::warn!(?conn, "udp hole punching listener got peer connection"); + let peer_mgr = peer_mgr.clone(); + tokio::spawn(async move { + if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await { + tracing::error!( + ?e, + "failed to add tunnel as server in hole punch listener" + ); + } + }); + } + + running_clone.store(false); + }); + + let last_active_time = Arc::new(AtomicCell::new(std::time::Instant::now())); + let conn_counter_clone = conn_counter.clone(); + let last_active_time_clone = last_active_time.clone(); + tasks.spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + if conn_counter_clone.get().unwrap_or(0) != 0 { + last_active_time_clone.store(std::time::Instant::now()); + } + } + }); + + tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started"); + + Ok(Self { + tasks, + socket, + running, + mapped_addr, + conn_counter, + + listen_time: std::time::Instant::now(), + last_select_time: AtomicCell::new(std::time::Instant::now()), + last_active_time, + }) + } + + pub async fn get_socket(&self) -> Arc { + self.last_select_time.store(std::time::Instant::now()); + self.socket.clone() + } + + pub async fn get_conn_count(&self) -> usize { + self.conn_counter.get().unwrap_or(0) as usize + } +} + +pub(crate) struct PunchHoleServerCommon { + peer_mgr: Arc, + + listeners: Arc>>, + tasks: Arc>>, +} + +impl PunchHoleServerCommon { + pub(crate) fn new(peer_mgr: Arc) -> Self { + let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); + join_joinset_background(tasks.clone(), "PunchHoleServerCommon".to_owned()); + + let listeners = Arc::new(Mutex::new(Vec::::new())); + + let l = listeners.clone(); + tasks.lock().unwrap().spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(5)).await; + { + // remove listener that is not active for 40 seconds but keep listeners that are selected less than 30 seconds + l.lock().await.retain(|listener| { + listener.last_active_time.load().elapsed().as_secs() < 40 + || listener.last_select_time.load().elapsed().as_secs() < 30 + }); + } + } + }); + + Self { + peer_mgr, + + listeners, + tasks, + } + } + + pub(crate) async fn add_listener(&self, listener: UdpHolePunchListener) { + self.listeners.lock().await.push(listener); + } + + pub(crate) async fn find_listener(&self, addr: &SocketAddr) -> Option> { + let all_listener_sockets = self.listeners.lock().await; + + let listener = all_listener_sockets + .iter() + .find(|listener| listener.mapped_addr == *addr && listener.running.load())?; + + Some(listener.get_socket().await) + } + + pub(crate) async fn my_udp_nat_type(&self) -> i32 { + self.peer_mgr + .get_global_ctx() + .get_stun_info_collector() + .get_stun_info() + .udp_nat_type + } + + pub(crate) async fn select_listener( + &self, + use_new_listener: bool, + ) -> Option<(Arc, SocketAddr)> { + let all_listener_sockets = &self.listeners; + + let mut use_last = false; + if all_listener_sockets.lock().await.len() < 16 || use_new_listener { + tracing::warn!("creating new udp hole punching listener"); + all_listener_sockets.lock().await.push( + UdpHolePunchListener::new(self.peer_mgr.clone()) + .await + .ok()?, + ); + use_last = true; + } + + let locked = all_listener_sockets.lock().await; + + let listener = if use_last { + locked.last()? + } else { + // use the listener that is active most recently + locked + .iter() + .max_by_key(|listener| listener.last_active_time.load())? + }; + + Some((listener.get_socket().await, listener.mapped_addr)) + } + + pub(crate) fn get_joinset(&self) -> Arc>> { + self.tasks.clone() + } + + pub(crate) fn get_global_ctx(&self) -> ArcGlobalCtx { + self.peer_mgr.get_global_ctx() + } + + pub(crate) fn get_peer_mgr(&self) -> Arc { + self.peer_mgr.clone() + } +} + +#[tracing::instrument(err, ret(level=Level::DEBUG), skip(ports))] +pub(crate) async fn send_symmetric_hole_punch_packet( + ports: &Vec, + udp: Arc, + transaction_id: u32, + public_ips: &Vec, + port_start_idx: usize, + max_packets: usize, +) -> Result { + tracing::debug!("sending hard symmetric hole punching packet"); + let mut sent_packets = 0; + let mut cur_port_idx = port_start_idx; + while sent_packets < max_packets { + let port = ports[cur_port_idx % ports.len()]; + for pub_ip in public_ips { + let addr = SocketAddr::V4(SocketAddrV4::new(*pub_ip, port)); + let packet = new_hole_punch_packet(transaction_id, HOLE_PUNCH_PACKET_BODY_LEN); + udp.send_to(&packet.into_bytes(), addr).await?; + sent_packets += 1; + } + cur_port_idx = cur_port_idx.wrapping_add(1); + tokio::time::sleep(Duration::from_millis(3)).await; + } + Ok(cur_port_idx % ports.len()) +} + +pub(crate) async fn try_connect_with_socket( + socket: Arc, + remote_mapped_addr: SocketAddr, +) -> Result, Error> { + let connector = UdpTunnelConnector::new( + format!( + "udp://{}:{}", + remote_mapped_addr.ip(), + remote_mapped_addr.port() + ) + .to_string() + .parse() + .unwrap(), + ); + connector + .try_connect_with_socket(socket, remote_mapped_addr) + .await + .map_err(|e| Error::from(e)) +} diff --git a/easytier/src/connector/udp_hole_punch/cone.rs b/easytier/src/connector/udp_hole_punch/cone.rs new file mode 100644 index 0000000..e8acf20 --- /dev/null +++ b/easytier/src/connector/udp_hole_punch/cone.rs @@ -0,0 +1,258 @@ +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; + +use anyhow::Context; +use tokio::net::UdpSocket; + +use crate::{ + common::{scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId}, + connector::udp_hole_punch::common::{ + try_connect_with_socket, UdpSocketArray, HOLE_PUNCH_PACKET_BODY_LEN, + }, + peers::peer_manager::PeerManager, + proto::{ + common::Void, + peer_rpc::{ + SelectPunchListenerRequest, SendPunchPacketConeRequest, UdpHolePunchRpcClientFactory, + }, + rpc_types::{self, controller::BaseController}, + }, + tunnel::{udp::new_hole_punch_packet, Tunnel}, +}; + +use super::common::PunchHoleServerCommon; + +pub(crate) struct PunchConeHoleServer { + common: Arc, +} + +impl PunchConeHoleServer { + pub(crate) fn new(common: Arc) -> Self { + Self { common } + } + + #[tracing::instrument(skip(self), ret, err)] + pub(crate) async fn send_punch_packet_cone( + &self, + _: BaseController, + request: SendPunchPacketConeRequest, + ) -> Result { + let listener_addr = request.listener_mapped_addr.ok_or(anyhow::anyhow!( + "send_punch_packet_for_cone request missing listener_mapped_addr" + ))?; + let listener_addr = std::net::SocketAddr::from(listener_addr); + let listener = self + .common + .find_listener(&listener_addr) + .await + .ok_or(anyhow::anyhow!( + "send_punch_packet_for_cone failed to find listener" + ))?; + + let dest_addr = request.dest_addr.ok_or(anyhow::anyhow!( + "send_punch_packet_for_cone request missing dest_addr" + ))?; + let dest_addr = std::net::SocketAddr::from(dest_addr); + let dest_ip = dest_addr.ip(); + if dest_ip.is_unspecified() || dest_ip.is_multicast() { + return Err(anyhow::anyhow!( + "send_punch_packet_for_cone dest_ip is malformed, {:?}", + request + ) + .into()); + } + + for _ in 0..request.packet_batch_count { + tracing::info!(?request, "sending hole punching packet"); + + for _ in 0..request.packet_count_per_batch { + let udp_packet = + new_hole_punch_packet(request.transaction_id, HOLE_PUNCH_PACKET_BODY_LEN); + if let Err(e) = listener.send_to(&udp_packet.into_bytes(), &dest_addr).await { + tracing::error!(?e, "failed to send hole punch packet to dest addr"); + } + } + tokio::time::sleep(Duration::from_millis(request.packet_interval_ms as u64)).await; + } + + Ok(Void::default()) + } +} + +pub(crate) struct PunchConeHoleClient { + peer_mgr: Arc, +} + +impl PunchConeHoleClient { + pub(crate) fn new(peer_mgr: Arc) -> Self { + Self { peer_mgr } + } + + #[tracing::instrument(skip(self))] + pub(crate) async fn do_hole_punching( + &self, + dst_peer_id: PeerId, + ) -> Result, anyhow::Error> { + tracing::info!(?dst_peer_id, "start hole punching"); + let tid = rand::random(); + + let global_ctx = self.peer_mgr.get_global_ctx(); + let udp_array = UdpSocketArray::new(1, global_ctx.net_ns.clone()); + let local_socket = { + let _g = self.peer_mgr.get_global_ctx().net_ns.guard(); + Arc::new(UdpSocket::bind("0.0.0.0:0").await?) + }; + + let local_addr = local_socket + .local_addr() + .with_context(|| anyhow::anyhow!("failed to get local port from udp array"))?; + let local_port = local_addr.port(); + + let local_mapped_addr = global_ctx + .get_stun_info_collector() + .get_udp_port_mapping(local_port) + .await + .with_context(|| "failed to get udp port mapping")?; + + // client -> server: tell server the mapped port, server will return the mapped address of listening port. + let rpc_stub = self + .peer_mgr + .get_peer_rpc_mgr() + .rpc_client() + .scoped_client::>( + self.peer_mgr.my_peer_id(), + dst_peer_id, + global_ctx.get_network_name(), + ); + + let resp = rpc_stub + .select_punch_listener( + BaseController::default(), + SelectPunchListenerRequest { force_new: false }, + ) + .await + .with_context(|| "failed to select punch listener")?; + let remote_mapped_addr = resp.listener_mapped_addr.ok_or(anyhow::anyhow!( + "select_punch_listener response missing listener_mapped_addr" + ))?; + + tracing::debug!( + ?local_mapped_addr, + ?remote_mapped_addr, + "hole punch got remote listener" + ); + + udp_array.add_new_socket(local_socket).await?; + udp_array.add_intreast_tid(tid); + let send_from_local = || async { + udp_array + .send_with_all( + &new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(), + remote_mapped_addr.clone().into(), + ) + .await + .with_context(|| "failed to send hole punch packet from local") + }; + + send_from_local().await?; + + let scoped_punch_task: ScopedTask<()> = tokio::spawn(async move { + if let Err(e) = rpc_stub + .send_punch_packet_cone( + BaseController { + timeout_ms: 4000, + ..Default::default() + }, + SendPunchPacketConeRequest { + listener_mapped_addr: Some(remote_mapped_addr.into()), + dest_addr: Some(local_mapped_addr.into()), + transaction_id: tid, + packet_count_per_batch: 2, + packet_batch_count: 5, + packet_interval_ms: 400, + }, + ) + .await + { + tracing::error!(?e, "failed to call remote send punch packet"); + } + }) + .into(); + + // server: will send some punching resps, total 10 packets. + // client: use the socket to create UdpTunnel with UdpTunnelConnector + // NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote. + let mut finish_time: Option = None; + while finish_time.is_none() || finish_time.as_ref().unwrap().elapsed().as_millis() < 1000 { + tokio::time::sleep(Duration::from_millis(200)).await; + + if finish_time.is_none() && (*scoped_punch_task).is_finished() { + finish_time = Some(Instant::now()); + } + + let Some(socket) = udp_array.try_fetch_punched_socket(tid) else { + tracing::debug!("no punched socket found, send some more hole punch packets"); + send_from_local().await?; + continue; + }; + + tracing::debug!(?socket, ?tid, "punched socket found, try connect with it"); + + for _ in 0..2 { + match try_connect_with_socket(socket.socket.clone(), remote_mapped_addr.into()) + .await + { + Ok(tunnel) => { + tracing::info!(?tunnel, "hole punched"); + return Ok(tunnel); + } + Err(e) => { + tracing::error!(?e, "failed to connect with socket"); + } + } + } + } + + return Err(anyhow::anyhow!("punch task finished but no hole punched")); + } +} + +#[cfg(test)] +pub mod tests { + + use crate::{ + connector::udp_hole_punch::{ + tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector, + }, + peers::tests::{connect_peer_manager, wait_route_appear, wait_route_appear_with_cost}, + proto::common::NatType, + }; + + #[tokio::test] + async fn hole_punching_cone() { + let p_a = create_mock_peer_manager_with_mock_stun(NatType::Restricted).await; + let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + let p_c = create_mock_peer_manager_with_mock_stun(NatType::Restricted).await; + connect_peer_manager(p_a.clone(), p_b.clone()).await; + connect_peer_manager(p_b.clone(), p_c.clone()).await; + + wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); + + println!("{:?}", p_a.list_routes().await); + + let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone()); + let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone()); + + hole_punching_a.run_as_client().await.unwrap(); + hole_punching_c.run_as_server().await.unwrap(); + + hole_punching_a.client.run_immediately().await; + + wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1)) + .await + .unwrap(); + println!("{:?}", p_a.list_routes().await); + } +} diff --git a/easytier/src/connector/udp_hole_punch/mod.rs b/easytier/src/connector/udp_hole_punch/mod.rs new file mode 100644 index 0000000..0124149 --- /dev/null +++ b/easytier/src/connector/udp_hole_punch/mod.rs @@ -0,0 +1,482 @@ +use std::sync::Arc; + +use anyhow::Error; +use both_easy_sym::{PunchBothEasySymHoleClient, PunchBothEasySymHoleServer}; +use common::{PunchHoleServerCommon, UdpNatType, UdpPunchClientMethod}; +use cone::{PunchConeHoleClient, PunchConeHoleServer}; +use sym_to_cone::{PunchSymToConeHoleClient, PunchSymToConeHoleServer}; +use tokio::{sync::Mutex, task::JoinHandle}; + +use crate::{ + common::{stun::StunInfoCollectorTrait, PeerId}, + connector::direct::PeerManagerForDirectConnector, + peers::{ + peer_manager::PeerManager, + peer_task::{PeerTaskLauncher, PeerTaskManager}, + }, + proto::{ + common::{NatType, Void}, + peer_rpc::{ + SelectPunchListenerRequest, SelectPunchListenerResponse, + SendPunchPacketBothEasySymRequest, SendPunchPacketBothEasySymResponse, + SendPunchPacketConeRequest, SendPunchPacketEasySymRequest, + SendPunchPacketHardSymRequest, SendPunchPacketHardSymResponse, UdpHolePunchRpc, + UdpHolePunchRpcServer, + }, + rpc_types::{self, controller::BaseController}, + }, +}; + +pub(crate) mod both_easy_sym; +pub(crate) mod common; +pub(crate) mod cone; +pub(crate) mod sym_to_cone; + +struct UdpHolePunchServer { + common: Arc, + cone_server: PunchConeHoleServer, + sym_to_cone_server: PunchSymToConeHoleServer, + both_easy_sym_server: PunchBothEasySymHoleServer, +} + +impl UdpHolePunchServer { + pub fn new(peer_mgr: Arc) -> Arc { + let common = Arc::new(PunchHoleServerCommon::new(peer_mgr.clone())); + let cone_server = PunchConeHoleServer::new(common.clone()); + let sym_to_cone_server = PunchSymToConeHoleServer::new(common.clone()); + let both_easy_sym_server = PunchBothEasySymHoleServer::new(common.clone()); + + Arc::new(Self { + common, + cone_server, + sym_to_cone_server, + both_easy_sym_server, + }) + } +} + +#[async_trait::async_trait] +impl UdpHolePunchRpc for UdpHolePunchServer { + type Controller = BaseController; + + async fn select_punch_listener( + &self, + _ctrl: Self::Controller, + input: SelectPunchListenerRequest, + ) -> rpc_types::error::Result { + let (_, addr) = self + .common + .select_listener(input.force_new) + .await + .ok_or(anyhow::anyhow!("no listener available"))?; + + Ok(SelectPunchListenerResponse { + listener_mapped_addr: Some(addr.into()), + }) + } + + /// send packet to one remote_addr, used by nat1-3 to nat1-3 + async fn send_punch_packet_cone( + &self, + ctrl: Self::Controller, + input: SendPunchPacketConeRequest, + ) -> rpc_types::error::Result { + self.cone_server.send_punch_packet_cone(ctrl, input).await + } + + /// send packet to multiple remote_addr (birthday attack), used by nat4 to nat1-3 + async fn send_punch_packet_hard_sym( + &self, + _ctrl: Self::Controller, + input: SendPunchPacketHardSymRequest, + ) -> rpc_types::error::Result { + self.sym_to_cone_server + .send_punch_packet_hard_sym(input) + .await + } + + async fn send_punch_packet_easy_sym( + &self, + _ctrl: Self::Controller, + input: SendPunchPacketEasySymRequest, + ) -> rpc_types::error::Result { + self.sym_to_cone_server + .send_punch_packet_easy_sym(input) + .await + .map(|_| Void {}) + } + + /// nat4 to nat4 (both predictably) + async fn send_punch_packet_both_easy_sym( + &self, + _ctrl: Self::Controller, + input: SendPunchPacketBothEasySymRequest, + ) -> rpc_types::error::Result { + self.both_easy_sym_server + .send_punch_packet_both_easy_sym(input) + .await + } +} + +struct BackOff { + backoffs_ms: Vec, + current_idx: usize, +} + +impl BackOff { + pub fn new(backoffs_ms: Vec) -> Self { + Self { + backoffs_ms, + current_idx: 0, + } + } + + pub fn next_backoff(&mut self) -> u64 { + let backoff = self.backoffs_ms[self.current_idx]; + self.current_idx = (self.current_idx + 1).min(self.backoffs_ms.len() - 1); + backoff + } + + pub fn rollback(&mut self) { + self.current_idx = self.current_idx.saturating_sub(1); + } + + pub async fn sleep_for_next_backoff(&mut self) { + let backoff = self.next_backoff(); + if backoff > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(backoff)).await; + } + } +} + +struct UdpHoePunchConnectorData { + cone_client: PunchConeHoleClient, + sym_to_cone_client: PunchSymToConeHoleClient, + both_easy_sym_client: PunchBothEasySymHoleClient, + peer_mgr: Arc, + + // sym punch should be serialized + sym_punch_lock: Mutex<()>, +} + +impl UdpHoePunchConnectorData { + pub fn new(peer_mgr: Arc) -> Arc { + let cone_client = PunchConeHoleClient::new(peer_mgr.clone()); + let sym_to_cone_client = PunchSymToConeHoleClient::new(peer_mgr.clone()); + let both_easy_sym_client = PunchBothEasySymHoleClient::new(peer_mgr.clone()); + + Arc::new(Self { + cone_client, + sym_to_cone_client, + both_easy_sym_client, + peer_mgr, + sym_punch_lock: Mutex::new(()), + }) + } + + #[tracing::instrument(skip(self))] + async fn cone_to_cone(self: Arc, task_info: PunchTaskInfo) -> Result<(), Error> { + let mut backoff = BackOff::new(vec![0, 1000, 2000, 4000, 4000, 8000, 8000, 16000]); + + loop { + backoff.sleep_for_next_backoff().await; + + let ret = self + .cone_client + .do_hole_punching(task_info.dst_peer_id) + .await; + if let Err(e) = ret { + tracing::info!(?e, "cone_to_cone hole punching failed"); + continue; + } + + if let Err(e) = self.peer_mgr.add_client_tunnel(ret.unwrap()).await { + tracing::warn!(?e, "cone_to_cone add client tunnel failed"); + continue; + } + + break; + } + + tracing::info!("cone_to_cone hole punching success"); + Ok(()) + } + + #[tracing::instrument(skip(self))] + async fn sym_to_cone(self: Arc, task_info: PunchTaskInfo) -> Result<(), Error> { + let mut backoff = BackOff::new(vec![0, 1000, 2000, 4000, 4000, 8000, 8000, 16000, 64000]); + let mut round = 0; + let mut port_idx = rand::random(); + + loop { + backoff.sleep_for_next_backoff().await; + + let ret = { + let _lock = self.sym_punch_lock.lock().await; + self.sym_to_cone_client + .do_hole_punching( + task_info.dst_peer_id, + round, + &mut port_idx, + task_info.my_nat_type, + ) + .await + }; + + round += 1; + + if let Err(e) = ret { + tracing::info!(?e, "sym_to_cone hole punching failed"); + continue; + } + + if let Err(e) = self.peer_mgr.add_client_tunnel(ret.unwrap()).await { + tracing::warn!(?e, "sym_to_cone add client tunnel failed"); + continue; + } + + break; + } + + Ok(()) + } + + #[tracing::instrument(skip(self))] + async fn both_easy_sym(self: Arc, task_info: PunchTaskInfo) -> Result<(), Error> { + let mut backoff = BackOff::new(vec![0, 1000, 2000, 4000, 4000, 8000, 8000, 16000, 64000]); + + loop { + backoff.sleep_for_next_backoff().await; + + let mut is_busy = false; + + let ret = { + let _lock = self.sym_punch_lock.lock().await; + self.both_easy_sym_client + .do_hole_punching( + task_info.dst_peer_id, + task_info.my_nat_type, + task_info.dst_nat_type, + &mut is_busy, + ) + .await + }; + + if is_busy { + backoff.rollback(); + } + + if let Err(e) = ret { + tracing::info!(?e, "both_easy_sym hole punching failed"); + continue; + } + + if let Err(e) = self.peer_mgr.add_client_tunnel(ret.unwrap()).await { + tracing::warn!(?e, "both_easy_sym add client tunnel failed"); + continue; + } + + break; + } + + Ok(()) + } +} + +#[derive(Clone)] +struct UdpHolePunchPeerTaskLauncher {} + +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +struct PunchTaskInfo { + dst_peer_id: PeerId, + dst_nat_type: UdpNatType, + my_nat_type: UdpNatType, +} + +#[async_trait::async_trait] +impl PeerTaskLauncher for UdpHolePunchPeerTaskLauncher { + type Data = Arc; + type CollectPeerItem = PunchTaskInfo; + type TaskRet = (); + + fn new_data(&self, peer_mgr: Arc) -> Self::Data { + UdpHoePunchConnectorData::new(peer_mgr) + } + + async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec { + let my_nat_type = data + .peer_mgr + .get_global_ctx() + .get_stun_info_collector() + .get_stun_info() + .udp_nat_type; + let my_nat_type: UdpNatType = NatType::try_from(my_nat_type) + .unwrap_or(NatType::Unknown) + .into(); + if !my_nat_type.is_sym() { + data.sym_to_cone_client.clear_udp_array().await; + } + + let mut peers_to_connect: Vec = Vec::new(); + // do not do anything if: + // 1. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us + // notice that if we are unknown, we treat ourselves as cone + if my_nat_type.is_open() { + return peers_to_connect; + } + + // collect peer list from peer manager and do some filter: + // 1. peers without direct conns; + // 2. peers is full cone (any restricted type); + for route in data.peer_mgr.list_routes().await.iter() { + if route + .feature_flag + .map(|x| x.is_public_server) + .unwrap_or(false) + { + continue; + } + + let peer_nat_type = route + .stun_info + .as_ref() + .map(|x| x.udp_nat_type) + .unwrap_or(0); + let Ok(peer_nat_type) = NatType::try_from(peer_nat_type) else { + continue; + }; + let peer_nat_type = peer_nat_type.into(); + + let peer_id: PeerId = route.peer_id; + let conns = data.peer_mgr.list_peer_conns(peer_id).await; + if conns.is_some() && conns.unwrap().len() > 0 { + continue; + } + + if !my_nat_type.can_punch_hole_as_client(peer_nat_type) { + continue; + } + + tracing::info!( + ?peer_id, + ?peer_nat_type, + ?my_nat_type, + "found peer to do hole punching" + ); + + peers_to_connect.push(PunchTaskInfo { + dst_peer_id: peer_id, + dst_nat_type: peer_nat_type, + my_nat_type, + }); + } + + peers_to_connect + } + + async fn launch_task( + &self, + data: &Self::Data, + item: Self::CollectPeerItem, + ) -> JoinHandle> { + let data = data.clone(); + let punch_method = item.my_nat_type.get_punch_hole_method(item.dst_nat_type); + match punch_method { + UdpPunchClientMethod::ConeToCone => tokio::spawn(data.cone_to_cone(item)), + UdpPunchClientMethod::SymToCone => tokio::spawn(data.sym_to_cone(item)), + UdpPunchClientMethod::EasySymToEasySym => tokio::spawn(data.both_easy_sym(item)), + _ => unreachable!(), + } + } + + async fn all_task_done(&self, data: &Self::Data) { + data.sym_to_cone_client.clear_udp_array().await; + } + + fn loop_interval_ms(&self) -> u64 { + 5000 + } +} + +pub struct UdpHolePunchConnector { + server: Arc, + client: PeerTaskManager, + peer_mgr: Arc, +} + +// Currently support: +// Symmetric -> Full Cone +// Any Type of Full Cone -> Any Type of Full Cone + +// if same level of full cone, node with smaller peer_id will be the initiator +// if different level of full cone, node with more strict level will be the initiator + +impl UdpHolePunchConnector { + pub fn new(peer_mgr: Arc) -> Self { + Self { + server: UdpHolePunchServer::new(peer_mgr.clone()), + client: PeerTaskManager::new(UdpHolePunchPeerTaskLauncher {}, peer_mgr.clone()), + peer_mgr, + } + } + + pub async fn run_as_client(&mut self) -> Result<(), Error> { + self.client.start(); + Ok(()) + } + + pub async fn run_as_server(&mut self) -> Result<(), Error> { + self.peer_mgr + .get_peer_rpc_mgr() + .rpc_server() + .registry() + .register( + UdpHolePunchRpcServer::new(self.server.clone()), + &self.peer_mgr.get_global_ctx().get_network_name(), + ); + + Ok(()) + } + + pub async fn run(&mut self) -> Result<(), Error> { + let global_ctx = self.peer_mgr.get_global_ctx(); + + if global_ctx.get_flags().disable_p2p { + return Ok(()); + } + if global_ctx.get_flags().disable_udp_hole_punching { + return Ok(()); + } + + self.run_as_client().await?; + self.run_as_server().await?; + + Ok(()) + } +} + +#[cfg(test)] +pub mod tests { + + use std::sync::Arc; + + use crate::common::stun::MockStunInfoCollector; + use crate::proto::common::NatType; + + use crate::peers::{peer_manager::PeerManager, tests::create_mock_peer_manager}; + + pub fn replace_stun_info_collector(peer_mgr: Arc, udp_nat_type: NatType) { + let collector = Box::new(MockStunInfoCollector { udp_nat_type }); + peer_mgr + .get_global_ctx() + .replace_stun_info_collector(collector); + } + + pub async fn create_mock_peer_manager_with_mock_stun( + udp_nat_type: NatType, + ) -> Arc { + let p_a = create_mock_peer_manager().await; + replace_stun_info_collector(p_a.clone(), udp_nat_type); + p_a + } +} diff --git a/easytier/src/connector/udp_hole_punch/sym_to_cone.rs b/easytier/src/connector/udp_hole_punch/sym_to_cone.rs new file mode 100644 index 0000000..eae949d --- /dev/null +++ b/easytier/src/connector/udp_hole_punch/sym_to_cone.rs @@ -0,0 +1,591 @@ +use std::{ + net::Ipv4Addr, + ops::{Div, Mul}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::{Duration, Instant}, +}; + +use anyhow::Context; +use rand::{seq::SliceRandom, Rng}; +use tokio::{net::UdpSocket, sync::RwLock}; +use tracing::Level; + +use crate::{ + common::{scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId}, + connector::udp_hole_punch::common::{ + send_symmetric_hole_punch_packet, try_connect_with_socket, HOLE_PUNCH_PACKET_BODY_LEN, + }, + defer, + peers::peer_manager::PeerManager, + proto::{ + peer_rpc::{ + SelectPunchListenerRequest, SendPunchPacketEasySymRequest, + SendPunchPacketHardSymRequest, SendPunchPacketHardSymResponse, + UdpHolePunchRpcClientFactory, + }, + rpc_types::{self, controller::BaseController}, + }, + tunnel::{udp::new_hole_punch_packet, Tunnel}, +}; + +use super::common::{PunchHoleServerCommon, UdpNatType, UdpSocketArray}; + +const UDP_ARRAY_SIZE_FOR_HARD_SYM: usize = 84; + +pub(crate) struct PunchSymToConeHoleServer { + common: Arc, + + shuffled_port_vec: Arc>, +} + +impl PunchSymToConeHoleServer { + pub(crate) fn new(common: Arc) -> Self { + let mut shuffled_port_vec: Vec = (1..=65535).collect(); + shuffled_port_vec.shuffle(&mut rand::thread_rng()); + + Self { + common, + shuffled_port_vec: Arc::new(shuffled_port_vec), + } + } + + // hard sym means public port is random and cannot be predicted + #[tracing::instrument(skip(self), ret)] + pub(crate) async fn send_punch_packet_easy_sym( + &self, + request: SendPunchPacketEasySymRequest, + ) -> Result<(), rpc_types::error::Error> { + tracing::info!("send_punch_packet_easy_sym start"); + + let listener_addr = request.listener_mapped_addr.ok_or(anyhow::anyhow!( + "send_punch_packet_easy_sym request missing listener_addr" + ))?; + let listener_addr = std::net::SocketAddr::from(listener_addr); + let listener = self + .common + .find_listener(&listener_addr) + .await + .ok_or(anyhow::anyhow!( + "send_punch_packet_easy_sym failed to find listener" + ))?; + + let public_ips = request + .public_ips + .into_iter() + .map(|ip| std::net::Ipv4Addr::from(ip)) + .collect::>(); + if public_ips.len() == 0 { + tracing::warn!("send_punch_packet_easy_sym got zero len public ip"); + return Err( + anyhow::anyhow!("send_punch_packet_easy_sym got zero len public ip").into(), + ); + } + + let transaction_id = request.transaction_id; + let base_port_num = request.base_port_num; + let max_port_num = request.max_port_num.max(1); + let is_incremental = request.is_incremental; + + let port_start = if is_incremental { + base_port_num.saturating_add(1) + } else { + base_port_num.saturating_sub(max_port_num) + }; + + let port_end = if is_incremental { + base_port_num.saturating_add(max_port_num) + } else { + base_port_num.saturating_sub(1) + }; + + if port_end <= port_start { + return Err(anyhow::anyhow!("send_punch_packet_easy_sym invalid port range").into()); + } + + let ports = (port_start..=port_end) + .map(|x| x as u16) + .collect::>(); + tracing::debug!( + ?ports, + ?public_ips, + "send_punch_packet_easy_sym send to ports" + ); + send_symmetric_hole_punch_packet( + &ports, + listener, + transaction_id, + &public_ips, + 0, + ports.len(), + ) + .await + .with_context(|| "failed to send symmetric hole punch packet")?; + + Ok(()) + } + + // hard sym means public port is random and cannot be predicted + #[tracing::instrument(skip(self))] + pub(crate) async fn send_punch_packet_hard_sym( + &self, + request: SendPunchPacketHardSymRequest, + ) -> Result { + tracing::info!("try_punch_symmetric start"); + + let listener_addr = request.listener_mapped_addr.ok_or(anyhow::anyhow!( + "try_punch_symmetric request missing listener_addr" + ))?; + let listener_addr = std::net::SocketAddr::from(listener_addr); + let listener = self + .common + .find_listener(&listener_addr) + .await + .ok_or(anyhow::anyhow!( + "send_punch_packet_for_cone failed to find listener" + ))?; + + let public_ips = request + .public_ips + .into_iter() + .map(|ip| std::net::Ipv4Addr::from(ip)) + .collect::>(); + if public_ips.len() == 0 { + tracing::warn!("try_punch_symmetric got zero len public ip"); + return Err(anyhow::anyhow!("try_punch_symmetric got zero len public ip").into()); + } + + let transaction_id = request.transaction_id; + let last_port_index = request.port_index as usize; + + let round = std::cmp::max(request.round, 1); + + // send max k1 packets if we are predicting the dst port + let max_k1: u32 = 180; + // send max k2 packets if we are sending to random port + let mut max_k2: u32 = rand::thread_rng().gen_range(600..800); + if round > 2 { + max_k2 = max_k2.mul(2).div(round).max(max_k1); + } + + let next_port_index = send_symmetric_hole_punch_packet( + &self.shuffled_port_vec, + listener.clone(), + transaction_id, + &public_ips, + last_port_index, + max_k2 as usize, + ) + .await + .with_context(|| "failed to send symmetric hole punch packet randomly")?; + + return Ok(SendPunchPacketHardSymResponse { + next_port_index: next_port_index as u32, + }); + } +} + +pub(crate) struct PunchSymToConeHoleClient { + peer_mgr: Arc, + udp_array: RwLock>>, + try_direct_connect: AtomicBool, + punch_predicablely: AtomicBool, + punch_randomly: AtomicBool, +} + +impl PunchSymToConeHoleClient { + pub(crate) fn new(peer_mgr: Arc) -> Self { + Self { + peer_mgr, + udp_array: RwLock::new(None), + try_direct_connect: AtomicBool::new(true), + punch_predicablely: AtomicBool::new(true), + punch_randomly: AtomicBool::new(true), + } + } + + async fn prepare_udp_array(&self) -> Result, anyhow::Error> { + let rlocked = self.udp_array.read().await; + if let Some(udp_array) = rlocked.clone() { + return Ok(udp_array); + } + + drop(rlocked); + let mut wlocked = self.udp_array.write().await; + if let Some(udp_array) = wlocked.clone() { + return Ok(udp_array); + } + + let udp_array = Arc::new(UdpSocketArray::new( + UDP_ARRAY_SIZE_FOR_HARD_SYM, + self.peer_mgr.get_global_ctx().net_ns.clone(), + )); + udp_array.start().await?; + wlocked.replace(udp_array.clone()); + Ok(udp_array) + } + + pub(crate) async fn clear_udp_array(&self) { + let mut wlocked = self.udp_array.write().await; + wlocked.take(); + } + + async fn get_base_port_for_easy_sym(&self, my_nat_info: UdpNatType) -> Option { + let global_ctx = self.peer_mgr.get_global_ctx(); + if my_nat_info.is_easy_sym() { + match global_ctx + .get_stun_info_collector() + .get_udp_port_mapping(0) + .await + { + Ok(addr) => Some(addr.port()), + ret => { + tracing::warn!(?ret, "failed to get udp port mapping for easy sym"); + None + } + } + } else { + None + } + } + + #[tracing::instrument(err(level = Level::ERROR), skip(self))] + pub(crate) async fn do_hole_punching( + &self, + dst_peer_id: PeerId, + round: u32, + last_port_idx: &mut usize, + my_nat_info: UdpNatType, + ) -> Result, anyhow::Error> { + let udp_array = self.prepare_udp_array().await?; + let global_ctx = self.peer_mgr.get_global_ctx(); + + let rpc_stub = self + .peer_mgr + .get_peer_rpc_mgr() + .rpc_client() + .scoped_client::>( + self.peer_mgr.my_peer_id(), + dst_peer_id, + global_ctx.get_network_name(), + ); + + let resp = rpc_stub + .select_punch_listener( + BaseController::default(), + SelectPunchListenerRequest { force_new: false }, + ) + .await + .with_context(|| "failed to select punch listener")?; + let remote_mapped_addr = resp.listener_mapped_addr.ok_or(anyhow::anyhow!( + "select_punch_listener response missing listener_mapped_addr" + ))?; + + // try direct connect first + if self.try_direct_connect.load(Ordering::Relaxed) { + if let Ok(tunnel) = try_connect_with_socket( + Arc::new(UdpSocket::bind("0.0.0.0:0").await?), + remote_mapped_addr.into(), + ) + .await + { + return Ok(tunnel); + } + } + + let stun_info = global_ctx.get_stun_info_collector().get_stun_info(); + let public_ips: Vec = stun_info + .public_ip + .iter() + .map(|x| x.parse().unwrap()) + .collect(); + if public_ips.is_empty() { + return Err(anyhow::anyhow!("failed to get public ips")); + } + + let tid = rand::thread_rng().gen(); + let packet = new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(); + udp_array.add_intreast_tid(tid); + defer! { udp_array.remove_intreast_tid(tid);} + udp_array + .send_with_all(&packet, remote_mapped_addr.into()) + .await?; + + let port_index = *last_port_idx as u32; + let base_port_for_easy_sym = self.get_base_port_for_easy_sym(my_nat_info).await; + let punch_random = self.punch_randomly.load(Ordering::Relaxed); + let punch_predicable = self.punch_predicablely.load(Ordering::Relaxed); + let scoped_punch_task: ScopedTask> = tokio::spawn(async move { + if punch_predicable { + if let Some(inc) = my_nat_info.get_inc_of_easy_sym() { + let req = SendPunchPacketEasySymRequest { + listener_mapped_addr: remote_mapped_addr.clone().into(), + public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(), + transaction_id: tid, + base_port_num: base_port_for_easy_sym.unwrap() as u32, + max_port_num: 50, + is_incremental: inc, + }; + tracing::debug!(?req, "send punch packet for easy sym start"); + let ret = rpc_stub + .send_punch_packet_easy_sym( + BaseController { + timeout_ms: 4000, + trace_id: 0, + }, + req, + ) + .await; + tracing::debug!(?ret, "send punch packet for easy sym return"); + } + } + + if punch_random { + let req = SendPunchPacketHardSymRequest { + listener_mapped_addr: remote_mapped_addr.clone().into(), + public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(), + transaction_id: tid, + round, + port_index, + }; + tracing::debug!(?req, "send punch packet for hard sym start"); + match rpc_stub + .send_punch_packet_hard_sym( + BaseController { + timeout_ms: 4000, + trace_id: 0, + }, + req, + ) + .await + { + Err(e) => { + tracing::error!(?e, "failed to send punch packet for hard sym"); + return None; + } + Ok(resp) => return Some(resp.next_port_index), + } + } + + None + }) + .into(); + + // no matter what the result is, we should check if we received any hole punching packet + let mut ret_tunnel: Option> = None; + let mut finish_time: Option = None; + while finish_time.is_none() || finish_time.as_ref().unwrap().elapsed().as_millis() < 1000 { + tokio::time::sleep(Duration::from_millis(200)).await; + + if finish_time.is_none() && (*scoped_punch_task).is_finished() { + finish_time = Some(Instant::now()); + } + + let Some(socket) = udp_array.try_fetch_punched_socket(tid) else { + tracing::debug!("no punched socket found, wait for more time"); + continue; + }; + + // if hole punched but tunnel creation failed, need to retry entire process. + match try_connect_with_socket(socket.socket.clone(), remote_mapped_addr.into()).await { + Ok(tunnel) => { + ret_tunnel.replace(tunnel); + break; + } + Err(e) => { + tracing::error!(?e, "failed to connect with socket"); + udp_array.add_new_socket(socket.socket).await?; + continue; + } + } + } + + let punch_task_result = scoped_punch_task.await; + tracing::debug!(?punch_task_result, ?ret_tunnel, "punch task got result"); + + if let Ok(Some(next_port_idx)) = punch_task_result { + *last_port_idx = next_port_idx as usize; + } else { + *last_port_idx = rand::random(); + } + + if let Some(tunnel) = ret_tunnel { + Ok(tunnel) + } else { + anyhow::bail!( + "failed to hole punch, punch task result: {:?}", + punch_task_result + ) + } + } +} + +#[cfg(test)] +pub mod tests { + use std::{ + sync::{atomic::AtomicU32, Arc}, + time::Duration, + }; + + use tokio::net::UdpSocket; + + use crate::{ + connector::udp_hole_punch::{ + tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector, + }, + peers::tests::{connect_peer_manager, wait_route_appear, wait_route_appear_with_cost}, + proto::common::NatType, + tunnel::common::tests::wait_for_condition, + }; + + #[tokio::test] + async fn hole_punching_symmetric_only_random() { + let p_a = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await; + let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + connect_peer_manager(p_a.clone(), p_b.clone()).await; + connect_peer_manager(p_b.clone(), p_c.clone()).await; + wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); + + let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone()); + let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone()); + + hole_punching_a + .client + .data() + .sym_to_cone_client + .try_direct_connect + .store(false, std::sync::atomic::Ordering::Relaxed); + + hole_punching_a + .client + .data() + .sym_to_cone_client + .punch_predicablely + .store(false, std::sync::atomic::Ordering::Relaxed); + + hole_punching_a.run().await.unwrap(); + hole_punching_c.run().await.unwrap(); + + hole_punching_a.client.run_immediately().await; + + wait_for_condition( + || async { + hole_punching_a + .client + .data() + .sym_to_cone_client + .udp_array + .read() + .await + .is_some() + }, + Duration::from_secs(5), + ) + .await; + + wait_for_condition( + || async { + wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1)) + .await + .is_ok() + }, + Duration::from_secs(5), + ) + .await; + println!("{:?}", p_a.list_routes().await); + + wait_for_condition( + || async { + hole_punching_a + .client + .data() + .sym_to_cone_client + .udp_array + .read() + .await + .is_none() + }, + Duration::from_secs(10), + ) + .await; + } + + #[rstest::rstest] + #[tokio::test] + #[serial_test::serial(hole_punch)] + async fn hole_punching_symmetric_only_predict(#[values("true", "false")] is_inc: bool) { + let p_a = create_mock_peer_manager_with_mock_stun(if is_inc { + NatType::SymmetricEasyInc + } else { + NatType::SymmetricEasyDec + }) + .await; + let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await; + connect_peer_manager(p_a.clone(), p_b.clone()).await; + connect_peer_manager(p_b.clone(), p_c.clone()).await; + wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); + + let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone()); + let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone()); + + hole_punching_a + .client + .data() + .sym_to_cone_client + .try_direct_connect + .store(false, std::sync::atomic::Ordering::Relaxed); + + hole_punching_a + .client + .data() + .sym_to_cone_client + .punch_randomly + .store(false, std::sync::atomic::Ordering::Relaxed); + + hole_punching_a.run().await.unwrap(); + hole_punching_c.run().await.unwrap(); + + let udps = if is_inc { + let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40147").await.unwrap()); + let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40194").await.unwrap()); + vec![udp1, udp2] + } else { + let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40141").await.unwrap()); + let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40100").await.unwrap()); + vec![udp1, udp2] + }; + // let udp_dec = Arc::new(UdpSocket::bind("0.0.0.0:40140").await.unwrap()); + // let udp_dec2 = Arc::new(UdpSocket::bind("0.0.0.0:40050").await.unwrap()); + + let counter = Arc::new(AtomicU32::new(0)); + + // all these sockets should receive hole punching packet + for udp in udps.iter().map(Arc::clone) { + let counter = counter.clone(); + tokio::spawn(async move { + let mut buf = [0u8; 1024]; + let (len, addr) = udp.recv_from(&mut buf).await.unwrap(); + println!( + "got predictable punch packet, {:?} {:?} {:?}", + len, + addr, + udp.local_addr() + ); + counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + }); + } + + hole_punching_a.client.run_immediately().await; + + let udp_len = udps.len(); + wait_for_condition( + || async { counter.load(std::sync::atomic::Ordering::Relaxed) == udp_len as u32 }, + Duration::from_secs(30), + ) + .await; + } +} diff --git a/easytier/src/easytier-cli.rs b/easytier/src/easytier-cli.rs index 886eb7a..c51529d 100644 --- a/easytier/src/easytier-cli.rs +++ b/easytier/src/easytier-cli.rs @@ -179,14 +179,16 @@ impl CommandHandler { async fn list_peers(&self) -> Result { let client = self.get_peer_manager_client().await?; let request = ListPeerRequest::default(); - let response = client.list_peer(BaseController {}, request).await?; + let response = client.list_peer(BaseController::default(), request).await?; Ok(response) } async fn list_routes(&self) -> Result { let client = self.get_peer_manager_client().await?; let request = ListRouteRequest::default(); - let response = client.list_route(BaseController {}, request).await?; + let response = client + .list_route(BaseController::default(), request) + .await?; Ok(response) } @@ -275,7 +277,7 @@ impl CommandHandler { let client = self.get_peer_manager_client().await?; let node_info = client - .show_node_info(BaseController {}, ShowNodeInfoRequest::default()) + .show_node_info(BaseController::default(), ShowNodeInfoRequest::default()) .await? .node_info .ok_or(anyhow::anyhow!("node info not found"))?; @@ -296,7 +298,9 @@ impl CommandHandler { async fn handle_route_dump(&self) -> Result<(), Error> { let client = self.get_peer_manager_client().await?; let request = DumpRouteRequest::default(); - let response = client.dump_route(BaseController {}, request).await?; + let response = client + .dump_route(BaseController::default(), request) + .await?; println!("response: {}", response.result); Ok(()) } @@ -305,7 +309,7 @@ impl CommandHandler { let client = self.get_peer_manager_client().await?; let request = ListForeignNetworkRequest::default(); let response = client - .list_foreign_network(BaseController {}, request) + .list_foreign_network(BaseController::default(), request) .await?; let network_map = response; if self.verbose { @@ -347,7 +351,7 @@ impl CommandHandler { let client = self.get_peer_manager_client().await?; let request = ListGlobalForeignNetworkRequest::default(); let response = client - .list_global_foreign_network(BaseController {}, request) + .list_global_foreign_network(BaseController::default(), request) .await?; if self.verbose { println!("{:#?}", response); @@ -383,7 +387,7 @@ impl CommandHandler { let mut items: Vec = vec![]; let client = self.get_peer_manager_client().await?; let node_info = client - .show_node_info(BaseController {}, ShowNodeInfoRequest::default()) + .show_node_info(BaseController::default(), ShowNodeInfoRequest::default()) .await? .node_info .ok_or(anyhow::anyhow!("node info not found"))?; @@ -451,7 +455,9 @@ impl CommandHandler { async fn handle_connector_list(&self) -> Result<(), Error> { let client = self.get_connector_manager_client().await?; let request = ListConnectorRequest::default(); - let response = client.list_connector(BaseController {}, request).await?; + let response = client + .list_connector(BaseController::default(), request) + .await?; println!("response: {:#?}", response); Ok(()) } @@ -515,7 +521,7 @@ async fn main() -> Result<(), Error> { Some(RouteSubCommand::Dump) => handler.handle_route_dump().await?, }, SubCommand::Stun => { - timeout(Duration::from_secs(5), async move { + timeout(Duration::from_secs(25), async move { let collector = StunInfoCollector::new_with_default_servers(); loop { let ret = collector.get_stun_info(); @@ -532,7 +538,10 @@ async fn main() -> Result<(), Error> { SubCommand::PeerCenter => { let peer_center_client = handler.get_peer_center_client().await?; let resp = peer_center_client - .get_global_peer_map(BaseController {}, GetGlobalPeerMapRequest::default()) + .get_global_peer_map( + BaseController::default(), + GetGlobalPeerMapRequest::default(), + ) .await?; #[derive(tabled::Tabled)] @@ -565,7 +574,10 @@ async fn main() -> Result<(), Error> { SubCommand::VpnPortal => { let vpn_portal_client = handler.get_vpn_portal_client().await?; let resp = vpn_portal_client - .get_vpn_portal_info(BaseController {}, GetVpnPortalInfoRequest::default()) + .get_vpn_portal_info( + BaseController::default(), + GetVpnPortalInfoRequest::default(), + ) .await? .vpn_portal_info .unwrap_or_default(); @@ -583,7 +595,7 @@ async fn main() -> Result<(), Error> { SubCommand::Node(sub_cmd) => { let client = handler.get_peer_manager_client().await?; let node_info = client - .show_node_info(BaseController {}, ShowNodeInfoRequest::default()) + .show_node_info(BaseController::default(), ShowNodeInfoRequest::default()) .await? .node_info .ok_or(anyhow::anyhow!("node info not found"))?; diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index ff0bd6b..e973e8a 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -161,7 +161,7 @@ impl Instance { DirectConnectorManager::new(global_ctx.clone(), peer_manager.clone()); direct_conn_manager.run(); - let udp_hole_puncher = UdpHolePunchConnector::new(global_ctx.clone(), peer_manager.clone()); + let udp_hole_puncher = UdpHolePunchConnector::new(peer_manager.clone()); let peer_center = Arc::new(PeerCenterInstance::new(peer_manager.clone())); diff --git a/easytier/src/peer_center/instance.rs b/easytier/src/peer_center/instance.rs index d7bb0a5..0537c7b 100644 --- a/easytier/src/peer_center/instance.rs +++ b/easytier/src/peer_center/instance.rs @@ -230,7 +230,7 @@ impl PeerCenterInstance { let ret = client .get_global_peer_map( - BaseController {}, + BaseController::default(), GetGlobalPeerMapRequest { digest: ctx.job_ctx.global_peer_map_digest.load(), }, @@ -307,7 +307,7 @@ impl PeerCenterInstance { let ret = client .report_peers( - BaseController {}, + BaseController::default(), ReportPeersRequest { my_peer_id: my_node_id, peer_infos: Some(peers), diff --git a/easytier/src/peers/mod.rs b/easytier/src/peers/mod.rs index 2598738..6056c58 100644 --- a/easytier/src/peers/mod.rs +++ b/easytier/src/peers/mod.rs @@ -15,6 +15,8 @@ pub mod foreign_network_manager; pub mod encrypt; +pub mod peer_task; + #[cfg(test)] pub mod tests; diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 6a4534d..7139661 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -1058,7 +1058,7 @@ mod tests { let ret = stub .say_hello( - RpcController {}, + RpcController::default(), SayHelloRequest { name: "abc".to_string(), }, diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index fcd73d0..adba786 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -539,7 +539,7 @@ impl RouteTable { fn get_nat_type(&self, peer_id: PeerId) -> Option { self.peer_infos .get(&peer_id) - .map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap()) + .map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap_or_default()) } fn build_peer_graph_from_synced_info( @@ -1322,7 +1322,7 @@ impl PeerRouteServiceImpl { self.global_ctx.get_network_name(), ); - let mut ctrl = BaseController {}; + let mut ctrl = BaseController::default(); ctrl.set_timeout_ms(3000); let ret = rpc_stub .sync_route_info( diff --git a/easytier/src/peers/peer_rpc.rs b/easytier/src/peers/peer_rpc.rs index c34d490..4f2fafe 100644 --- a/easytier/src/peers/peer_rpc.rs +++ b/easytier/src/peers/peer_rpc.rs @@ -224,7 +224,10 @@ pub mod tests { let msg = random_string(8192); let ret = stub - .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() }) + .say_hello( + RpcController::default(), + SayHelloRequest { name: msg.clone() }, + ) .await .unwrap(); @@ -233,7 +236,10 @@ pub mod tests { let msg = random_string(10); let ret = stub - .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() }) + .say_hello( + RpcController::default(), + SayHelloRequest { name: msg.clone() }, + ) .await .unwrap(); @@ -281,7 +287,10 @@ pub mod tests { ); let ret = stub - .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() }) + .say_hello( + RpcController::default(), + SayHelloRequest { name: msg.clone() }, + ) .await .unwrap(); assert_eq!(ret.greeting, format!("Hello {}!", msg)); @@ -289,14 +298,20 @@ pub mod tests { // call again let msg = random_string(16 * 1024); let ret = stub - .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() }) + .say_hello( + RpcController::default(), + SayHelloRequest { name: msg.clone() }, + ) .await .unwrap(); assert_eq!(ret.greeting, format!("Hello {}!", msg)); let msg = random_string(16 * 1024); let ret = stub - .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() }) + .say_hello( + RpcController::default(), + SayHelloRequest { name: msg.clone() }, + ) .await .unwrap(); assert_eq!(ret.greeting, format!("Hello {}!", msg)); @@ -340,13 +355,19 @@ pub mod tests { let msg = random_string(16 * 1024); let ret = stub1 - .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() }) + .say_hello( + RpcController::default(), + SayHelloRequest { name: msg.clone() }, + ) .await .unwrap(); assert_eq!(ret.greeting, format!("Hello {}!", msg)); let ret = stub2 - .say_hello(RpcController {}, SayHelloRequest { name: msg.clone() }) + .say_hello( + RpcController::default(), + SayHelloRequest { name: msg.clone() }, + ) .await; assert!(ret.is_err() && ret.unwrap_err().to_string().contains("Timeout")); } diff --git a/easytier/src/peers/peer_task.rs b/easytier/src/peers/peer_task.rs new file mode 100644 index 0000000..0fcb89e --- /dev/null +++ b/easytier/src/peers/peer_task.rs @@ -0,0 +1,138 @@ +use std::result::Result; +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use dashmap::DashMap; +use tokio::select; +use tokio::sync::Notify; +use tokio::task::JoinHandle; + +use crate::common::scoped_task::ScopedTask; +use anyhow::Error; + +use super::peer_manager::PeerManager; + +#[async_trait] +pub trait PeerTaskLauncher: Send + Sync + Clone + 'static { + type Data; + type CollectPeerItem; + type TaskRet; + + fn new_data(&self, peer_mgr: Arc) -> Self::Data; + async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec; + async fn launch_task( + &self, + data: &Self::Data, + item: Self::CollectPeerItem, + ) -> JoinHandle>; + + async fn all_task_done(&self, _data: &Self::Data) {} + + fn loop_interval_ms(&self) -> u64 { + 5000 + } +} + +pub struct PeerTaskManager { + launcher: Launcher, + peer_mgr: Arc, + main_loop_task: Mutex>>, + run_signal: Arc, + data: Launcher::Data, +} + +impl PeerTaskManager +where + D: Send + Sync + Clone + 'static, + C: std::fmt::Debug + Send + Sync + Clone + core::hash::Hash + Eq + 'static, + T: Send + 'static, + L: PeerTaskLauncher + 'static, +{ + pub fn new(launcher: L, peer_mgr: Arc) -> Self { + let data = launcher.new_data(peer_mgr.clone()); + Self { + launcher, + peer_mgr, + main_loop_task: Mutex::new(None), + run_signal: Arc::new(Notify::new()), + data, + } + } + + pub fn start(&self) { + let task = tokio::spawn(Self::main_loop( + self.launcher.clone(), + self.data.clone(), + self.run_signal.clone(), + )) + .into(); + self.main_loop_task.lock().unwrap().replace(task); + } + + async fn main_loop(launcher: L, data: D, signal: Arc) { + let peer_task_map = Arc::new(DashMap::>>::new()); + + loop { + let peers_to_connect = launcher.collect_peers_need_task(&data).await; + + // remove task not in peers_to_connect + let mut to_remove = vec![]; + for item in peer_task_map.iter() { + if !peers_to_connect.contains(item.key()) || item.value().is_finished() { + to_remove.push(item.key().clone()); + } + } + + tracing::debug!( + ?peers_to_connect, + ?to_remove, + "got peers to connect and remove" + ); + + for key in to_remove { + if let Some((_, task)) = peer_task_map.remove(&key) { + task.abort(); + match task.await { + Ok(Ok(_)) => {} + Ok(Err(task_ret)) => { + tracing::error!(?task_ret, "hole punching task failed"); + } + Err(e) => { + tracing::error!(?e, "hole punching task aborted"); + } + } + } + } + + if !peers_to_connect.is_empty() { + for item in peers_to_connect { + if peer_task_map.contains_key(&item) { + continue; + } + + tracing::debug!(?item, "launch hole punching task"); + peer_task_map + .insert(item.clone(), launcher.launch_task(&data, item).await.into()); + } + } else if peer_task_map.is_empty() { + tracing::debug!("all task done"); + launcher.all_task_done(&data).await; + } + + select! { + _ = tokio::time::sleep(std::time::Duration::from_millis( + launcher.loop_interval_ms(), + )) => {}, + _ = signal.notified() => {} + } + } + } + + pub async fn run_immediately(&self) { + self.run_signal.notify_one(); + } + + pub fn data(&self) -> D { + self.data.clone() + } +} diff --git a/easytier/src/proto/common.proto b/easytier/src/proto/common.proto index 7caeb60..bd9901d 100644 --- a/easytier/src/proto/common.proto +++ b/easytier/src/proto/common.proto @@ -42,6 +42,8 @@ message RpcPacket { int32 trace_id = 9; } +message Void {} + message UUID { uint64 high = 1; uint64 low = 2; @@ -57,6 +59,8 @@ enum NatType { PortRestricted = 5; Symmetric = 6; SymUdpFirewall = 7; + SymmetricEasyInc = 8; + SymmetricEasyDec = 9; } message Ipv4Addr { uint32 addr = 1; } diff --git a/easytier/src/proto/peer_rpc.proto b/easytier/src/proto/peer_rpc.proto index ac47b7f..8fffa37 100644 --- a/easytier/src/proto/peer_rpc.proto +++ b/easytier/src/proto/peer_rpc.proto @@ -93,27 +93,78 @@ service DirectConnectorRpc { rpc GetIpList(GetIpListRequest) returns (GetIpListResponse); } -message TryPunchHoleRequest { common.SocketAddr local_mapped_addr = 1; } - -message TryPunchHoleResponse { common.SocketAddr remote_mapped_addr = 1; } - -message TryPunchSymmetricRequest { - common.SocketAddr listener_addr = 1; - uint32 port = 2; - repeated common.Ipv4Addr public_ips = 3; - uint32 min_port = 4; - uint32 max_port = 5; - uint32 transaction_id = 6; - uint32 round = 7; - uint32 last_port_index = 8; +message SelectPunchListenerRequest { + bool force_new = 1; } -message TryPunchSymmetricResponse { uint32 last_port_index = 1; } +message SelectPunchListenerResponse { + common.SocketAddr listener_mapped_addr = 1; +} + +message SendPunchPacketConeRequest { + common.SocketAddr listener_mapped_addr = 1; + common.SocketAddr dest_addr = 2; + uint32 transaction_id = 3; + // send this many packets in a batch + uint32 packet_count_per_batch = 4; + // send total this batch count, total packet count = packet_batch_size * packet_batch_count + uint32 packet_batch_count = 5; + // interval between each batch + uint32 packet_interval_ms = 6; +} + +message SendPunchPacketHardSymRequest { + common.SocketAddr listener_mapped_addr = 1; + + repeated common.Ipv4Addr public_ips = 2; + uint32 transaction_id = 3; + uint32 port_index = 4; + uint32 round = 5; +} + +message SendPunchPacketHardSymResponse { uint32 next_port_index = 1; } + +message SendPunchPacketEasySymRequest { + common.SocketAddr listener_mapped_addr = 1; + repeated common.Ipv4Addr public_ips = 2; + uint32 transaction_id = 3; + + uint32 base_port_num = 4; + uint32 max_port_num = 5; + bool is_incremental = 6; +} + +message SendPunchPacketBothEasySymRequest { + uint32 udp_socket_count = 1; + common.Ipv4Addr public_ip = 2; + uint32 transaction_id = 3; + + uint32 dst_port_num = 4; + uint32 wait_time_ms = 5; +} + +message SendPunchPacketBothEasySymResponse { + // is doing punch with other peer + bool is_busy = 1; + common.SocketAddr base_mapped_addr = 2; +} service UdpHolePunchRpc { - rpc TryPunchHole(TryPunchHoleRequest) returns (TryPunchHoleResponse); - rpc TryPunchSymmetric(TryPunchSymmetricRequest) - returns (TryPunchSymmetricResponse); + rpc SelectPunchListener(SelectPunchListenerRequest) + returns (SelectPunchListenerResponse); + + // send packet to one remote_addr, used by nat1-3 to nat1-3 + rpc SendPunchPacketCone(SendPunchPacketConeRequest) returns (common.Void); + + // send packet to multiple remote_addr (birthday attack), used by nat4 to nat1-3 + rpc SendPunchPacketHardSym(SendPunchPacketHardSymRequest) + returns (SendPunchPacketHardSymResponse); + rpc SendPunchPacketEasySym(SendPunchPacketEasySymRequest) + returns (common.Void); + + // nat4 to nat4 (both predictably) + rpc SendPunchPacketBothEasySym(SendPunchPacketBothEasySymRequest) + returns (SendPunchPacketBothEasySymResponse); } message DirectConnectedPeerInfo { int32 latency_ms = 1; } diff --git a/easytier/src/proto/rpc_impl/server.rs b/easytier/src/proto/rpc_impl/server.rs index d1db95e..372cd2a 100644 --- a/easytier/src/proto/rpc_impl/server.rs +++ b/easytier/src/proto/rpc_impl/server.rs @@ -146,7 +146,7 @@ impl Server { async fn handle_rpc_request(packet: RpcPacket, reg: Arc) -> Result { let rpc_request = RpcRequest::decode(Bytes::from(packet.body))?; let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64); - let ctrl = RpcController {}; + let ctrl = RpcController::default(); Ok(timeout( timeout_duration, reg.call_method( diff --git a/easytier/src/proto/rpc_types/controller.rs b/easytier/src/proto/rpc_types/controller.rs index 60b97d8..900fa2a 100644 --- a/easytier/src/proto/rpc_types/controller.rs +++ b/easytier/src/proto/rpc_types/controller.rs @@ -13,6 +13,34 @@ pub trait Controller: Send + Sync + 'static { } #[derive(Debug)] -pub struct BaseController {} +pub struct BaseController { + pub timeout_ms: i32, + pub trace_id: i32, +} -impl Controller for BaseController {} +impl Controller for BaseController { + fn timeout_ms(&self) -> i32 { + self.timeout_ms + } + + fn set_timeout_ms(&mut self, timeout_ms: i32) { + self.timeout_ms = timeout_ms; + } + + fn set_trace_id(&mut self, trace_id: i32) { + self.trace_id = trace_id; + } + + fn trace_id(&self) -> i32 { + self.trace_id + } +} + +impl Default for BaseController { + fn default() -> Self { + Self { + timeout_ms: 5000, + trace_id: 0, + } + } +} diff --git a/easytier/src/proto/tests.rs b/easytier/src/proto/tests.rs index dba760b..7fb978b 100644 --- a/easytier/src/proto/tests.rs +++ b/easytier/src/proto/tests.rs @@ -121,14 +121,14 @@ async fn rpc_basic_test() { // small size req and resp - let ctrl = RpcController {}; + let ctrl = RpcController::default(); let input = SayHelloRequest { name: "world".to_string(), }; let ret = out.say_hello(ctrl, input).await; assert_eq!(ret.unwrap().greeting, "Hello world!"); - let ctrl = RpcController {}; + let ctrl = RpcController::default(); let input = SayGoodbyeRequest { name: "world".to_string(), }; @@ -136,7 +136,7 @@ async fn rpc_basic_test() { assert_eq!(ret.unwrap().greeting, "Goodbye, world!"); // large size req and resp - let ctrl = RpcController {}; + let ctrl = RpcController::default(); let name = random_string(20 * 1024 * 1024); let input = SayGoodbyeRequest { name: name.clone() }; let ret = out.say_goodbye(ctrl, input).await; @@ -160,7 +160,7 @@ async fn rpc_timeout_test() { .client .scoped_client::>(1, 1, "test".to_string()); - let ctrl = RpcController {}; + let ctrl = RpcController::default(); let input = SayHelloRequest { name: "world".to_string(), }; @@ -199,7 +199,7 @@ async fn standalone_rpc_test() { .await .unwrap(); - let ctrl = RpcController {}; + let ctrl = RpcController::default(); let input = SayHelloRequest { name: "world".to_string(), }; @@ -211,7 +211,7 @@ async fn standalone_rpc_test() { .await .unwrap(); - let ctrl = RpcController {}; + let ctrl = RpcController::default(); let input = SayGoodbyeRequest { name: "world".to_string(), }; diff --git a/easytier/src/tunnel/mod.rs b/easytier/src/tunnel/mod.rs index 55d27cc..eb32131 100644 --- a/easytier/src/tunnel/mod.rs +++ b/easytier/src/tunnel/mod.rs @@ -94,7 +94,7 @@ pub trait Tunnel: Send { #[auto_impl::auto_impl(Arc)] pub trait TunnelConnCounter: 'static + Send + Sync + Debug { - fn get(&self) -> u32; + fn get(&self) -> Option; } #[derive(Debug, Clone, Copy, PartialEq)] @@ -114,8 +114,8 @@ pub trait TunnelListener: Send { #[derive(Debug)] struct FakeTunnelConnCounter {} impl TunnelConnCounter for FakeTunnelConnCounter { - fn get(&self) -> u32 { - 0 + fn get(&self) -> Option { + None } } Arc::new(Box::new(FakeTunnelConnCounter {})) diff --git a/easytier/src/tunnel/tcp.rs b/easytier/src/tunnel/tcp.rs index fc9d5d6..2a18b23 100644 --- a/easytier/src/tunnel/tcp.rs +++ b/easytier/src/tunnel/tcp.rs @@ -43,6 +43,10 @@ impl TunnelListener for TcpTunnelListener { setup_sokcet2(&socket2_socket, &addr)?; let socket = TcpSocket::from_std_stream(socket2_socket.into()); + if let Err(e) = socket.set_nodelay(true) { + tracing::warn!(?e, "set_nodelay fail in listen"); + } + self.addr .set_port(Some(socket.local_addr()?.port())) .unwrap(); @@ -54,7 +58,11 @@ impl TunnelListener for TcpTunnelListener { async fn accept(&mut self) -> Result, super::TunnelError> { let listener = self.listener.as_ref().unwrap(); let (stream, _) = listener.accept().await?; - stream.set_nodelay(true).unwrap(); + + if let Err(e) = stream.set_nodelay(true) { + tracing::warn!(?e, "set_nodelay fail in accept"); + } + let info = TunnelInfo { tunnel_type: "tcp".to_owned(), local_addr: Some(self.local_url().into()), @@ -80,7 +88,9 @@ fn get_tunnel_with_tcp_stream( stream: TcpStream, remote_url: url::Url, ) -> Result, super::TunnelError> { - stream.set_nodelay(true).unwrap(); + if let Err(e) = stream.set_nodelay(true) { + tracing::warn!(?e, "set_nodelay fail in get_tunnel_with_tcp_stream"); + } let info = TunnelInfo { tunnel_type: "tcp".to_owned(), diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index df96ec2..5ba715a 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -1,4 +1,7 @@ -use std::{fmt::Debug, sync::Arc}; +use std::{ + fmt::Debug, + sync::{Arc, Weak}, +}; use async_trait::async_trait; use bytes::BytesMut; @@ -445,25 +448,25 @@ impl TunnelListener for UdpTunnelListener { fn get_conn_counter(&self) -> Arc> { struct UdpTunnelConnCounter { - sock_map: Arc>, + sock_map: Weak>, + } + + impl TunnelConnCounter for UdpTunnelConnCounter { + fn get(&self) -> Option { + self.sock_map.upgrade().map(|x| x.len() as u32) + } } impl Debug for UdpTunnelConnCounter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("UdpTunnelConnCounter") - .field("sock_map_len", &self.sock_map.len()) + .field("sock_map_len", &self.get()) .finish() } } - impl TunnelConnCounter for UdpTunnelConnCounter { - fn get(&self) -> u32 { - self.sock_map.len() as u32 - } - } - Arc::new(Box::new(UdpTunnelConnCounter { - sock_map: self.data.sock_map.clone(), + sock_map: Arc::downgrade(&self.data.sock_map.clone()), })) } } @@ -942,14 +945,22 @@ mod tests { listener.listen().await.unwrap(); let c1 = listener.accept().await.unwrap(); - assert_eq!(conn_counter.get(), 1); + assert_eq!(conn_counter.get(), Some(1)); let c2 = listener.accept().await.unwrap(); - assert_eq!(conn_counter.get(), 2); + assert_eq!(conn_counter.get(), Some(2)); drop(c2); - wait_for_condition(|| async { conn_counter.get() == 1 }, Duration::from_secs(1)).await; + wait_for_condition( + || async { conn_counter.get() == Some(1) }, + Duration::from_secs(1), + ) + .await; drop(c1); - wait_for_condition(|| async { conn_counter.get() == 0 }, Duration::from_secs(1)).await; + wait_for_condition( + || async { conn_counter.get().unwrap_or(0) == 0 }, + Duration::from_secs(1), + ) + .await; } }