From 8aa57ebc22b57d10a3a5df9542969f8c5829fb7d Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Mon, 10 Jun 2024 10:27:24 +0800 Subject: [PATCH] support no tun mode (#141) --- Cargo.lock | 79 +++- easytier/Cargo.toml | 8 +- easytier/src/common/config.rs | 4 + easytier/src/common/global_ctx.rs | 7 + easytier/src/connector/mod.rs | 3 +- easytier/src/easytier-core.rs | 16 + easytier/src/gateway/icmp_proxy.rs | 102 ++++- easytier/src/gateway/mod.rs | 3 +- easytier/src/gateway/tcp_proxy.rs | 249 ++++++++++-- .../gateway/tokio_smoltcp/channel_device.rs | 75 ++++ easytier/src/gateway/tokio_smoltcp/device.rs | 122 ++++++ easytier/src/gateway/tokio_smoltcp/mod.rs | 220 ++++++++++ easytier/src/gateway/tokio_smoltcp/reactor.rs | 163 ++++++++ easytier/src/gateway/tokio_smoltcp/socket.rs | 377 ++++++++++++++++++ .../gateway/tokio_smoltcp/socket_allocator.rs | 145 +++++++ easytier/src/gateway/udp_proxy.rs | 38 +- easytier/src/instance/instance.rs | 27 +- easytier/src/peers/peer_manager.rs | 2 +- easytier/src/peers/peer_ospf_route.rs | 2 +- easytier/src/tests/three_node.rs | 248 +++++++----- easytier/src/tunnel/wireguard.rs | 2 +- 21 files changed, 1722 insertions(+), 170 deletions(-) create mode 100644 easytier/src/gateway/tokio_smoltcp/channel_device.rs create mode 100644 easytier/src/gateway/tokio_smoltcp/device.rs create mode 100644 easytier/src/gateway/tokio_smoltcp/mod.rs create mode 100644 easytier/src/gateway/tokio_smoltcp/reactor.rs create mode 100644 easytier/src/gateway/tokio_smoltcp/socket.rs create mode 100644 easytier/src/gateway/tokio_smoltcp/socket_allocator.rs diff --git a/Cargo.lock b/Cargo.lock index cc8ac69..d0bb12c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1084,6 +1084,38 @@ dependencies = [ "thiserror", ] +[[package]] +name = "defmt" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a99dd22262668b887121d4672af5a64b238f026099f1a2a1b322066c9ecfe9e0" +dependencies = [ + "bitflags 1.3.2", + "defmt-macros", +] + +[[package]] +name = "defmt-macros" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9f309eff1f79b3ebdf252954d90ae440599c26c2c553fe87a2d17195f2dcb" +dependencies = [ + "defmt-parser", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 2.0.65", +] + +[[package]] +name = "defmt-parser" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff4a5fefe330e8d7f31b16a318f9ce81000d8e35e69b93eae154d16d2278f70f" +dependencies = [ + "thiserror", +] + [[package]] name = "deranged" version = "0.3.11" @@ -1244,6 +1276,7 @@ dependencies = [ "network-interface", "nix 0.27.1", "once_cell", + "parking_lot", "percent-encoding", "petgraph", "pin-project-lite", @@ -1259,6 +1292,7 @@ dependencies = [ "rustls", "serde", "serial_test", + "smoltcp", "socket2", "stun_codec", "tabled", @@ -1975,6 +2009,15 @@ dependencies = [ "byteorder", ] +[[package]] +name = "hash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606" +dependencies = [ + "byteorder", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1994,13 +2037,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" dependencies = [ "atomic-polyfill", - "hash32", + "hash32 0.2.1", "rustc_version", "serde", "spin 0.9.8", "stable_deref_trait", ] +[[package]] +name = "heapless" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" +dependencies = [ + "hash32 0.3.1", + "stable_deref_trait", +] + [[package]] name = "heck" version = "0.3.3" @@ -2606,6 +2659,12 @@ dependencies = [ "libc", ] +[[package]] +name = "managed" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d" + [[package]] name = "markup5ever" version = "0.11.0" @@ -3550,7 +3609,7 @@ checksum = "a55c51ee6c0db07e68448e336cf8ea4131a620edefebf9893e759b2d793420f8" dependencies = [ "cobs", "embedded-io", - "heapless", + "heapless 0.7.17", "serde", ] @@ -4491,6 +4550,22 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smoltcp" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a1a996951e50b5971a2c8c0fa05a381480d70a933064245c4a223ddc87ccc97" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "cfg-if", + "defmt", + "heapless 0.8.0", + "libc", + "log", + "managed", +] + [[package]] name = "socket2" version = "0.5.7" diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 362d623..e150278 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -155,6 +155,9 @@ indexmap = { version = "~1.9.3", optional = false, features = ["std"] } atomic-shim = "0.2.0" +smoltcp = { version = "0.11.0", optional = true } +parking_lot = { version = "0.12.0", optional = true } + [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.52", features = [ "Win32_Networking_WinSock", @@ -179,8 +182,8 @@ defguard_wireguard_rs = "0.4.2" [features] -default = ["wireguard", "mimalloc", "websocket"] -full = ["quic", "websocket", "wireguard", "mimalloc", "aes-gcm"] +default = ["wireguard", "mimalloc", "websocket", "smoltcp"] +full = ["quic", "websocket", "wireguard", "mimalloc", "aes-gcm", "smoltcp"] mips = ["aes-gcm", "mimalloc", "wireguard"] wireguard = ["dep:boringtun", "dep:ring"] quic = ["dep:quinn", "dep:rustls", "dep:rcgen"] @@ -193,3 +196,4 @@ websocket = [ "dep:rustls", "dep:rcgen", ] +smoltcp = ["dep:smoltcp", "dep:parking_lot"] diff --git a/easytier/src/common/config.rs b/easytier/src/common/config.rs index dfeeda6..fda96dc 100644 --- a/easytier/src/common/config.rs +++ b/easytier/src/common/config.rs @@ -160,6 +160,10 @@ pub struct Flags { pub latency_first: bool, #[derivative(Default(value = "false"))] pub enable_exit_node: bool, + #[derivative(Default(value = "false"))] + pub no_tun: bool, + #[derivative(Default(value = "false"))] + pub use_smoltcp: bool, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index 9026251..9da3dcb 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -65,6 +65,7 @@ pub struct GlobalCtx { running_listeners: Mutex>, enable_exit_node: bool, + no_tun: bool, } impl std::fmt::Debug for GlobalCtx { @@ -93,6 +94,7 @@ impl GlobalCtx { let stun_info_collection = Arc::new(StunInfoCollector::new_with_default_servers()); let enable_exit_node = config_fs.get_flags().enable_exit_node; + let no_tun = config_fs.get_flags().no_tun; GlobalCtx { inst_name: config_fs.get_inst_name(), @@ -114,6 +116,7 @@ impl GlobalCtx { running_listeners: Mutex::new(Vec::new()), enable_exit_node, + no_tun, } } @@ -234,6 +237,10 @@ impl GlobalCtx { pub fn enable_exit_node(&self) -> bool { self.enable_exit_node } + + pub fn no_tun(&self) -> bool { + self.no_tun + } } #[cfg(test)] diff --git a/easytier/src/connector/mod.rs b/easytier/src/connector/mod.rs index 085bd8d..44c872a 100644 --- a/easytier/src/connector/mod.rs +++ b/easytier/src/connector/mod.rs @@ -11,7 +11,7 @@ use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector}, tunnel::{ check_scheme_and_get_socket_addr, ring::RingTunnelConnector, tcp::TcpTunnelConnector, - udp::UdpTunnelConnector, FromUrl, IpVersion, TunnelConnector, + udp::UdpTunnelConnector, TunnelConnector, }, }; @@ -107,6 +107,7 @@ pub async fn create_connector_by_url( } #[cfg(feature = "websocket")] "ws" | "wss" => { + use crate::tunnel::{FromUrl, IpVersion}; let dst_addr = SocketAddr::from_url(url.clone(), IpVersion::Both)?; let mut connector = crate::tunnel::websocket::WSTunnelConnector::new(url); set_bind_addr_for_peer_connector( diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index e70662f..66fc4b0 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -198,6 +198,20 @@ and the vpn client is in network of 10.14.14.0/24" default_value = "false" )] enable_exit_node: bool, + + #[arg( + long, + help = "do not create TUN device, can use subnet proxy to access node", + default_value = "false" + )] + no_tun: bool, + + #[arg( + long, + help = "enable smoltcp stack for subnet proxy", + default_value = "true" + )] + use_smoltcp: bool, } impl Cli { @@ -414,6 +428,8 @@ impl From for TomlConfigLoader { f.mtu = mtu; } f.enable_exit_node = cli.enable_exit_node; + f.no_tun = cli.no_tun; + f.use_smoltcp = cli.use_smoltcp; cfg.set_flags(f); cfg.set_exit_nodes(cli.exit_nodes.clone()); diff --git a/easytier/src/gateway/icmp_proxy.rs b/easytier/src/gateway/icmp_proxy.rs index 3e0f4d4..cb75c75 100644 --- a/easytier/src/gateway/icmp_proxy.rs +++ b/easytier/src/gateway/icmp_proxy.rs @@ -7,7 +7,7 @@ use std::{ }; use pnet::packet::{ - icmp::{self, IcmpTypes}, + icmp::{self, echo_reply::MutableEchoReplyPacket, IcmpCode, IcmpTypes, MutableIcmpPacket}, ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, Packet, @@ -74,6 +74,7 @@ pub struct IcmpProxy { tasks: Mutex>, ip_resemmbler: Arc, + icmp_sender: Arc>>>, } fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit]) -> Result<(usize, IpAddr), Error> { @@ -181,12 +182,13 @@ impl IcmpProxy { tasks: Mutex::new(JoinSet::new()), ip_resemmbler: Arc::new(IpReassembler::new(Duration::from_secs(10))), + icmp_sender: Arc::new(std::sync::Mutex::new(None)), }; Ok(Arc::new(ret)) } - pub async fn start(self: &Arc) -> Result<(), Error> { + fn create_raw_socket(self: &Arc) -> Result { let _g = self.global_ctx.net_ns.guard(); let socket = socket2::Socket::new( socket2::Domain::IPV4, @@ -197,7 +199,22 @@ impl IcmpProxy { std::net::Ipv4Addr::UNSPECIFIED, 0, )))?; - self.socket.lock().unwrap().replace(socket); + Ok(socket) + } + + pub async fn start(self: &Arc) -> Result<(), Error> { + let socket = self.create_raw_socket(); + match socket { + Ok(socket) => { + self.socket.lock().unwrap().replace(socket); + } + Err(e) => { + tracing::warn!("create icmp socket failed: {:?}", e); + if !self.global_ctx.no_tun() { + return Err(e); + } + } + } self.start_icmp_proxy().await?; self.start_nat_table_cleaner().await?; @@ -219,12 +236,15 @@ impl IcmpProxy { } async fn start_icmp_proxy(self: &Arc) -> Result<(), Error> { - let socket = self.socket.lock().unwrap().as_ref().unwrap().try_clone()?; let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel(); - let nat_table = self.nat_table.clone(); - thread::spawn(|| { - socket_recv_loop(socket, nat_table, sender); - }); + self.icmp_sender.lock().unwrap().replace(sender.clone()); + if let Some(socket) = self.socket.lock().unwrap().as_ref() { + let socket = socket.try_clone()?; + let nat_table = self.nat_table.clone(); + thread::spawn(|| { + socket_recv_loop(socket, nat_table, sender); + }); + } let peer_manager = self.peer_manager.clone(); self.tasks.lock().await.spawn( @@ -268,8 +288,54 @@ impl IcmpProxy { Ok(()) } + async fn send_icmp_reply_to_peer( + &self, + src_ip: &Ipv4Addr, + dst_ip: &Ipv4Addr, + src_peer_id: PeerId, + dst_peer_id: PeerId, + icmp_packet: &icmp::echo_request::EchoRequestPacket<'_>, + ) { + let mut buf = vec![0u8; icmp_packet.packet().len() + 20]; + let mut reply_packet = MutableEchoReplyPacket::new(&mut buf[20..]).unwrap(); + reply_packet.set_icmp_type(IcmpTypes::EchoReply); + reply_packet.set_icmp_code(IcmpCode::new(0)); + reply_packet.set_identifier(icmp_packet.get_identifier()); + reply_packet.set_sequence_number(icmp_packet.get_sequence_number()); + reply_packet.set_payload(icmp_packet.payload()); + + let mut icmp_packet = MutableIcmpPacket::new(&mut buf[20..]).unwrap(); + icmp_packet.set_checksum(icmp::checksum(&icmp_packet.to_immutable())); + + let len = buf.len() - 20; + let _ = compose_ipv4_packet( + &mut buf[..], + src_ip, + dst_ip, + IpNextHeaderProtocols::Icmp, + len, + 1200, + rand::random(), + |buf| { + let mut packet = ZCPacket::new_with_payload(buf); + packet.fill_peer_manager_hdr(src_peer_id, dst_peer_id, PacketType::Data as u8); + let _ = self + .icmp_sender + .lock() + .unwrap() + .as_ref() + .unwrap() + .send(packet); + Ok(()) + }, + ); + } + async fn try_handle_peer_packet(&self, packet: &ZCPacket) -> Option<()> { - if self.cidr_set.is_empty() && !self.global_ctx.enable_exit_node() { + if self.cidr_set.is_empty() + && !self.global_ctx.enable_exit_node() + && !self.global_ctx.no_tun() + { return None; } @@ -288,7 +354,11 @@ impl IcmpProxy { return None; } - if !self.cidr_set.contains_v4(ipv4.get_destination()) && !is_exit_node { + if !self.cidr_set.contains_v4(ipv4.get_destination()) + && !is_exit_node + && !(self.global_ctx.no_tun() + && Some(ipv4.get_destination()) == self.global_ctx.get_ipv4()) + { return None; } @@ -311,6 +381,18 @@ impl IcmpProxy { return Some(()); } + if self.global_ctx.no_tun() && Some(ipv4.get_destination()) == self.global_ctx.get_ipv4() { + self.send_icmp_reply_to_peer( + &ipv4.get_destination(), + &ipv4.get_source(), + hdr.to_peer_id.get(), + hdr.from_peer_id.get(), + &icmp_packet, + ) + .await; + return Some(()); + } + let icmp_id = icmp_packet.get_identifier(); let icmp_seq = icmp_packet.get_sequence_number(); diff --git a/easytier/src/gateway/mod.rs b/easytier/src/gateway/mod.rs index c1007d8..433e0c7 100644 --- a/easytier/src/gateway/mod.rs +++ b/easytier/src/gateway/mod.rs @@ -6,8 +6,9 @@ use crate::common::global_ctx::ArcGlobalCtx; pub mod icmp_proxy; pub mod ip_reassembler; pub mod tcp_proxy; +#[cfg(feature = "smoltcp")] +pub mod tokio_smoltcp; pub mod udp_proxy; - #[derive(Debug)] struct CidrSet { global_ctx: ArcGlobalCtx, diff --git a/easytier/src/gateway/tcp_proxy.rs b/easytier/src/gateway/tcp_proxy.rs index 926518b..d32bc98 100644 --- a/easytier/src/gateway/tcp_proxy.rs +++ b/easytier/src/gateway/tcp_proxy.rs @@ -1,3 +1,4 @@ +use core::panic; use crossbeam::atomic::AtomicCell; use dashmap::DashMap; use pnet::packet::ip::IpNextHeaderProtocols; @@ -6,19 +7,18 @@ use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket, TcpPacket}; use pnet::packet::MutablePacket; use pnet::packet::Packet; use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; -use std::sync::atomic::AtomicU16; +use std::sync::atomic::{AtomicBool, AtomicU16}; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::io::copy_bidirectional; use tokio::net::{TcpListener, TcpSocket, TcpStream}; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, Mutex}; use tokio::task::JoinSet; use tracing::Instrument; use crate::common::error::Result; -use crate::common::global_ctx::GlobalCtx; +use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx}; use crate::common::join_joinset_background; -use crate::common::netns::NetNS; use crate::peers::peer_manager::PeerManager; use crate::peers::{NicPacketFilter, PeerPacketFilter}; @@ -26,6 +26,9 @@ use crate::tunnel::packet_def::{PacketType, ZCPacket}; use super::CidrSet; +#[cfg(feature = "smoltcp")] +use super::tokio_smoltcp::{self, channel_device, Net, NetConfig}; + #[derive(Debug, Clone, Copy, PartialEq)] enum NatDstEntryState { // receive syn packet but not start connecting to dst @@ -61,6 +64,64 @@ impl NatDstEntry { } } +enum ProxyTcpStream { + KernelTcpStream(TcpStream), + #[cfg(feature = "smoltcp")] + SmolTcpStream(tokio_smoltcp::TcpStream), +} + +impl ProxyTcpStream { + pub fn set_nodelay(&self, nodelay: bool) -> Result<()> { + match self { + Self::KernelTcpStream(stream) => stream.set_nodelay(nodelay).map_err(Into::into), + #[cfg(feature = "smoltcp")] + Self::SmolTcpStream(_stream) => { + tracing::warn!("smol tcp stream set_nodelay not implemented"); + Ok(()) + } + } + } + + pub async fn copy_bidirectional(&mut self, dst: &mut TcpStream) -> Result<()> { + match self { + Self::KernelTcpStream(stream) => { + copy_bidirectional(stream, dst).await?; + Ok(()) + } + #[cfg(feature = "smoltcp")] + Self::SmolTcpStream(stream) => { + copy_bidirectional(stream, dst).await?; + Ok(()) + } + } + } +} + +enum ProxyTcpListener { + KernelTcpListener(TcpListener), + #[cfg(feature = "smoltcp")] + SmolTcpListener(tokio_smoltcp::TcpListener), +} + +impl ProxyTcpListener { + pub async fn accept(&mut self) -> Result<(ProxyTcpStream, SocketAddr)> { + match self { + Self::KernelTcpListener(listener) => { + let (stream, addr) = listener.accept().await?; + Ok((ProxyTcpStream::KernelTcpStream(stream), addr)) + } + #[cfg(feature = "smoltcp")] + Self::SmolTcpListener(listener) => { + let Ok((stream, src)) = listener.accept().await else { + return Err(anyhow::anyhow!("smol tcp listener closed").into()); + }; + tracing::info!(?src, "smol tcp listener accepted"); + Ok((ProxyTcpStream::SmolTcpStream(stream), src)) + } + } + } +} + type ArcNatDstEntry = Arc; type SynSockMap = Arc>; @@ -81,6 +142,12 @@ pub struct TcpProxy { addr_conn_map: AddrConnSockMap, cidr_set: CidrSet, + + smoltcp_stack_sender: Option>, + smoltcp_stack_receiver: Arc>>>, + #[cfg(feature = "smoltcp")] + smoltcp_net: Arc>>, + enable_smoltcp: Arc, } #[async_trait::async_trait] @@ -157,6 +224,8 @@ impl NicPacketFilter for TcpProxy { impl TcpProxy { pub fn new(global_ctx: Arc, peer_manager: Arc) -> Arc { + let (smoltcp_stack_sender, smoltcp_stack_receiver) = mpsc::channel::(1000); + Arc::new(Self { global_ctx: global_ctx.clone(), peer_manager, @@ -169,6 +238,14 @@ impl TcpProxy { addr_conn_map: Arc::new(DashMap::new()), cidr_set: CidrSet::new(global_ctx), + + smoltcp_stack_sender: Some(smoltcp_stack_sender), + smoltcp_stack_receiver: Arc::new(Mutex::new(Some(smoltcp_stack_receiver))), + + #[cfg(feature = "smoltcp")] + smoltcp_net: Arc::new(Mutex::new(None)), + + enable_smoltcp: Arc::new(AtomicBool::new(true)), }) } @@ -224,34 +301,114 @@ impl TcpProxy { Ok(()) } + async fn get_proxy_listener(&self) -> Result { + #[cfg(feature = "smoltcp")] + if self.global_ctx.get_flags().use_smoltcp || self.global_ctx.no_tun() { + // use smoltcp network stack + self.local_port + .store(8899, std::sync::atomic::Ordering::Relaxed); + + let mut cap = smoltcp::phy::DeviceCapabilities::default(); + cap.max_transmission_unit = 1280; + cap.medium = smoltcp::phy::Medium::Ip; + let (dev, stack_sink, mut stack_stream) = channel_device::ChannelDevice::new(cap); + + let mut smoltcp_stack_receiver = + self.smoltcp_stack_receiver.lock().await.take().unwrap(); + self.tasks.lock().unwrap().spawn(async move { + while let Some(packet) = smoltcp_stack_receiver.recv().await { + tracing::trace!(?packet, "receive from peer send to smoltcp packet"); + if let Err(e) = stack_sink.send(Ok(packet.payload().to_vec())).await { + tracing::error!("send to smoltcp stack failed: {:?}", e); + } + } + tracing::error!("smoltcp stack sink exited"); + panic!("smoltcp stack sink exited"); + }); + + let peer_mgr = self.peer_manager.clone(); + self.tasks.lock().unwrap().spawn(async move { + while let Some(data) = stack_stream.recv().await { + tracing::trace!( + ?data, + "receive from smoltcp stack and send to peer mgr packet" + ); + let Some(ipv4) = Ipv4Packet::new(&data) else { + tracing::error!(?data, "smoltcp stack stream get non ipv4 packet"); + continue; + }; + + let dst = ipv4.get_destination(); + let packet = ZCPacket::new_with_payload(&data); + if let Err(e) = peer_mgr.send_msg_ipv4(packet, dst).await { + tracing::error!("send to peer failed in smoltcp sender: {:?}", e); + } + } + tracing::error!("smoltcp stack stream exited"); + panic!("smoltcp stack stream exited"); + }); + + let interface_config = smoltcp::iface::Config::new(smoltcp::wire::HardwareAddress::Ip); + let net = Net::new( + dev, + NetConfig::new( + interface_config, + format!("{}/24", self.global_ctx.get_ipv4().unwrap()) + .parse() + .unwrap(), + vec![], + ), + ); + net.set_any_ip(true); + let tcp = net.tcp_bind("0.0.0.0:8899".parse().unwrap()).await?; + self.smoltcp_net.lock().await.replace(net); + + self.enable_smoltcp + .store(true, std::sync::atomic::Ordering::Relaxed); + + return Ok(ProxyTcpListener::SmolTcpListener(tcp)); + } + + { + // use kernel network stack + let listen_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0); + let net_ns = self.global_ctx.net_ns.clone(); + let tcp_listener = net_ns + .run_async(|| async { TcpListener::bind(&listen_addr).await }) + .await?; + self.local_port.store( + tcp_listener.local_addr()?.port(), + std::sync::atomic::Ordering::Relaxed, + ); + + self.enable_smoltcp + .store(false, std::sync::atomic::Ordering::Relaxed); + + return Ok(ProxyTcpListener::KernelTcpListener(tcp_listener)); + } + } + async fn run_listener(&self) -> Result<()> { // bind on both v4 & v6 - let listen_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0); - - let net_ns = self.global_ctx.net_ns.clone(); - let tcp_listener = net_ns - .run_async(|| async { TcpListener::bind(&listen_addr).await }) - .await?; - - self.local_port.store( - tcp_listener.local_addr()?.port(), - std::sync::atomic::Ordering::Relaxed, - ); + let mut tcp_listener = self.get_proxy_listener().await?; + let global_ctx = self.global_ctx.clone(); let tasks = self.tasks.clone(); let syn_map = self.syn_map.clone(); let conn_map = self.conn_map.clone(); let addr_conn_map = self.addr_conn_map.clone(); let accept_task = async move { - tracing::info!(listener = ?tcp_listener, "tcp connection start accepting"); - let conn_map = conn_map.clone(); while let Ok((tcp_stream, socket_addr)) = tcp_listener.accept().await { let Some(entry) = syn_map.get(&socket_addr) else { tracing::error!("tcp connection from unknown source: {:?}", socket_addr); continue; }; - tracing::info!(?socket_addr, "tcp connection accepted for proxy"); + tracing::info!( + ?socket_addr, + "tcp connection accepted for proxy, nat dst: {:?}", + entry.dst + ); assert_eq!(entry.state.load(), NatDstEntryState::SynReceived); let entry_clone = entry.clone(); @@ -265,7 +422,7 @@ impl TcpProxy { assert!(old_nat_val.is_none()); tasks.lock().unwrap().spawn(Self::connect_to_nat_dst( - net_ns.clone(), + global_ctx.clone(), tcp_stream, conn_map.clone(), addr_conn_map.clone(), @@ -293,8 +450,8 @@ impl TcpProxy { } async fn connect_to_nat_dst( - net_ns: NetNS, - src_tcp_stream: TcpStream, + global_ctx: ArcGlobalCtx, + src_tcp_stream: ProxyTcpStream, conn_map: ConnSockMap, addr_conn_map: AddrConnSockMap, nat_entry: ArcNatDstEntry, @@ -303,14 +460,24 @@ impl TcpProxy { tracing::warn!("set_nodelay failed, ignore it: {:?}", e); } - let _guard = net_ns.guard(); + let _guard = global_ctx.net_ns.guard(); let socket = TcpSocket::new_v4().unwrap(); if let Err(e) = socket.set_nodelay(true) { tracing::warn!("set_nodelay failed, ignore it: {:?}", e); } + + let nat_dst = if Some(nat_entry.dst.ip()) == global_ctx.get_ipv4().map(|ip| IpAddr::V4(ip)) + { + format!("127.0.0.1:{}", nat_entry.dst.port()) + .parse() + .unwrap() + } else { + nat_entry.dst + }; + let Ok(Ok(dst_tcp_stream)) = tokio::time::timeout( Duration::from_secs(10), - TcpSocket::new_v4().unwrap().connect(nat_entry.dst), + TcpSocket::new_v4().unwrap().connect(nat_dst), ) .await else { @@ -321,6 +488,8 @@ impl TcpProxy { }; drop(_guard); + tracing::info!(?nat_entry, ?nat_dst, "tcp connection to dst established"); + assert_eq!(nat_entry.state.load(), NatDstEntryState::ConnectingDst); nat_entry.state.store(NatDstEntryState::Connected); @@ -335,7 +504,7 @@ impl TcpProxy { } async fn handle_nat_connection( - mut src_tcp_stream: TcpStream, + mut src_tcp_stream: ProxyTcpStream, mut dst_tcp_stream: TcpStream, conn_map: ConnSockMap, addr_conn_map: AddrConnSockMap, @@ -343,8 +512,8 @@ impl TcpProxy { ) { let nat_entry_clone = nat_entry.clone(); nat_entry.tasks.lock().await.spawn(async move { - let ret = copy_bidirectional(&mut src_tcp_stream, &mut dst_tcp_stream).await; - tracing::trace!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed"); + let ret = src_tcp_stream.copy_bidirectional(&mut dst_tcp_stream).await; + tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed"); nat_entry_clone.state.store(NatDstEntryState::Closed); Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry_clone); @@ -356,7 +525,10 @@ impl TcpProxy { } async fn try_handle_peer_packet(&self, packet: &mut ZCPacket) -> Option<()> { - if self.cidr_set.is_empty() && !self.global_ctx.enable_exit_node() { + if self.cidr_set.is_empty() + && !self.global_ctx.enable_exit_node() + && !self.global_ctx.no_tun() + { return None; } @@ -375,11 +547,15 @@ impl TcpProxy { return None; } - if !self.cidr_set.contains_v4(ipv4.get_destination()) && !is_exit_node { + if !self.cidr_set.contains_v4(ipv4.get_destination()) + && !is_exit_node + && !(self.global_ctx.no_tun() + && Some(ipv4.get_destination()) == self.global_ctx.get_ipv4()) + { return None; } - tracing::info!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received"); + tracing::trace!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received"); let ip_packet = Ipv4Packet::new(payload_bytes).unwrap(); let tcp_packet = TcpPacket::new(ip_packet.payload()).unwrap(); @@ -397,7 +573,7 @@ impl TcpProxy { let old_val = self .syn_map .insert(src, Arc::new(NatDstEntry::new(src, dst))); - tracing::trace!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received"); + tracing::info!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received"); } let mut ip_packet = MutableIpv4Packet::new(payload_bytes).unwrap(); @@ -411,7 +587,18 @@ impl TcpProxy { drop(tcp_packet); Self::update_ip_packet_checksum(&mut ip_packet); - tracing::info!(?source, ?ipv4_addr, ?packet, "tcp packet after modified"); + tracing::trace!(?source, ?ipv4_addr, ?packet, "tcp packet after modified"); + + if self + .enable_smoltcp + .load(std::sync::atomic::Ordering::Relaxed) + { + let smoltcp_stack_sender = self.smoltcp_stack_sender.as_ref().unwrap(); + if let Err(e) = smoltcp_stack_sender.try_send(packet.clone()) { + tracing::error!("send to smoltcp stack failed: {:?}", e); + } + return None; + } Some(()) } diff --git a/easytier/src/gateway/tokio_smoltcp/channel_device.rs b/easytier/src/gateway/tokio_smoltcp/channel_device.rs new file mode 100644 index 0000000..6ee90d4 --- /dev/null +++ b/easytier/src/gateway/tokio_smoltcp/channel_device.rs @@ -0,0 +1,75 @@ +use futures::{Sink, Stream}; +use smoltcp::phy::DeviceCapabilities; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio_util::sync::{PollSendError, PollSender}; + +use super::device::AsyncDevice; + +/// A device that send and receive packets using a channel. +pub struct ChannelDevice { + recv: Receiver>>, + send: PollSender>, + caps: DeviceCapabilities, +} + +impl ChannelDevice { + /// Make a new `ChannelDevice` with the given `recv` and `send` channels. + /// + /// The `caps` is used to determine the device capabilities. `DeviceCapabilities::max_transmission_unit` must be set. + pub fn new(caps: DeviceCapabilities) -> (Self, Sender>>, Receiver>) { + let (tx1, rx1) = channel(1000); + let (tx2, rx2) = channel(1000); + ( + ChannelDevice { + send: PollSender::new(tx1), + recv: rx2, + caps, + }, + tx2, + rx1, + ) + } +} + +impl Stream for ChannelDevice { + type Item = io::Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.recv.poll_recv(cx) + } +} + +fn map_err(e: PollSendError>) -> io::Error { + io::Error::new(io::ErrorKind::Other, e) +} + +impl Sink> for ChannelDevice { + type Error = io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.send.poll_reserve(cx).map_err(map_err) + } + + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + self.send.send_item(item).map_err(map_err) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.send.poll_reserve(cx).map_err(map_err) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl AsyncDevice for ChannelDevice { + fn capabilities(&self) -> &DeviceCapabilities { + &self.caps + } +} diff --git a/easytier/src/gateway/tokio_smoltcp/device.rs b/easytier/src/gateway/tokio_smoltcp/device.rs new file mode 100644 index 0000000..c424984 --- /dev/null +++ b/easytier/src/gateway/tokio_smoltcp/device.rs @@ -0,0 +1,122 @@ +use futures::{Sink, Stream}; +pub use smoltcp::phy::DeviceCapabilities; +use smoltcp::{ + phy::{Device, RxToken, TxToken}, + time::Instant, +}; +use std::{collections::VecDeque, io}; + +/// Default value of `max_burst_size`. +pub const DEFAULT_MAX_BURST_SIZE: usize = 100; + +/// A packet used in `AsyncDevice`. +pub type Packet = Vec; + +/// A device that send and receive packets asynchronously. +pub trait AsyncDevice: + Stream> + Sink + Send + Unpin +{ + /// Returns the device capabilities. + fn capabilities(&self) -> &DeviceCapabilities; +} + +impl AsyncDevice for Box +where + T: AsyncDevice, +{ + fn capabilities(&self) -> &DeviceCapabilities { + (**self).capabilities() + } +} + +/// A device that send and receive packets synchronously. +pub struct BufferDevice { + caps: DeviceCapabilities, + max_burst_size: usize, + recv_queue: VecDeque, + send_queue: VecDeque, +} + +/// RxToken for `BufferDevice`. +pub struct BufferRxToken(Packet); + +impl RxToken for BufferRxToken { + fn consume(mut self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let p = &mut self.0; + let result = f(p); + result + } +} + +/// TxToken for `BufferDevice`. +pub struct BufferTxToken<'a>(&'a mut BufferDevice); + +impl<'d> TxToken for BufferTxToken<'d> { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut buffer = vec![0u8; len]; + let result = f(&mut buffer); + + self.0.send_queue.push_back(buffer); + + result + } +} + +impl Device for BufferDevice { + type RxToken<'a> = BufferRxToken + where Self:'a; + type TxToken<'a> = BufferTxToken<'a> + where Self:'a; + + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + match self.recv_queue.pop_front() { + Some(p) => Some((BufferRxToken(p), BufferTxToken(self))), + None => None, + } + } + + fn transmit(&mut self, _timestamp: Instant) -> Option> { + if self.send_queue.len() < self.max_burst_size { + Some(BufferTxToken(self)) + } else { + None + } + } + + fn capabilities(&self) -> DeviceCapabilities { + self.caps.clone() + } +} + +impl BufferDevice { + pub(crate) fn new(caps: DeviceCapabilities) -> BufferDevice { + let max_burst_size = caps.max_burst_size.unwrap_or(DEFAULT_MAX_BURST_SIZE); + BufferDevice { + caps, + max_burst_size, + recv_queue: VecDeque::with_capacity(max_burst_size), + send_queue: VecDeque::with_capacity(max_burst_size), + } + } + pub(crate) fn take_send_queue(&mut self) -> VecDeque { + std::mem::replace( + &mut self.send_queue, + VecDeque::with_capacity(self.max_burst_size), + ) + } + pub(crate) fn push_recv_queue(&mut self, p: impl Iterator) { + self.recv_queue.extend(p.take(self.avaliable_recv_queue())); + } + pub(crate) fn avaliable_recv_queue(&self) -> usize { + self.max_burst_size - self.recv_queue.len() + } + pub(crate) fn need_wait(&self) -> bool { + self.recv_queue.is_empty() + } +} diff --git a/easytier/src/gateway/tokio_smoltcp/mod.rs b/easytier/src/gateway/tokio_smoltcp/mod.rs new file mode 100644 index 0000000..2b38bb4 --- /dev/null +++ b/easytier/src/gateway/tokio_smoltcp/mod.rs @@ -0,0 +1,220 @@ +// most code is copied from https://github.com/spacemeowx2/tokio-smoltcp + +//! An asynchronous wrapper for smoltcp. + +use std::{ + io, + net::{Ipv4Addr, Ipv6Addr, SocketAddr}, + sync::{ + atomic::{AtomicU16, Ordering}, + Arc, + }, +}; + +use device::BufferDevice; +use futures::Future; +use reactor::Reactor; +pub use smoltcp; +use smoltcp::{ + iface::{Config, Interface, Routes}, + time::{Duration, Instant}, + wire::{HardwareAddress, IpAddress, IpCidr, IpProtocol, IpVersion}, +}; +pub use socket::{RawSocket, TcpListener, TcpStream, UdpSocket}; +pub use socket_allocator::BufferSize; +use tokio::sync::Notify; + +/// The async devices. +pub mod channel_device; +pub mod device; +mod reactor; +mod socket; +mod socket_allocator; + +/// Can be used to create a forever timestamp in neighbor. +// The 60_000 is the same as NeighborCache::ENTRY_LIFETIME. +pub const FOREVER: Instant = + Instant::from_micros_const(i64::max_value() - Duration::from_millis(60_000).micros() as i64); + +pub struct Neighbor { + pub protocol_addr: IpAddress, + pub hardware_addr: HardwareAddress, + pub timestamp: Instant, +} + +/// A config for a `Net`. +/// +/// This is used to configure the `Net`. +#[non_exhaustive] +pub struct NetConfig { + pub interface_config: Config, + pub ip_addr: IpCidr, + pub gateway: Vec, + pub buffer_size: BufferSize, +} + +impl NetConfig { + pub fn new(interface_config: Config, ip_addr: IpCidr, gateway: Vec) -> Self { + Self { + interface_config, + ip_addr, + gateway, + buffer_size: Default::default(), + } + } +} + +/// `Net` is the main interface to the network stack. +/// Socket creation and configuration is done through the `Net` interface. +/// +/// When `Net` is dropped, all sockets are closed and the network stack is stopped. +pub struct Net { + reactor: Arc, + ip_addr: IpCidr, + from_port: AtomicU16, + stopper: Arc, +} + +impl std::fmt::Debug for Net { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Net") + .field("ip_addr", &self.ip_addr) + .field("from_port", &self.from_port) + .finish() + } +} + +impl Net { + /// Creates a new `Net` instance. It panics if the medium is not supported. + pub fn new(device: D, config: NetConfig) -> Net { + let (net, fut) = Self::new2(device, config); + tokio::spawn(fut); + net + } + + fn new2( + device: D, + config: NetConfig, + ) -> (Net, impl Future> + Send) { + let mut buffer_device = BufferDevice::new(device.capabilities().clone()); + let mut iface = Interface::new(config.interface_config, &mut buffer_device, Instant::now()); + let ip_addr = config.ip_addr; + iface.update_ip_addrs(|ip_addrs| { + ip_addrs.push(ip_addr).unwrap(); + }); + for gateway in config.gateway { + match gateway { + IpAddress::Ipv4(v4) => { + iface.routes_mut().add_default_ipv4_route(v4).unwrap(); + } + IpAddress::Ipv6(v6) => { + iface.routes_mut().add_default_ipv6_route(v6).unwrap(); + } + #[allow(unreachable_patterns)] + _ => panic!("Unsupported address"), + }; + } + + let stopper = Arc::new(Notify::new()); + let (reactor, fut) = Reactor::new( + device, + iface, + buffer_device, + config.buffer_size, + stopper.clone(), + ); + + ( + Net { + reactor: Arc::new(reactor), + ip_addr: config.ip_addr, + from_port: AtomicU16::new(10001), + stopper, + }, + fut, + ) + } + fn get_port(&self) -> u16 { + self.from_port + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| { + Some(if x > 60000 { 10000 } else { x + 1 }) + }) + .unwrap() + } + /// Creates a new TcpListener, which will be bound to the specified address. + pub async fn tcp_bind(&self, addr: SocketAddr) -> io::Result { + let addr = self.set_address(addr); + TcpListener::new(self.reactor.clone(), addr.into()).await + } + /// Opens a TCP connection to a remote host. + pub async fn tcp_connect(&self, addr: SocketAddr) -> io::Result { + TcpStream::connect( + self.reactor.clone(), + (self.ip_addr.address(), self.get_port()).into(), + addr.into(), + ) + .await + } + /// This function will create a new UDP socket and attempt to bind it to the `addr` provided. + pub async fn udp_bind(&self, addr: SocketAddr) -> io::Result { + let addr = self.set_address(addr); + UdpSocket::new(self.reactor.clone(), addr.into()).await + } + /// Creates a new raw socket. + pub async fn raw_socket( + &self, + ip_version: IpVersion, + ip_protocol: IpProtocol, + ) -> io::Result { + RawSocket::new(self.reactor.clone(), ip_version, ip_protocol).await + } + fn set_address(&self, mut addr: SocketAddr) -> SocketAddr { + if addr.ip().is_unspecified() { + addr.set_ip(match self.ip_addr.address() { + IpAddress::Ipv4(ip) => Ipv4Addr::from(ip).into(), + IpAddress::Ipv6(ip) => Ipv6Addr::from(ip).into(), + #[allow(unreachable_patterns)] + _ => panic!("address must not be unspecified"), + }); + } + if addr.port() == 0 { + addr.set_port(self.get_port()); + } + addr + } + + /// Enable or disable the AnyIP capability. + pub fn set_any_ip(&self, any_ip: bool) { + let iface = self.reactor.iface().clone(); + let mut iface: parking_lot::lock_api::MutexGuard<'_, parking_lot::RawMutex, Interface> = + iface.lock(); + iface.set_any_ip(any_ip); + } + + /// Get whether AnyIP is enabled. + pub fn any_ip(&self) -> bool { + let iface = self.reactor.iface().clone(); + let iface = iface.lock(); + iface.any_ip() + } + + pub fn routes(&self, f: F) { + let iface = self.reactor.iface().clone(); + let iface = iface.lock(); + let routes = iface.routes(); + f(routes) + } + + pub fn routes_mut(&self, f: F) { + let iface = self.reactor.iface().clone(); + let mut iface = iface.lock(); + let routes = iface.routes_mut(); + f(routes) + } +} + +impl Drop for Net { + fn drop(&mut self) { + self.stopper.notify_waiters() + } +} diff --git a/easytier/src/gateway/tokio_smoltcp/reactor.rs b/easytier/src/gateway/tokio_smoltcp/reactor.rs new file mode 100644 index 0000000..65a4aeb --- /dev/null +++ b/easytier/src/gateway/tokio_smoltcp/reactor.rs @@ -0,0 +1,163 @@ +use super::{ + device::{BufferDevice, Packet}, + socket_allocator::{BufferSize, SocketAlloctor}, +}; +use futures::{stream::iter, FutureExt, SinkExt, StreamExt}; +use parking_lot::{MappedMutexGuard, Mutex, MutexGuard}; +use smoltcp::{ + iface::{Context, Interface, SocketHandle}, + socket::{AnySocket, Socket}, + time::{Duration, Instant}, +}; +use std::{collections::VecDeque, future::Future, io, sync::Arc}; +use tokio::{pin, select, sync::Notify, time::sleep}; + +pub(crate) type BufferInterface = Arc>; +const MAX_BURST_SIZE: usize = 100; + +pub(crate) struct Reactor { + notify: Arc, + iface: BufferInterface, + socket_allocator: SocketAlloctor, +} + +async fn receive( + async_iface: &mut impl super::device::AsyncDevice, + recv_buf: &mut VecDeque, +) -> io::Result<()> { + if let Some(packet) = async_iface.next().await { + recv_buf.push_back(packet?); + } + Ok(()) +} + +async fn run( + mut async_iface: impl super::device::AsyncDevice, + iface: BufferInterface, + mut device: BufferDevice, + socket_allocator: SocketAlloctor, + notify: Arc, + stopper: Arc, +) -> io::Result<()> { + let default_timeout = Duration::from_secs(60); + let timer = sleep(default_timeout.into()); + let max_burst_size = async_iface + .capabilities() + .max_burst_size + .unwrap_or(MAX_BURST_SIZE); + let mut recv_buf = VecDeque::with_capacity(max_burst_size); + pin!(timer); + + loop { + let packets = device.take_send_queue(); + + async_iface + .send_all(&mut iter(packets).map(|p| Ok(p))) + .await?; + + if recv_buf.is_empty() && device.need_wait() { + let start = Instant::now(); + let deadline = { + iface + .lock() + .poll_delay(start, &socket_allocator.sockets().lock()) + .unwrap_or(default_timeout) + }; + + timer + .as_mut() + .reset(tokio::time::Instant::now() + deadline.into()); + select! { + _ = &mut timer => {}, + _ = receive(&mut async_iface,&mut recv_buf) => {} + _ = notify.notified() => {} + _ = stopper.notified() => break, + }; + + while let (true, Some(Ok(p))) = ( + recv_buf.len() < max_burst_size, + async_iface.next().now_or_never().flatten(), + ) { + recv_buf.push_back(p); + } + } + + let mut iface = iface.lock(); + + device.push_recv_queue(recv_buf.drain(..device.avaliable_recv_queue().min(recv_buf.len()))); + + iface.poll( + Instant::now(), + &mut device, + &mut socket_allocator.sockets().lock(), + ); + } + + Ok(()) +} + +impl Reactor { + pub fn new( + async_device: impl super::device::AsyncDevice, + iface: Interface, + device: BufferDevice, + buffer_size: BufferSize, + stopper: Arc, + ) -> (Self, impl Future> + Send) { + let iface = Arc::new(Mutex::new(iface)); + let notify = Arc::new(Notify::new()); + let socket_allocator = SocketAlloctor::new(buffer_size); + let fut = run( + async_device, + iface.clone(), + device, + socket_allocator.clone(), + notify.clone(), + stopper, + ); + + ( + Reactor { + notify, + iface: iface.clone(), + socket_allocator, + }, + fut, + ) + } + pub fn get_socket>( + &self, + handle: SocketHandle, + ) -> MappedMutexGuard<'_, T> { + MutexGuard::map( + self.socket_allocator.sockets().lock(), + |sockets: &mut smoltcp::iface::SocketSet<'_>| sockets.get_mut::(handle), + ) + } + pub fn context(&self) -> MappedMutexGuard<'_, Context> { + MutexGuard::map(self.iface.lock(), |iface| iface.context()) + } + pub fn socket_allocator(&self) -> &SocketAlloctor { + &self.socket_allocator + } + pub fn notify(&self) { + self.notify.notify_waiters(); + } + pub fn iface(&self) -> &BufferInterface { + &self.iface + } +} + +impl Drop for Reactor { + fn drop(&mut self) { + for (_, socket) in self.socket_allocator.sockets().lock().iter_mut() { + match socket { + Socket::Tcp(tcp) => tcp.close(), + Socket::Raw(_) => {} + Socket::Udp(udp) => udp.close(), + #[allow(unreachable_patterns)] + _ => {} + } + } + } +} diff --git a/easytier/src/gateway/tokio_smoltcp/socket.rs b/easytier/src/gateway/tokio_smoltcp/socket.rs new file mode 100644 index 0000000..0d91d14 --- /dev/null +++ b/easytier/src/gateway/tokio_smoltcp/socket.rs @@ -0,0 +1,377 @@ +use super::{reactor::Reactor, socket_allocator::SocketHandle}; +use futures::future::{self, poll_fn}; +use futures::{ready, Stream}; +pub use smoltcp::socket::{raw, tcp, udp}; +use smoltcp::wire::{IpAddress, IpEndpoint, IpProtocol, IpVersion}; +use std::mem::replace; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::{ + io, + net::SocketAddr, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// A TCP socket server, listening for connections. +/// +/// You can accept a new connection by using the accept method. +pub struct TcpListener { + handle: SocketHandle, + reactor: Arc, + local_addr: SocketAddr, +} + +fn map_err(e: E) -> io::Error { + io::Error::new(io::ErrorKind::Other, e.to_string()) +} + +impl TcpListener { + pub(super) async fn new( + reactor: Arc, + local_endpoint: IpEndpoint, + ) -> io::Result { + let handle = reactor.socket_allocator().new_tcp_socket(); + { + let mut socket = reactor.get_socket::(*handle); + socket.listen(local_endpoint).map_err(map_err)?; + } + + let local_addr = ep2sa(&local_endpoint); + Ok(TcpListener { + handle, + reactor, + local_addr, + }) + } + pub fn poll_accept( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + + if socket.state() == tcp::State::Established { + drop(socket); + return Poll::Ready(Ok(TcpStream::accept(self)?)); + } + socket.register_send_waker(cx.waker()); + Poll::Pending + } + pub async fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> { + poll_fn(|cx| self.poll_accept(cx)).await + } + pub fn incoming(self) -> Incoming { + Incoming(self) + } + pub fn local_addr(&self) -> io::Result { + Ok(self.local_addr) + } +} + +pub struct Incoming(TcpListener); + +impl Stream for Incoming { + type Item = io::Result; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let (tcp, _) = ready!(self.0.poll_accept(cx))?; + Poll::Ready(Some(Ok(tcp))) + } +} + +fn ep2sa(ep: &IpEndpoint) -> SocketAddr { + match ep.addr { + IpAddress::Ipv4(v4) => SocketAddr::new(IpAddr::V4(Ipv4Addr::from(v4)), ep.port), + IpAddress::Ipv6(v6) => SocketAddr::new(IpAddr::V6(Ipv6Addr::from(v6)), ep.port), + #[allow(unreachable_patterns)] + _ => unreachable!(), + } +} + +/// A TCP stream between a local and a remote socket. +pub struct TcpStream { + handle: SocketHandle, + reactor: Arc, + local_addr: SocketAddr, + peer_addr: SocketAddr, +} + +impl TcpStream { + pub(super) async fn connect( + reactor: Arc, + local_endpoint: IpEndpoint, + remote_endpoint: IpEndpoint, + ) -> io::Result { + let handle = reactor.socket_allocator().new_tcp_socket(); + + reactor + .get_socket::(*handle) + .connect(&mut reactor.context(), remote_endpoint, local_endpoint) + .map_err(map_err)?; + + let local_addr = ep2sa(&local_endpoint); + let peer_addr = ep2sa(&remote_endpoint); + let tcp = TcpStream { + handle, + reactor, + local_addr, + peer_addr, + }; + + tcp.reactor.notify(); + future::poll_fn(|cx| tcp.poll_connected(cx)).await?; + + Ok(tcp) + } + + fn accept(listener: &mut TcpListener) -> io::Result<(TcpStream, SocketAddr)> { + let reactor = listener.reactor.clone(); + let new_handle = reactor.socket_allocator().new_tcp_socket(); + { + let mut new_socket = reactor.get_socket::(*new_handle); + new_socket.listen(listener.local_addr).map_err(map_err)?; + } + let (peer_addr, local_addr) = { + let socket = reactor.get_socket::(*listener.handle); + ( + // should be Some, because the state is Established + ep2sa(&socket.remote_endpoint().unwrap()), + ep2sa(&socket.local_endpoint().unwrap()), + ) + }; + + Ok(( + TcpStream { + handle: replace(&mut listener.handle, new_handle), + reactor: reactor.clone(), + local_addr, + peer_addr, + }, + peer_addr, + )) + } + + pub fn local_addr(&self) -> io::Result { + Ok(self.local_addr) + } + pub fn peer_addr(&self) -> io::Result { + Ok(self.peer_addr) + } + pub fn poll_connected(&self, cx: &mut Context<'_>) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + if socket.state() == tcp::State::Established { + return Poll::Ready(Ok(())); + } + socket.register_send_waker(cx.waker()); + Poll::Pending + } +} + +impl AsyncRead for TcpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + if !socket.may_recv() { + return Poll::Ready(Ok(())); + } + if socket.can_recv() { + let read = socket + .recv_slice(buf.initialize_unfilled()) + .map_err(map_err)?; + self.reactor.notify(); + buf.advance(read); + return Poll::Ready(Ok(())); + } + socket.register_recv_waker(cx.waker()); + Poll::Pending + } +} + +impl AsyncWrite for TcpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + if !socket.may_send() { + return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())); + } + if socket.can_send() { + let r = socket.send_slice(buf).map_err(map_err)?; + self.reactor.notify(); + return Poll::Ready(Ok(r)); + } + socket.register_send_waker(cx.waker()); + Poll::Pending + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + if socket.send_queue() == 0 { + return Poll::Ready(Ok(())); + } + socket.register_send_waker(cx.waker()); + Poll::Pending + } + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + + if socket.is_open() { + socket.close(); + self.reactor.notify(); + } + if socket.state() == tcp::State::Closed { + return Poll::Ready(Ok(())); + } + + socket.register_send_waker(cx.waker()); + Poll::Pending + } +} + +/// A UDP socket. +pub struct UdpSocket { + handle: SocketHandle, + reactor: Arc, + local_addr: SocketAddr, +} + +impl UdpSocket { + pub(super) async fn new( + reactor: Arc, + local_endpoint: IpEndpoint, + ) -> io::Result { + let handle = reactor.socket_allocator().new_udp_socket(); + { + let mut socket = reactor.get_socket::(*handle); + socket.bind(local_endpoint).map_err(map_err)?; + } + + let local_addr = ep2sa(&local_endpoint); + + Ok(UdpSocket { + handle, + reactor, + local_addr, + }) + } + /// Note that on multiple calls to a poll_* method in the send direction, only the Waker from the Context passed to the most recent call will be scheduled to receive a wakeup. + pub fn poll_send_to( + &self, + cx: &mut Context<'_>, + buf: &[u8], + target: SocketAddr, + ) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + let target_ip: IpEndpoint = target.into(); + + match socket.send_slice(buf, target_ip) { + // the buffer is full + Err(udp::SendError::BufferFull) => {} + r => { + r.map_err(map_err)?; + self.reactor.notify(); + return Poll::Ready(Ok(buf.len())); + } + } + + socket.register_send_waker(cx.waker()); + Poll::Pending + } + /// See note on `poll_send_to` + pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result { + poll_fn(|cx| self.poll_send_to(cx, buf, target)).await + } + /// Note that on multiple calls to a poll_* method in the recv direction, only the Waker from the Context passed to the most recent call will be scheduled to receive a wakeup. + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + + match socket.recv_slice(buf) { + // the buffer is empty + Err(udp::RecvError::Exhausted) => {} + r => { + let (size, metadata) = r.map_err(map_err)?; + self.reactor.notify(); + return Poll::Ready(Ok((size, ep2sa(&metadata.endpoint)))); + } + } + + socket.register_recv_waker(cx.waker()); + Poll::Pending + } + /// See note on `poll_recv_from` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + poll_fn(|cx| self.poll_recv_from(cx, buf)).await + } + pub fn local_addr(&self) -> io::Result { + Ok(self.local_addr) + } +} + +/// A raw socket. +pub struct RawSocket { + handle: SocketHandle, + reactor: Arc, +} + +impl RawSocket { + pub(super) async fn new( + reactor: Arc, + ip_version: IpVersion, + ip_protocol: IpProtocol, + ) -> io::Result { + let handle = reactor + .socket_allocator() + .new_raw_socket(ip_version, ip_protocol); + + Ok(RawSocket { handle, reactor }) + } + /// Note that on multiple calls to a poll_* method in the send direction, only the Waker from the Context passed to the most recent call will be scheduled to receive a wakeup. + pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + + match socket.send_slice(buf) { + // the buffer is full + Err(raw::SendError::BufferFull) => {} + r => { + r.map_err(map_err)?; + self.reactor.notify(); + return Poll::Ready(Ok(buf.len())); + } + } + + socket.register_send_waker(cx.waker()); + Poll::Pending + } + /// See note on `poll_send` + pub async fn send(&self, buf: &[u8]) -> io::Result { + poll_fn(|cx| self.poll_send(cx, buf)).await + } + /// Note that on multiple calls to a poll_* method in the recv direction, only the Waker from the Context passed to the most recent call will be scheduled to receive a wakeup. + pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + + match socket.recv_slice(buf) { + // the buffer is empty + Err(raw::RecvError::Exhausted) => {} + r => { + let size = r.map_err(map_err)?; + return Poll::Ready(Ok(size)); + } + } + + socket.register_recv_waker(cx.waker()); + Poll::Pending + } + /// See note on `poll_recv` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { + poll_fn(|cx| self.poll_recv(cx, buf)).await + } +} diff --git a/easytier/src/gateway/tokio_smoltcp/socket_allocator.rs b/easytier/src/gateway/tokio_smoltcp/socket_allocator.rs new file mode 100644 index 0000000..9b1a055 --- /dev/null +++ b/easytier/src/gateway/tokio_smoltcp/socket_allocator.rs @@ -0,0 +1,145 @@ +use parking_lot::Mutex; +use smoltcp::{ + iface::{SocketHandle as InnerSocketHandle, SocketSet}, + socket::{raw, tcp, udp}, + wire::{IpProtocol, IpVersion}, +}; +use std::{ + ops::{Deref, DerefMut}, + sync::Arc, +}; + +/// `BufferSize` is used to configure the size of the socket buffer. +#[derive(Debug, Clone, Copy)] +pub struct BufferSize { + pub tcp_rx_size: usize, + pub tcp_tx_size: usize, + pub udp_rx_size: usize, + pub udp_tx_size: usize, + pub udp_rx_meta_size: usize, + pub udp_tx_meta_size: usize, + pub raw_rx_size: usize, + pub raw_tx_size: usize, + pub raw_rx_meta_size: usize, + pub raw_tx_meta_size: usize, +} + +impl Default for BufferSize { + fn default() -> Self { + BufferSize { + tcp_rx_size: 8192, + tcp_tx_size: 8192, + udp_rx_size: 8192, + udp_tx_size: 8192, + udp_rx_meta_size: 32, + udp_tx_meta_size: 32, + raw_rx_size: 8192, + raw_tx_size: 8192, + raw_rx_meta_size: 32, + raw_tx_meta_size: 32, + } + } +} + +type SharedSocketSet = Arc>>; + +#[derive(Clone)] +pub struct SocketAlloctor { + sockets: SharedSocketSet, + buffer_size: BufferSize, +} + +impl SocketAlloctor { + pub(crate) fn new(buffer_size: BufferSize) -> SocketAlloctor { + let sockets = Arc::new(Mutex::new(SocketSet::new(Vec::new()))); + SocketAlloctor { + sockets, + buffer_size, + } + } + pub(crate) fn sockets(&self) -> &SharedSocketSet { + &self.sockets + } + pub fn new_tcp_socket(&self) -> SocketHandle { + let mut set = self.sockets.lock(); + let handle = set.add(self.alloc_tcp_socket()); + SocketHandle::new(handle, self.sockets.clone()) + } + pub fn new_udp_socket(&self) -> SocketHandle { + let mut set = self.sockets.lock(); + let handle = set.add(self.alloc_udp_socket()); + SocketHandle::new(handle, self.sockets.clone()) + } + pub fn new_raw_socket(&self, ip_version: IpVersion, ip_protocol: IpProtocol) -> SocketHandle { + let mut set = self.sockets.lock(); + let handle = set.add(self.alloc_raw_socket(ip_version, ip_protocol)); + SocketHandle::new(handle, self.sockets.clone()) + } + fn alloc_tcp_socket(&self) -> tcp::Socket<'static> { + let rx_buffer = tcp::SocketBuffer::new(vec![0; self.buffer_size.tcp_rx_size]); + let tx_buffer = tcp::SocketBuffer::new(vec![0; self.buffer_size.tcp_tx_size]); + let mut tcp = tcp::Socket::new(rx_buffer, tx_buffer); + tcp.set_nagle_enabled(false); + + tcp + } + fn alloc_udp_socket(&self) -> udp::Socket<'static> { + let rx_buffer = udp::PacketBuffer::new( + vec![udp::PacketMetadata::EMPTY; self.buffer_size.udp_rx_meta_size], + vec![0; self.buffer_size.udp_rx_size], + ); + let tx_buffer = udp::PacketBuffer::new( + vec![udp::PacketMetadata::EMPTY; self.buffer_size.udp_tx_meta_size], + vec![0; self.buffer_size.udp_tx_size], + ); + let udp = udp::Socket::new(rx_buffer, tx_buffer); + + udp + } + fn alloc_raw_socket( + &self, + ip_version: IpVersion, + ip_protocol: IpProtocol, + ) -> raw::Socket<'static> { + let rx_buffer = raw::PacketBuffer::new( + vec![raw::PacketMetadata::EMPTY; self.buffer_size.raw_rx_meta_size], + vec![0; self.buffer_size.raw_rx_size], + ); + let tx_buffer = raw::PacketBuffer::new( + vec![raw::PacketMetadata::EMPTY; self.buffer_size.raw_tx_meta_size], + vec![0; self.buffer_size.raw_tx_size], + ); + let raw = raw::Socket::new(ip_version, ip_protocol, rx_buffer, tx_buffer); + + raw + } +} + +pub struct SocketHandle(InnerSocketHandle, SharedSocketSet); + +impl SocketHandle { + fn new(inner: InnerSocketHandle, set: SharedSocketSet) -> SocketHandle { + SocketHandle(inner, set) + } +} + +impl Drop for SocketHandle { + fn drop(&mut self) { + let mut iface = self.1.lock(); + iface.remove(self.0); + } +} + +impl Deref for SocketHandle { + type Target = InnerSocketHandle; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for SocketHandle { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/easytier/src/gateway/udp_proxy.rs b/easytier/src/gateway/udp_proxy.rs index 7f8d55c..9ef75ad 100644 --- a/easytier/src/gateway/udp_proxy.rs +++ b/easytier/src/gateway/udp_proxy.rs @@ -1,5 +1,5 @@ use std::{ - net::{SocketAddr, SocketAddrV4}, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, sync::{atomic::AtomicBool, Arc}, time::Duration, }; @@ -129,14 +129,18 @@ impl UdpNatEntry { Ok(()) } - async fn forward_task(self: Arc, mut packet_sender: UnboundedSender) { + async fn forward_task( + self: Arc, + mut packet_sender: UnboundedSender, + virtual_ipv4: Ipv4Addr, + ) { let mut buf = [0u8; 65536]; let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) }; let mut ip_id = 1; loop { let (len, src_socket) = match timeout( - Duration::from_secs(120), + Duration::from_secs(30), self.socket.recv_from(&mut udp_body), ) .await @@ -158,10 +162,14 @@ impl UdpNatEntry { break; } - let SocketAddr::V4(src_v4) = src_socket else { + let SocketAddr::V4(mut src_v4) = src_socket else { continue; }; + if src_v4.ip().is_loopback() { + src_v4.set_ip(virtual_ipv4); + } + let Ok(_) = Self::compose_ipv4_packet( &self, &mut packet_sender, @@ -201,7 +209,10 @@ pub struct UdpProxy { impl UdpProxy { async fn try_handle_packet(&self, packet: &ZCPacket) -> Option<()> { - if self.cidr_set.is_empty() && !self.global_ctx.enable_exit_node() { + if self.cidr_set.is_empty() + && !self.global_ctx.enable_exit_node() + && !self.global_ctx.no_tun() + { return None; } @@ -217,7 +228,11 @@ impl UdpProxy { return None; } - if !self.cidr_set.contains_v4(ipv4.get_destination()) && !is_exit_node { + if !self.cidr_set.contains_v4(ipv4.get_destination()) + && !is_exit_node + && !(self.global_ctx.no_tun() + && Some(ipv4.get_destination()) == self.global_ctx.get_ipv4()) + { return None; } @@ -267,12 +282,19 @@ impl UdpProxy { .replace(tokio::spawn(UdpNatEntry::forward_task( nat_entry.clone(), self.sender.clone(), + self.global_ctx.get_ipv4()?, ))); } // TODO: should it be async. - let dst_socket = - SocketAddr::new(ipv4.get_destination().into(), udp_packet.get_destination()); + let dst_socket = if Some(ipv4.get_destination()) == self.global_ctx.get_ipv4() { + format!("127.0.0.1:{}", udp_packet.get_destination()) + .parse() + .unwrap() + } else { + SocketAddr::new(ipv4.get_destination().into(), udp_packet.get_destination()) + }; + let send_ret = { let _g = self.global_ctx.net_ns.guard(); nat_entry diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index 7f18f85..e3e0982 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -68,7 +68,8 @@ impl IpProxy { async fn start(&self) -> Result<(), Error> { if (self.global_ctx.get_proxy_cidrs().is_empty() || self.started.load(Ordering::Relaxed)) - && !self.global_ctx.config.get_flags().enable_exit_node + && !self.global_ctx.enable_exit_node() + && !self.global_ctx.no_tun() { return Ok(()); } @@ -502,16 +503,20 @@ impl Instance { self.listener_manager.lock().await.run().await?; self.peer_manager.run().await?; - if self.global_ctx.config.get_dhcp() { - self.check_dhcp_ip_conflict(); - } else if let Some(ipv4_addr) = self.global_ctx.get_ipv4() { - let mut new_nic_ctx = NicCtx::new( - self.global_ctx.clone(), - &self.peer_manager, - self.peer_packet_receiver.clone(), - ); - new_nic_ctx.run(ipv4_addr).await?; - Self::use_new_nic_ctx(self.nic_ctx.clone(), new_nic_ctx).await; + if !self.global_ctx.config.get_flags().no_tun { + if self.global_ctx.config.get_dhcp() { + self.check_dhcp_ip_conflict(); + } else if let Some(ipv4_addr) = self.global_ctx.get_ipv4() { + let mut new_nic_ctx = NicCtx::new( + self.global_ctx.clone(), + &self.peer_manager, + self.peer_packet_receiver.clone(), + ); + new_nic_ctx.run(ipv4_addr).await?; + Self::use_new_nic_ctx(self.nic_ctx.clone(), new_nic_ctx).await; + } + } else { + self.peer_packet_receiver.lock().await.close(); } self.run_rpc_server()?; diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 17a4439..83d5f57 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -417,7 +417,7 @@ impl PeerManager { if hdr.packet_type == PacketType::Data as u8 { tracing::trace!(?packet, "send packet to nic channel"); // TODO: use a function to get the body ref directly for zero copy - self.nic_channel.send(packet).await.unwrap(); + let _ = self.nic_channel.send(packet).await; None } else { Some(packet) diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index 3a9b3c5..7bde102 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -1441,7 +1441,7 @@ impl Route for PeerRoute { return Some(peer_id); } - tracing::info!(?ipv4_addr, "no peer id for ipv4"); + tracing::debug!(?ipv4_addr, "no peer id for ipv4"); None } diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 2a1de6f..7b7376d 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -60,12 +60,31 @@ pub fn get_inst_config(inst_name: &str, ns: Option<&str>, ipv4: &str) -> TomlCon } pub async fn init_three_node(proto: &str) -> Vec { + init_three_node_ex(proto, |cfg| cfg).await +} + +pub async fn init_three_node_ex TomlConfigLoader>( + proto: &str, + cfg_cb: F, +) -> Vec { log::set_max_level(log::LevelFilter::Info); prepare_linux_namespaces(); - let mut inst1 = Instance::new(get_inst_config("inst1", Some("net_a"), "10.144.144.1")); - let mut inst2 = Instance::new(get_inst_config("inst2", Some("net_b"), "10.144.144.2")); - let mut inst3 = Instance::new(get_inst_config("inst3", Some("net_c"), "10.144.144.3")); + let mut inst1 = Instance::new(cfg_cb(get_inst_config( + "inst1", + Some("net_a"), + "10.144.144.1", + ))); + let mut inst2 = Instance::new(cfg_cb(get_inst_config( + "inst2", + Some("net_b"), + "10.144.144.2", + ))); + let mut inst3 = Instance::new(cfg_cb(get_inst_config( + "inst3", + Some("net_c"), + "10.144.144.3", + ))); inst1.run().await.unwrap(); inst2.run().await.unwrap(); @@ -183,33 +202,79 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg", "ws", "wss")] pr .await; } -#[rstest::rstest] -#[tokio::test] -#[serial_test::serial] -pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { +async fn subnet_proxy_test_udp() { + use crate::tunnel::{common::tests::_tunnel_pingpong_netns, udp::UdpTunnelListener}; use rand::Rng; - use crate::tunnel::{common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener}; + let udp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap()); + let udp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap()); - let mut insts = init_three_node(proto).await; + // NOTE: this should not excced udp tunnel max buffer size + let mut buf = vec![0; 20 * 1024]; + rand::thread_rng().fill(&mut buf[..]); - insts[2] - .get_global_ctx() - .add_proxy_cidr("10.1.2.0/24".parse().unwrap()) - .unwrap(); - insts[2].run_ip_proxy().await.unwrap(); - assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1); - - wait_proxy_route_appear( - &insts[0].get_peer_manager(), - "10.144.144.3", - insts[2].peer_id(), - "10.1.2.0/24", + _tunnel_pingpong_netns( + udp_listener, + udp_connector, + NetNS::new(Some("net_d".into())), + NetNS::new(Some("net_a".into())), + buf, ) .await; - // wait updater - tokio::time::sleep(tokio::time::Duration::from_secs(6)).await; + // no fragment + let udp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap()); + let udp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap()); + + let mut buf = vec![0; 1 * 1024]; + rand::thread_rng().fill(&mut buf[..]); + + _tunnel_pingpong_netns( + udp_listener, + udp_connector, + NetNS::new(Some("net_d".into())), + NetNS::new(Some("net_a".into())), + buf, + ) + .await; + + // connect to virtual ip (no tun mode) + + let udp_listener = UdpTunnelListener::new("udp://0.0.0.0:22234".parse().unwrap()); + let udp_connector = UdpTunnelConnector::new("udp://10.144.144.3:22234".parse().unwrap()); + // NOTE: this should not excced udp tunnel max buffer size + let mut buf = vec![0; 20 * 1024]; + rand::thread_rng().fill(&mut buf[..]); + + _tunnel_pingpong_netns( + udp_listener, + udp_connector, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_a".into())), + buf, + ) + .await; + + // no fragment + let udp_listener = UdpTunnelListener::new("udp://0.0.0.0:22235".parse().unwrap()); + let udp_connector = UdpTunnelConnector::new("udp://10.144.144.3:22235".parse().unwrap()); + + let mut buf = vec![0; 1 * 1024]; + rand::thread_rng().fill(&mut buf[..]); + + _tunnel_pingpong_netns( + udp_listener, + udp_connector, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_a".into())), + buf, + ) + .await; +} + +async fn subnet_proxy_test_tcp() { + use crate::tunnel::{common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener}; + use rand::Rng; let tcp_listener = TcpTunnelListener::new("tcp://10.1.2.4:22223".parse().unwrap()); let tcp_connector = TcpTunnelConnector::new("tcp://10.1.2.4:22223".parse().unwrap()); @@ -225,29 +290,25 @@ pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str buf, ) .await; -} -#[rstest::rstest] -#[tokio::test] -#[serial_test::serial] -pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { - let mut insts = init_three_node(proto).await; + // connect to virtual ip (no tun mode) + let tcp_listener = TcpTunnelListener::new("tcp://0.0.0.0:22223".parse().unwrap()); + let tcp_connector = TcpTunnelConnector::new("tcp://10.144.144.3:22223".parse().unwrap()); - insts[2] - .get_global_ctx() - .add_proxy_cidr("10.1.2.0/24".parse().unwrap()) - .unwrap(); - insts[2].run_ip_proxy().await.unwrap(); - assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1); + let mut buf = vec![0; 32]; + rand::thread_rng().fill(&mut buf[..]); - wait_proxy_route_appear( - &insts[0].get_peer_manager(), - "10.144.144.3", - insts[2].peer_id(), - "10.1.2.0/24", + _tunnel_pingpong_netns( + tcp_listener, + tcp_connector, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_a".into())), + buf, ) .await; +} +async fn subnet_proxy_test_icmp() { wait_for_condition( || async { ping_test("net_a", "10.1.2.4", None).await }, Duration::from_secs(5), @@ -259,6 +320,52 @@ pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &st Duration::from_secs(5), ) .await; + + // connect to virtual ip (no tun mode) + wait_for_condition( + || async { ping_test("net_a", "10.144.144.3", None).await }, + Duration::from_secs(5), + ) + .await; + + wait_for_condition( + || async { ping_test("net_a", "10.144.144.3", Some(5 * 1024)).await }, + Duration::from_secs(5), + ) + .await; +} + +#[rstest::rstest] +#[tokio::test] +#[serial_test::serial] +pub async fn subnet_proxy_three_node_test( + #[values("tcp", "udp", "wg")] proto: &str, + #[values(true)] no_tun: bool, +) { + let insts = init_three_node_ex(proto, |cfg| { + if cfg.get_inst_name() == "inst3" { + let mut flags = cfg.get_flags(); + flags.no_tun = no_tun; + cfg.set_flags(flags); + cfg.add_proxy_cidr("10.1.2.0/24".parse().unwrap()); + } + cfg + }) + .await; + + assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1); + + wait_proxy_route_appear( + &insts[0].get_peer_manager(), + "10.144.144.3", + insts[2].peer_id(), + "10.1.2.0/24", + ) + .await; + + subnet_proxy_test_icmp().await; + subnet_proxy_test_tcp().await; + subnet_proxy_test_udp().await; } #[cfg(feature = "wireguard")] @@ -328,67 +435,6 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str assert!(ret.is_ok()); } -#[rstest::rstest] -#[tokio::test] -#[serial_test::serial] -pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { - use rand::Rng; - - use crate::tunnel::{common::tests::_tunnel_pingpong_netns, udp::UdpTunnelListener}; - - let mut insts = init_three_node(proto).await; - - insts[2] - .get_global_ctx() - .add_proxy_cidr("10.1.2.0/24".parse().unwrap()) - .unwrap(); - insts[2].run_ip_proxy().await.unwrap(); - assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1); - - wait_proxy_route_appear( - &insts[0].get_peer_manager(), - "10.144.144.3", - insts[2].peer_id(), - "10.1.2.0/24", - ) - .await; - - // wait updater - tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; - - let tcp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap()); - let tcp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap()); - - // NOTE: this should not excced udp tunnel max buffer size - let mut buf = vec![0; 20 * 1024]; - rand::thread_rng().fill(&mut buf[..]); - - _tunnel_pingpong_netns( - tcp_listener, - tcp_connector, - NetNS::new(Some("net_d".into())), - NetNS::new(Some("net_a".into())), - buf, - ) - .await; - - // no fragment - let tcp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap()); - let tcp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap()); - - let mut buf = vec![0; 1 * 1024]; - rand::thread_rng().fill(&mut buf[..]); - - _tunnel_pingpong_netns( - tcp_listener, - tcp_connector, - NetNS::new(Some("net_d".into())), - NetNS::new(Some("net_a".into())), - buf, - ) - .await; -} - #[tokio::test] #[serial_test::serial] pub async fn udp_broadcast_test() { diff --git a/easytier/src/tunnel/wireguard.rs b/easytier/src/tunnel/wireguard.rs index 36e55cd..25d4c4f 100644 --- a/easytier/src/tunnel/wireguard.rs +++ b/easytier/src/tunnel/wireguard.rs @@ -245,7 +245,7 @@ impl WgPeerData { } } _ => { - tracing::warn!( + tracing::debug!( "Unexpected WireGuard state during decapsulation: {:?}", decapsulate_result );