From e4be86cf9264fb8abd19338a695aa9da4c6f10df Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Wed, 3 Apr 2024 21:46:52 +0800 Subject: [PATCH] allow specify bind dev for tunnels. also fix bugs #46) 1. fix wireguard / udp tunnel stack overflow on win. 2. custom panic handler to save panic stack. 3. fix iface filter on windows and linux. 4. add scheme black list to direct connector --- Cargo.toml | 5 +- src/arch/windows.rs | 5 +- src/common/constants.rs | 4 - src/common/network.rs | 78 +++++++++++------ src/connector/direct.rs | 170 +++++++++++++++++++++++++++----------- src/easytier-core.rs | 14 +++- src/peers/peer_conn.rs | 8 +- src/tunnels/common.rs | 20 ++++- src/tunnels/mod.rs | 32 +++++++ src/tunnels/udp_tunnel.rs | 56 +++++++++++-- src/tunnels/wireguard.rs | 32 ++++--- 11 files changed, 320 insertions(+), 104 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1b9a438..e2c4572 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,6 +86,7 @@ anyhow = "1.0" tarpc = { version = "0.32", features = ["tokio1", "serde1"] } url = { version = "2.5", features = ["serde"] } +percent-encoding = "2.3.1" # for tun packet byteorder = "1.5.0" @@ -144,8 +145,8 @@ serial_test = "3.0.0" rstest = "0.18.2" [profile.dev] -panic = "abort" +panic = "unwind" [profile.release] -panic = "abort" +panic = "unwind" lto = true diff --git a/src/arch/windows.rs b/src/arch/windows.rs index 46b4f3e..2554cde 100644 --- a/src/arch/windows.rs +++ b/src/arch/windows.rs @@ -19,8 +19,6 @@ use windows_sys::{ }, }; -use crate::tunnels::common::get_interface_name_by_ip; - pub fn disable_connection_reset(socket: &S) -> io::Result<()> { let handle = socket.as_raw_socket() as SOCKET; @@ -132,13 +130,14 @@ pub fn set_ip_unicast_if( pub fn setup_socket_for_win( socket: &S, bind_addr: &SocketAddr, + bind_dev: Option, is_udp: bool, ) -> io::Result<()> { if is_udp { disable_connection_reset(socket)?; } - if let Some(iface) = get_interface_name_by_ip(&bind_addr.ip()) { + if let Some(iface) = bind_dev { set_ip_unicast_if(socket, bind_addr, iface.as_str())?; } diff --git a/src/common/constants.rs b/src/common/constants.rs index 2f36d4a..955ec3d 100644 --- a/src/common/constants.rs +++ b/src/common/constants.rs @@ -1,7 +1,3 @@ -pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1; -pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 60; -pub const DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC: u64 = 60; - macro_rules! define_global_var { ($name:ident, $type:ty, $init:expr) => { pub static $name: once_cell::sync::Lazy> = diff --git a/src/common/network.rs b/src/common/network.rs index 69ea760..1ed248c 100644 --- a/src/common/network.rs +++ b/src/common/network.rs @@ -7,7 +7,9 @@ use tokio::{ task::JoinSet, }; -use super::{constants::DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC, netns::NetNS}; +use super::netns::NetNS; + +pub const CACHED_IP_LIST_TIMEOUT_SEC: u64 = 60; struct InterfaceFilter { iface: NetworkInterface, @@ -15,33 +17,37 @@ struct InterfaceFilter { #[cfg(target_os = "linux")] impl InterfaceFilter { - async fn is_iface_bridge(&self) -> bool { - let path = format!("/sys/class/net/{}/bridge", self.iface.name); + async fn is_tun_tap_device(&self) -> bool { + let path = format!("/sys/class/net/{}/tun_flags", self.iface.name); tokio::fs::metadata(&path).await.is_ok() } - async fn is_iface_phsical(&self) -> bool { - let path = format!("/sys/class/net/{}/device", self.iface.name); - tokio::fs::metadata(&path).await.is_ok() + async fn has_valid_ip(&self) -> bool { + self.iface + .ips + .iter() + .map(|ip| ip.ip()) + .any(|ip| !ip.is_loopback() && !ip.is_unspecified() && !ip.is_multicast()) } async fn filter_iface(&self) -> bool { tracing::trace!( - "filter linux iface: {:?}, is_point_to_point: {}, is_loopback: {}, is_up: {}, is_lower_up: {}, is_bridge: {}, is_physical: {}", + "filter linux iface: {:?}, is_point_to_point: {}, is_loopback: {}, is_up: {}, is_lower_up: {}, is_tun: {}, has_valid_ip: {}", self.iface, self.iface.is_point_to_point(), self.iface.is_loopback(), self.iface.is_up(), self.iface.is_lower_up(), - self.is_iface_bridge().await, - self.is_iface_phsical().await, + self.is_tun_tap_device().await, + self.has_valid_ip().await ); !self.iface.is_point_to_point() && !self.iface.is_loopback() && self.iface.is_up() && self.iface.is_lower_up() - && (self.is_iface_bridge().await || self.is_iface_phsical().await) + && !self.is_tun_tap_device().await + && self.has_valid_ip().await } } @@ -85,7 +91,22 @@ impl InterfaceFilter { #[cfg(target_os = "windows")] impl InterfaceFilter { async fn filter_iface(&self) -> bool { - !self.iface.is_point_to_point() && !self.iface.is_loopback() && self.iface.is_up() + tracing::debug!( + "iface_name: {:?}, p2p: {:?}, is_up: {:?}, iface: {:?}", + self.iface.name, + self.iface.is_point_to_point(), + self.iface.is_up(), + self.iface + ); + !self.iface.is_point_to_point() + && !self.iface.is_loopback() + && self + .iface + .ips + .iter() + .map(|ip| ip.ip()) + .any(|ip| !ip.is_loopback() && !ip.is_unspecified() && !ip.is_multicast()) + && self.iface.mac.map(|mac| !mac.is_zero()).unwrap_or(false) } } @@ -143,10 +164,8 @@ impl IPCollector { loop { let ip_addrs = Self::do_collect_ip_addrs(true, net_ns.clone()).await; *cached_ip_list.write().await = ip_addrs; - tokio::time::sleep(std::time::Duration::from_secs( - DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC, - )) - .await; + tokio::time::sleep(std::time::Duration::from_secs(CACHED_IP_LIST_TIMEOUT_SEC)) + .await; } }); } @@ -154,6 +173,25 @@ impl IPCollector { return self.cached_ip_list.read().await.deref().clone(); } + pub async fn collect_interfaces(net_ns: NetNS) -> Vec { + let _g = net_ns.guard(); + let ifaces = pnet::datalink::interfaces(); + let mut ret = vec![]; + for iface in ifaces { + let f = InterfaceFilter { + iface: iface.clone(), + }; + + if !f.filter_iface().await { + continue; + } + + ret.push(iface); + } + + ret + } + #[tracing::instrument(skip(net_ns))] async fn do_collect_ip_addrs(with_public: bool, net_ns: NetNS) -> GetIpListResponse { let mut ret = crate::rpc::peer::GetIpListResponse::new(); @@ -170,17 +208,9 @@ impl IPCollector { } } + let ifaces = Self::collect_interfaces(net_ns.clone()).await; let _g = net_ns.guard(); - let ifaces = pnet::datalink::interfaces(); for iface in ifaces { - let f = InterfaceFilter { - iface: iface.clone(), - }; - - if !f.filter_iface().await { - continue; - } - for ip in iface.ips { let ip: std::net::IpAddr = ip.ip(); if ip.is_loopback() || ip.is_multicast() { diff --git a/src/connector/direct.rs b/src/connector/direct.rs index f92ae73..c210b02 100644 --- a/src/connector/direct.rs +++ b/src/connector/direct.rs @@ -3,12 +3,7 @@ use std::sync::Arc; use crate::{ - common::{ - constants::{self, DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC}, - error::Error, - global_ctx::ArcGlobalCtx, - PeerId, - }, + common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, peers::{peer_manager::PeerManager, peer_rpc::PeerRpcManager}, }; @@ -18,6 +13,9 @@ use tracing::Instrument; use super::create_connector_by_url; +pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1; +pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300; + #[tarpc::service] pub trait DirectConnectorRpc { async fn get_ip_list() -> GetIpListResponse; @@ -76,10 +74,25 @@ impl DirectConnectorManagerRpcServer { #[derive(Hash, Eq, PartialEq, Clone)] struct DstBlackListItem(PeerId, String); +#[derive(Hash, Eq, PartialEq, Clone)] +struct DstSchemeBlackListItem(PeerId, String); + struct DirectConnectorManagerData { global_ctx: ArcGlobalCtx, peer_manager: Arc, dst_blacklist: timedmap::TimedMap, + dst_sceme_blacklist: timedmap::TimedMap, +} + +impl DirectConnectorManagerData { + pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc) -> Self { + Self { + global_ctx, + peer_manager, + dst_blacklist: timedmap::TimedMap::new(), + dst_sceme_blacklist: timedmap::TimedMap::new(), + } + } } impl std::fmt::Debug for DirectConnectorManagerData { @@ -101,11 +114,7 @@ impl DirectConnectorManager { pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc) -> Self { Self { global_ctx: global_ctx.clone(), - data: Arc::new(DirectConnectorManagerData { - global_ctx, - peer_manager, - dst_blacklist: timedmap::TimedMap::new(), - }), + data: Arc::new(DirectConnectorManagerData::new(global_ctx, peer_manager)), tasks: JoinSet::new(), } } @@ -117,7 +126,7 @@ impl DirectConnectorManager { pub fn run_as_server(&mut self) { self.data.peer_manager.get_peer_rpc_mgr().run_service( - constants::DIRECT_CONNECTOR_SERVICE_ID, + DIRECT_CONNECTOR_SERVICE_ID, DirectConnectorManagerRpcServer::new(self.global_ctx.clone()).serve(), ); } @@ -193,7 +202,7 @@ impl DirectConnectorManager { data: Arc, dst_peer_id: PeerId, addr: String, - ) { + ) -> Result<(), Error> { let ret = Self::do_try_connect_to_ip(data.clone(), dst_peer_id, addr.clone()).await; if let Err(e) = ret { if !matches!(e, Error::UrlInBlacklist) { @@ -208,47 +217,36 @@ impl DirectConnectorManager { std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC), ); } + return Err(e); } else { log::info!("try_connect_to_ip success, peer_id: {}", dst_peer_id); + return Ok(()); } } #[tracing::instrument] - async fn do_try_direct_connect( + async fn do_try_direct_connect_internal( data: Arc, dst_peer_id: PeerId, + ip_list: GetIpListResponse, ) -> Result<(), Error> { - let peer_manager = data.peer_manager.clone(); - // check if we have direct connection with dst_peer_id - if let Some(c) = peer_manager.list_peer_conns(dst_peer_id).await { - // currently if we have any type of direct connection (udp or tcp), we will not try to connect - if !c.is_empty() { - return Ok(()); - } - } - - log::trace!("try direct connect to peer: {}", dst_peer_id); - - let ip_list = peer_manager - .get_peer_rpc_mgr() - .do_client_rpc_scoped(1, dst_peer_id, |c| async { - let client = - DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn(); - let ip_list = client.get_ip_list(tarpc::context::current()).await; - tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list"); - ip_list - }) - .await?; - let available_listeners = ip_list .listeners .iter() .filter_map(|l| if l.scheme() != "ring" { Some(l) } else { None }) + .filter(|l| l.port().is_some()) + .filter(|l| { + !data.dst_sceme_blacklist.contains(&DstSchemeBlackListItem( + dst_peer_id.clone(), + l.scheme().to_string(), + )) + }) .collect::>(); - let mut listener = available_listeners - .get(0) - .ok_or(anyhow::anyhow!("peer {} have no listener", dst_peer_id))?; + let mut listener = available_listeners.get(0).ok_or(anyhow::anyhow!( + "peer {} have no valid listener", + dst_peer_id + ))?; // if have default listener, use it first listener = available_listeners @@ -283,30 +281,77 @@ impl DirectConnectorManager { addr, )); + let mut has_succ = false; while let Some(ret) = tasks.join_next().await { if let Err(e) = ret { log::error!("join direct connect task failed: {:?}", e); + } else if let Ok(Ok(_)) = ret { + has_succ = true; } } + if !has_succ { + data.dst_sceme_blacklist.insert( + DstSchemeBlackListItem(dst_peer_id.clone(), listener.scheme().to_string()), + (), + std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC), + ); + } + Ok(()) } + + #[tracing::instrument] + async fn do_try_direct_connect( + data: Arc, + dst_peer_id: PeerId, + ) -> Result<(), Error> { + let peer_manager = data.peer_manager.clone(); + // check if we have direct connection with dst_peer_id + if let Some(c) = peer_manager.list_peer_conns(dst_peer_id).await { + // currently if we have any type of direct connection (udp or tcp), we will not try to connect + if !c.is_empty() { + return Ok(()); + } + } + + log::trace!("try direct connect to peer: {}", dst_peer_id); + + let ip_list = peer_manager + .get_peer_rpc_mgr() + .do_client_rpc_scoped(1, dst_peer_id, |c| async { + let client = + DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn(); + let ip_list = client.get_ip_list(tarpc::context::current()).await; + tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list"); + ip_list + }) + .await?; + + Self::do_try_direct_connect_internal(data, dst_peer_id, ip_list).await + } } #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::{ - connector::direct::DirectConnectorManager, + connector::direct::{ + DirectConnectorManager, DirectConnectorManagerData, DstBlackListItem, + DstSchemeBlackListItem, + }, instance::listeners::ListenerManager, peers::tests::{ connect_peer_manager, create_mock_peer_manager, wait_route_appear, wait_route_appear_with_cost, }, - tunnels::tcp_tunnel::TcpTunnelListener, + rpc::peer::GetIpListResponse, }; + #[rstest::rstest] #[tokio::test] - async fn direct_connector_basic_test() { + async fn direct_connector_basic_test(#[values("tcp", "udp", "wg")] proto: &str) { let p_a = create_mock_peer_manager().await; let p_b = create_mock_peer_manager().await; let p_c = create_mock_peer_manager().await; @@ -321,14 +366,14 @@ mod tests { dm_a.run_as_client(); dm_c.run_as_server(); + let port = if proto == "wg" { 11040 } else { 11041 }; + p_c.get_global_ctx() + .config + .set_listeners(vec![format!("{}://0.0.0.0:{}", proto, port) + .parse() + .unwrap()]); let mut lis_c = ListenerManager::new(p_c.get_global_ctx(), p_c.clone()); - - lis_c - .add_listener(TcpTunnelListener::new( - "tcp://0.0.0.0:11040".parse().unwrap(), - )) - .await - .unwrap(); + lis_c.prepare_listeners().await.unwrap(); lis_c.run().await.unwrap(); @@ -336,4 +381,31 @@ mod tests { .await .unwrap(); } + + #[tokio::test] + async fn direct_connector_scheme_blacklist() { + let p_a = create_mock_peer_manager().await; + let data = Arc::new(DirectConnectorManagerData::new( + p_a.get_global_ctx(), + p_a.clone(), + )); + let mut ip_list = GetIpListResponse::new(); + ip_list + .listeners + .push("tcp://127.0.0.1:10222".parse().unwrap()); + + ip_list.interface_ipv4s.push("127.0.0.1".to_string()); + + DirectConnectorManager::do_try_direct_connect_internal(data.clone(), 1, ip_list.clone()) + .await + .unwrap(); + + assert!(data + .dst_sceme_blacklist + .contains(&DstSchemeBlackListItem(1, "tcp".into()))); + + assert!(data + .dst_blacklist + .contains(&DstBlackListItem(1, ip_list.listeners[0].to_string()))); + } } diff --git a/src/easytier-core.rs b/src/easytier-core.rs index 500081c..cd65c7a 100644 --- a/src/easytier-core.rs +++ b/src/easytier-core.rs @@ -3,7 +3,7 @@ #[cfg(test)] mod tests; -use std::net::SocketAddr; +use std::{backtrace, io::Write as _, net::SocketAddr}; use anyhow::Context; use clap::Parser; @@ -318,9 +318,21 @@ fn peer_conn_info_to_string(p: crate::rpc::PeerConnInfo) -> String { ) } +fn setup_panic_handler() { + std::panic::set_hook(Box::new(|info| { + let backtrace = backtrace::Backtrace::force_capture(); + println!("panic occurred: {:?}", info); + let _ = std::fs::File::create("easytier-panic.log") + .and_then(|mut f| f.write_all(format!("{:?}\n{:#?}", info, backtrace).as_bytes())); + std::process::exit(1); + })); +} + #[tokio::main(flavor = "current_thread")] #[tracing::instrument] pub async fn main() { + setup_panic_handler(); + let cli = Cli::parse(); tracing::info!(cli = ?cli, "cli args parsed"); diff --git a/src/peers/peer_conn.rs b/src/peers/peer_conn.rs index c9373a4..1110a72 100644 --- a/src/peers/peer_conn.rs +++ b/src/peers/peer_conn.rs @@ -25,7 +25,7 @@ use crate::{ PeerId, }, define_tunnel_filter_chain, - peers::packet::{ArchivedPacketType, CtrlPacketPayload}, + peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType}, rpc::{PeerConnInfo, PeerConnStats}, tunnels::{ stats::{Throughput, WindowLatency}, @@ -52,6 +52,12 @@ macro_rules! wait_response { let $out_var; let rsp_bytes = Packet::decode(&rsp_vec); + if rsp_bytes.packet_type != PacketType::HandShake { + tracing::error!("unexpected packet type: {:?}", rsp_bytes); + return Err(TunnelError::WaitRespError( + "unexpected packet type".to_owned(), + )); + } let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes); match &resp_payload { $pattern => $out_var = $value, diff --git a/src/tunnels/common.rs b/src/tunnels/common.rs index 512c2f7..ed38704 100644 --- a/src/tunnels/common.rs +++ b/src/tunnels/common.rs @@ -275,14 +275,15 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option { None } -pub(crate) fn setup_sokcet2( +pub(crate) fn setup_sokcet2_ext( socket2_socket: &socket2::Socket, bind_addr: &SocketAddr, + bind_dev: Option, ) -> Result<(), TunnelError> { #[cfg(target_os = "windows")] { let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM); - crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, is_udp)?; + crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?; } socket2_socket.set_nonblocking(true)?; @@ -299,7 +300,7 @@ pub(crate) fn setup_sokcet2( // linux/mac does not use interface of bind_addr to send packet, so we need to bind device // win can handle this with bind correctly #[cfg(any(target_os = "ios", target_os = "macos"))] - if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) { + if let Some(dev_name) = bind_dev { // use IP_BOUND_IF to bind device unsafe { let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8); @@ -310,7 +311,7 @@ pub(crate) fn setup_sokcet2( } #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] - if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) { + if let Some(dev_name) = bind_dev { tracing::trace!(dev_name = ?dev_name, "bind device"); socket2_socket.bind_device(Some(dev_name.as_bytes()))?; } @@ -318,6 +319,17 @@ pub(crate) fn setup_sokcet2( Ok(()) } +pub(crate) fn setup_sokcet2( + socket2_socket: &socket2::Socket, + bind_addr: &SocketAddr, +) -> Result<(), TunnelError> { + setup_sokcet2_ext( + socket2_socket, + bind_addr, + super::common::get_interface_name_by_ip(&bind_addr.ip()), + ) +} + pub mod tests { use std::time::Instant; diff --git a/src/tunnels/mod.rs b/src/tunnels/mod.rs index da31a6f..6ad25e1 100644 --- a/src/tunnels/mod.rs +++ b/src/tunnels/mod.rs @@ -158,3 +158,35 @@ impl FromUrl for uuid::Uuid { Ok(o) } } + +pub struct TunnelUrl { + inner: url::Url, +} + +impl From for TunnelUrl { + fn from(url: url::Url) -> Self { + TunnelUrl { inner: url } + } +} + +impl From for url::Url { + fn from(url: TunnelUrl) -> Self { + url.into_inner() + } +} + +impl TunnelUrl { + pub fn into_inner(self) -> url::Url { + self.inner + } + + pub fn bind_dev(&self) -> Option { + self.inner.path().strip_prefix("/").and_then(|s| { + if s.is_empty() { + None + } else { + Some(String::from_utf8(percent_encoding::percent_decode_str(&s).collect()).unwrap()) + } + }) + } +} diff --git a/src/tunnels/udp_tunnel.rs b/src/tunnels/udp_tunnel.rs index 98d32af..2656b85 100644 --- a/src/tunnels/udp_tunnel.rs +++ b/src/tunnels/udp_tunnel.rs @@ -23,9 +23,9 @@ use crate::{ use super::{ codec::BytesCodec, - common::{setup_sokcet2, FramedTunnel, TunnelWithCustomInfo}, + common::{setup_sokcet2, setup_sokcet2_ext, FramedTunnel, TunnelWithCustomInfo}, ring_tunnel::create_ring_tunnel_pair, - DatagramSink, DatagramStream, Tunnel, TunnelListener, + DatagramSink, DatagramStream, Tunnel, TunnelListener, TunnelUrl, }; pub const UDP_DATA_MTU: usize = 65000; @@ -323,7 +323,14 @@ impl TunnelListener for UdpTunnelListener { socket2::Type::DGRAM, Some(socket2::Protocol::UDP), )?; - setup_sokcet2(&socket2_socket, &addr)?; + + let tunnel_url: TunnelUrl = self.addr.clone().into(); + if let Some(bind_dev) = tunnel_url.bind_dev() { + setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?; + } else { + setup_sokcet2(&socket2_socket, &addr)?; + } + self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); let socket = self.socket.as_ref().unwrap().clone(); @@ -335,7 +342,7 @@ impl TunnelListener for UdpTunnelListener { async move { loop { let mut buf = BytesMut::new(); - buf.resize(2500, 0); + buf.resize(UDP_DATA_MTU, 0); let (_size, addr) = socket.recv_from(&mut buf).await.unwrap(); let _ = buf.split_off(_size); log::trace!( @@ -597,7 +604,16 @@ mod tests { use rand::Rng; use tokio::time::timeout; - use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong}; + use crate::{ + common::global_ctx::tests::get_mock_global_ctx, + tunnels::{ + check_scheme_and_get_socket_addr, + common::{ + get_interface_name_by_ip, setup_sokcet2_ext, + tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong}, + }, + }, + }; use super::*; @@ -723,4 +739,34 @@ mod tests { let a_udp_packet = rkyv_util::decode_from_bytes::(&b).unwrap(); println!("{:?}, {:?}", udp_packet, a_udp_packet); } + + #[tokio::test] + async fn bind_multi_ip_to_same_dev() { + let global_ctx = get_mock_global_ctx(); + let ips = global_ctx + .get_ip_collector() + .collect_ip_addrs() + .await + .interface_ipv4s; + if ips.is_empty() { + return; + } + let bind_dev = get_interface_name_by_ip(&ips[0].parse().unwrap()); + + for ip in ips { + println!("bind to ip: {:?}, {:?}", ip, bind_dev); + let addr = check_scheme_and_get_socket_addr::( + &format!("udp://{}:11111", ip).parse().unwrap(), + "udp", + ) + .unwrap(); + let socket2_socket = socket2::Socket::new( + socket2::Domain::for_address(addr), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + ) + .unwrap(); + setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap(); + } + } } diff --git a/src/tunnels/wireguard.rs b/src/tunnels/wireguard.rs index d0db520..11d750b 100644 --- a/src/tunnels/wireguard.rs +++ b/src/tunnels/wireguard.rs @@ -26,8 +26,10 @@ use crate::{ }; use super::{ - check_scheme_and_get_socket_addr, common::setup_sokcet2, ring_tunnel::create_ring_tunnel_pair, - DatagramSink, DatagramStream, Tunnel, TunnelError, TunnelListener, + check_scheme_and_get_socket_addr, + common::{setup_sokcet2, setup_sokcet2_ext}, + ring_tunnel::create_ring_tunnel_pair, + DatagramSink, DatagramStream, Tunnel, TunnelError, TunnelListener, TunnelUrl, }; const MAX_PACKET: usize = 65500; @@ -132,7 +134,7 @@ impl Debug for WgPeerData { impl WgPeerData { #[tracing::instrument] async fn handle_one_packet_from_me(&self, packet: &[u8]) -> Result<(), anyhow::Error> { - let mut send_buf = [0u8; MAX_PACKET]; + let mut send_buf = vec![0u8; MAX_PACKET]; let encapsulate_result = { let mut peer = self.tunn.lock().await; @@ -180,7 +182,7 @@ impl WgPeerData { /// decapsulates them, and dispatches newly received IP packets. #[tracing::instrument] pub async fn handle_one_packet_from_peer(&self, recv_buf: &[u8]) { - let mut send_buf = [0u8; MAX_PACKET]; + let mut send_buf = vec![0u8; MAX_PACKET]; let data = &recv_buf[..]; let decapsulate_result = { let mut peer = self.tunn.lock().await; @@ -200,7 +202,7 @@ impl WgPeerData { }; let mut peer = self.tunn.lock().await; loop { - let mut send_buf = [0u8; MAX_PACKET]; + let mut send_buf = vec![0u8; MAX_PACKET]; match peer.decapsulate(None, &[], &mut send_buf) { TunnResult::WriteToNetwork(packet) => { match self.udp.send_to(packet, self.endpoint).await { @@ -288,10 +290,11 @@ impl WgPeerData { } TunnResult::Done => { // Sleep for a bit - tokio::time::sleep(Duration::from_millis(1)).await; + tokio::time::sleep(Duration::from_millis(250)).await; } other => { tracing::warn!("Unexpected WireGuard routine task state: {:?}", other); + tokio::time::sleep(Duration::from_millis(250)).await; } }; } @@ -299,7 +302,7 @@ impl WgPeerData { /// WireGuard Routine task. Handles Handshake, keep-alive, etc. pub async fn routine_task(self) { loop { - let mut send_buf = [0u8; MAX_PACKET]; + let mut send_buf = vec![0u8; MAX_PACKET]; let tun_result = { self.tunn.lock().await.update_timers(&mut send_buf) }; self.handle_routine_tun_result(tun_result).await; } @@ -462,7 +465,7 @@ impl WgTunnelListener { } }); - let mut buf = [0u8; 4096]; + let mut buf = vec![0u8; MAX_PACKET]; loop { let Ok((n, addr)) = socket.recv_from(&mut buf).await else { tracing::error!("Failed to receive from UDP socket"); @@ -508,7 +511,14 @@ impl TunnelListener for WgTunnelListener { socket2::Type::DGRAM, Some(socket2::Protocol::UDP), )?; - setup_sokcet2(&socket2_socket, &addr)?; + + let tunnel_url: TunnelUrl = self.addr.clone().into(); + if let Some(bind_dev) = tunnel_url.bind_dev() { + setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?; + } else { + setup_sokcet2(&socket2_socket, &addr)?; + } + self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); self.tasks.spawn(Self::handle_udp_incoming( self.get_udp_socket(), @@ -636,7 +646,7 @@ impl WgTunnelConnector { let init = Self::create_handshake_init(&mut my_tun); udp.send_to(&init, addr).await?; - let mut buf = [0u8; MAX_PACKET]; + let mut buf = vec![0u8; MAX_PACKET]; let (n, _) = udp.recv_from(&mut buf).await.unwrap(); let keepalive = Self::parse_handshake_resp(&mut my_tun, &buf[..n]); udp.send_to(&keepalive, addr).await?; @@ -647,7 +657,7 @@ impl WgTunnelConnector { let data = wg_peer.data.as_ref().unwrap().clone(); wg_peer.tasks.spawn(async move { loop { - let mut buf = [0u8; MAX_PACKET]; + let mut buf = vec![0u8; MAX_PACKET]; let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap(); if recv_addr != addr { continue;