diff --git a/easytier-core/src/connector/udp_hole_punch.rs b/easytier-core/src/connector/udp_hole_punch.rs index b7db5c8..14bbe1e 100644 --- a/easytier-core/src/connector/udp_hole_punch.rs +++ b/easytier-core/src/connector/udp_hole_punch.rs @@ -56,7 +56,10 @@ impl UdpHolePunchListener { let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap()); - listener.listen().await?; + { + let _g = peer_mgr.get_global_ctx().net_ns.guard(); + listener.listen().await?; + } let socket = listener.get_socket().unwrap(); let running = Arc::new(AtomicCell::new(true)); @@ -339,7 +342,10 @@ impl UdpHolePunchConnector { ) -> Result, anyhow::Error> { tracing::info!(?dst_peer_id, "start hole punching"); // client: choose a local udp port, and get the pubic mapped port from stun server - let socket = UdpSocket::bind("0.0.0.0:0").await.with_context(|| "")?; + let socket = { + let _g = data.global_ctx.net_ns.guard(); + UdpSocket::bind("0.0.0.0:0").await.with_context(|| "")? + }; let local_socket_addr = socket.local_addr()?; let local_port = socket.local_addr()?.port(); drop(socket); // drop the socket to release the port @@ -388,6 +394,7 @@ impl UdpHolePunchConnector { .unwrap(), ); + let _g = data.global_ctx.net_ns.guard(); let socket2_socket = socket2::Socket::new( socket2::Domain::for_address(local_socket_addr), socket2::Type::DGRAM, diff --git a/easytier-core/src/gateway/mod.rs b/easytier-core/src/gateway/mod.rs index 07b2815..1466e58 100644 --- a/easytier-core/src/gateway/mod.rs +++ b/easytier-core/src/gateway/mod.rs @@ -6,6 +6,7 @@ use crate::common::global_ctx::ArcGlobalCtx; pub mod icmp_proxy; pub mod tcp_proxy; +pub mod udp_proxy; #[derive(Debug)] struct CidrSet { @@ -48,4 +49,8 @@ impl CidrSet { let ip = ip.into(); return self.cidr_set.iter().any(|cidr| cidr.contains(&ip)); } + + pub fn is_empty(&self) -> bool { + return self.cidr_set.is_empty(); + } } diff --git a/easytier-core/src/gateway/udp_proxy.rs b/easytier-core/src/gateway/udp_proxy.rs new file mode 100644 index 0000000..192f658 --- /dev/null +++ b/easytier-core/src/gateway/udp_proxy.rs @@ -0,0 +1,387 @@ +use std::{ + net::{SocketAddr, SocketAddrV4}, + sync::{atomic::AtomicBool, Arc}, + time::Duration, +}; + +use dashmap::DashMap; +use pnet::packet::{ + ip::IpNextHeaderProtocols, + ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet}, + udp::{self, MutableUdpPacket}, + Packet, +}; +use tokio::{ + net::UdpSocket, + sync::{ + mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + Mutex, + }, + task::{JoinHandle, JoinSet}, + time::timeout, +}; + +use tokio_util::bytes::Bytes; +use tracing::Level; + +use crate::{ + common::{error::Error, global_ctx::ArcGlobalCtx}, + peers::{ + packet, + peer_manager::{PeerManager, PeerPacketFilter}, + PeerId, + }, + tunnels::common::setup_sokcet2, +}; + +use super::CidrSet; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct UdpNatKey { + src_socket: SocketAddr, +} + +#[derive(Debug)] +struct UdpNatEntry { + src_peer_id: PeerId, + my_peer_id: PeerId, + src_socket: SocketAddr, + socket: UdpSocket, + forward_task: Mutex>>, + stopped: AtomicBool, + start_time: std::time::Instant, +} + +impl UdpNatEntry { + #[tracing::instrument(err(level = Level::WARN))] + fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_socket: SocketAddr) -> Result { + // TODO: try use src port, so we will be ip restricted nat type + let socket2_socket = socket2::Socket::new( + socket2::Domain::IPV4, + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + )?; + let dst_socket_addr = "0.0.0.0:0".parse().unwrap(); + setup_sokcet2(&socket2_socket, &dst_socket_addr)?; + let socket = UdpSocket::from_std(socket2_socket.into())?; + + Ok(Self { + src_peer_id, + my_peer_id, + src_socket, + socket, + forward_task: Mutex::new(None), + stopped: AtomicBool::new(false), + start_time: std::time::Instant::now(), + }) + } + + pub fn stop(&self) { + self.stopped + .store(true, std::sync::atomic::Ordering::Relaxed); + } + + async fn compose_ipv4_packet( + self: &Arc, + packet_sender: &mut UnboundedSender, + buf: &mut [u8], + src_v4: &SocketAddrV4, + payload_len: usize, + payload_mtu: usize, + ip_id: u16, + ) -> Result<(), Error> { + let SocketAddr::V4(nat_src_v4) = self.src_socket else { + return Err(Error::Unknown); + }; + + assert_eq!(0, payload_mtu % 8); + + // udp payload is in buf[20 + 8..] + let mut udp_packet = MutableUdpPacket::new(&mut buf[20..28 + payload_len]).unwrap(); + udp_packet.set_source(src_v4.port()); + udp_packet.set_destination(self.src_socket.port()); + udp_packet.set_length(payload_len as u16 + 8); + udp_packet.set_checksum(udp::ipv4_checksum( + &udp_packet.to_immutable(), + src_v4.ip(), + nat_src_v4.ip(), + )); + + let payload_len = payload_len + 8; // include udp header + let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu; + let mut buf_offset = 0; + let mut fragment_offset = 0; + let mut cur_piece = 0; + while fragment_offset < payload_len { + let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len); + let fragment_len = next_fragment_offset - fragment_offset; + let mut ipv4_packet = + MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20]) + .unwrap(); + ipv4_packet.set_version(4); + ipv4_packet.set_header_length(5); + ipv4_packet.set_total_length((fragment_len + 20) as u16); + ipv4_packet.set_identification(ip_id); + if total_pieces > 1 { + if cur_piece != total_pieces - 1 { + ipv4_packet.set_flags(Ipv4Flags::MoreFragments); + } else { + ipv4_packet.set_flags(0); + } + assert_eq!(0, fragment_offset % 8); + ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8); + } else { + ipv4_packet.set_flags(Ipv4Flags::DontFragment); + ipv4_packet.set_fragment_offset(0); + } + ipv4_packet.set_ecn(0); + ipv4_packet.set_dscp(0); + ipv4_packet.set_ttl(32); + ipv4_packet.set_source(src_v4.ip().clone()); + ipv4_packet.set_destination(nat_src_v4.ip().clone()); + ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Udp); + ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable())); + + tracing::trace!(?ipv4_packet, "udp nat packet response send"); + + let peer_packet = packet::Packet::new_data_packet( + self.my_peer_id, + self.src_peer_id, + &ipv4_packet.to_immutable().packet(), + ); + + if let Err(e) = packet_sender.send(peer_packet) { + tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e); + return Err(Error::AnyhowError(e.into())); + } + + buf_offset += next_fragment_offset - fragment_offset; + fragment_offset = next_fragment_offset; + cur_piece += 1; + } + Ok(()) + } + + async fn forward_task(self: Arc, mut packet_sender: UnboundedSender) { + let mut buf = [0u8; 8192]; + let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) }; + let mut ip_id = 1; + + loop { + let (len, src_socket) = match timeout( + Duration::from_secs(120), + self.socket.recv_from(&mut udp_body), + ) + .await + { + Ok(Ok(x)) => x, + Ok(Err(err)) => { + tracing::error!(?err, "udp nat recv failed"); + break; + } + Err(err) => { + tracing::error!(?err, "udp nat recv timeout"); + break; + } + }; + + tracing::trace!(?len, ?src_socket, "udp nat packet response received"); + + if self.stopped.load(std::sync::atomic::Ordering::Relaxed) { + break; + } + + let SocketAddr::V4(src_v4) = src_socket else { + continue; + }; + + let Ok(_) = Self::compose_ipv4_packet( + &self, + &mut packet_sender, + &mut buf, + &src_v4, + len, + 1200, + ip_id, + ) + .await + else { + break; + }; + ip_id += 1; + } + + self.stop(); + } +} + +#[derive(Debug)] +pub struct UdpProxy { + global_ctx: ArcGlobalCtx, + peer_manager: Arc, + + cidr_set: CidrSet, + + nat_table: Arc>>, + + sender: UnboundedSender, + receiver: Mutex>>, + + tasks: Mutex>, +} + +#[async_trait::async_trait] +impl PeerPacketFilter for UdpProxy { + async fn try_process_packet_from_peer( + &self, + packet: &packet::ArchivedPacket, + _: &Bytes, + ) -> Option<()> { + if self.cidr_set.is_empty() { + return None; + } + + let _ = self.global_ctx.get_ipv4()?; + + let packet::ArchivedPacketBody::Data(x) = &packet.body else { + return None; + }; + + let ipv4 = Ipv4Packet::new(&x.data)?; + + if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Udp { + return None; + } + + if !self.cidr_set.contains_v4(ipv4.get_destination()) { + return None; + } + + let udp_packet = udp::UdpPacket::new(ipv4.payload())?; + + tracing::trace!( + ?packet, + ?ipv4, + ?udp_packet, + "udp nat packet request received" + ); + + let nat_key = UdpNatKey { + src_socket: SocketAddr::new(ipv4.get_source().into(), udp_packet.get_source()), + }; + let nat_entry = self + .nat_table + .entry(nat_key) + .or_try_insert_with::(|| { + tracing::info!(?packet, ?ipv4, ?udp_packet, "udp nat table entry created"); + let _g = self.global_ctx.net_ns.guard(); + Ok(Arc::new(UdpNatEntry::new( + packet.from_peer.to_uuid(), + packet.to_peer.as_ref().unwrap().to_uuid(), + nat_key.src_socket, + )?)) + }) + .ok()? + .clone(); + + if nat_entry.forward_task.lock().await.is_none() { + nat_entry + .forward_task + .lock() + .await + .replace(tokio::spawn(UdpNatEntry::forward_task( + nat_entry.clone(), + self.sender.clone(), + ))); + } + + // TODO: should it be async. + let dst_socket = + SocketAddr::new(ipv4.get_destination().into(), udp_packet.get_destination()); + let send_ret = { + let _g = self.global_ctx.net_ns.guard(); + nat_entry + .socket + .send_to(udp_packet.payload(), dst_socket) + .await + }; + + if let Err(send_err) = send_ret { + tracing::error!( + ?send_err, + ?nat_key, + ?nat_entry, + ?send_err, + "udp nat send failed" + ); + } + + Some(()) + } +} + +impl UdpProxy { + pub fn new( + global_ctx: ArcGlobalCtx, + peer_manager: Arc, + ) -> Result, Error> { + let cidr_set = CidrSet::new(global_ctx.clone()); + let (sender, receiver) = unbounded_channel(); + let ret = Self { + global_ctx, + peer_manager, + cidr_set, + nat_table: Arc::new(DashMap::new()), + sender, + receiver: Mutex::new(Some(receiver)), + tasks: Mutex::new(JoinSet::new()), + }; + Ok(Arc::new(ret)) + } + + pub async fn start(self: &Arc) -> Result<(), Error> { + self.peer_manager + .add_packet_process_pipeline(Box::new(self.clone())) + .await; + + // clean up nat table + let nat_table = self.nat_table.clone(); + self.tasks.lock().await.spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(15)).await; + nat_table.retain(|_, v| { + if v.start_time.elapsed().as_secs() > 120 { + tracing::info!(?v, "udp nat table entry removed"); + v.stop(); + false + } else { + true + } + }); + } + }); + + // forward packets to peer manager + let mut receiver = self.receiver.lock().await.take().unwrap(); + let peer_manager = self.peer_manager.clone(); + self.tasks.lock().await.spawn(async move { + while let Some(msg) = receiver.recv().await { + let to_peer_id: uuid::Uuid = msg.to_peer.as_ref().unwrap().clone().into(); + tracing::trace!(?msg, ?to_peer_id, "udp nat packet response send"); + let ret = peer_manager.send_msg(msg.into(), &to_peer_id).await; + if ret.is_err() { + tracing::error!("send icmp packet to peer failed: {:?}", ret); + } + } + }); + Ok(()) + } +} + +impl Drop for UdpProxy { + fn drop(&mut self) { + for v in self.nat_table.iter() { + v.stop(); + } + } +} diff --git a/easytier-core/src/instance/instance.rs b/easytier-core/src/instance/instance.rs index ffd1d68..c452cd0 100644 --- a/easytier-core/src/instance/instance.rs +++ b/easytier-core/src/instance/instance.rs @@ -20,6 +20,7 @@ use crate::connector::manual::{ConnectorManagerRpcService, ManualConnectorManage use crate::connector::udp_hole_punch::UdpHolePunchConnector; use crate::gateway::icmp_proxy::IcmpProxy; use crate::gateway::tcp_proxy::TcpProxy; +use crate::gateway::udp_proxy::UdpProxy; use crate::peer_center::instance::PeerCenterInstance; use crate::peers::peer_manager::PeerManager; use crate::peers::rpc_service::PeerManagerRpcService; @@ -85,6 +86,7 @@ pub struct Instance { tcp_proxy: Arc, icmp_proxy: Arc, + udp_proxy: Arc, peer_center: Arc, @@ -143,6 +145,7 @@ impl Instance { let arc_tcp_proxy = TcpProxy::new(global_ctx.clone(), peer_manager.clone()); let arc_icmp_proxy = IcmpProxy::new(global_ctx.clone(), peer_manager.clone()).unwrap(); + let arc_udp_proxy = UdpProxy::new(global_ctx.clone(), peer_manager.clone()).unwrap(); let peer_center = Arc::new(PeerCenterInstance::new(peer_manager.clone())); @@ -162,6 +165,7 @@ impl Instance { tcp_proxy: arc_tcp_proxy, icmp_proxy: arc_icmp_proxy, + udp_proxy: arc_udp_proxy, peer_center, @@ -284,6 +288,7 @@ impl Instance { self.tcp_proxy.start().await.unwrap(); self.icmp_proxy.start().await.unwrap(); + self.udp_proxy.start().await.unwrap(); self.run_proxy_cidrs_route_updater(); self.udp_hole_puncher.lock().await.run().await?; diff --git a/easytier-core/src/peer_center/server.rs b/easytier-core/src/peer_center/server.rs index 9487282..3d82065 100644 --- a/easytier-core/src/peer_center/server.rs +++ b/easytier-core/src/peer_center/server.rs @@ -98,7 +98,7 @@ impl PeerCenterService for PeerCenterServer { peers: Option, digest: Digest, ) -> Result<(), Error> { - tracing::warn!("receive report_peers"); + tracing::info!("receive report_peers"); let data = get_global_data(self.my_node_id); let mut locked_data = data.write().await; diff --git a/easytier-core/src/tests/three_node.rs b/easytier-core/src/tests/three_node.rs index fbcec17..3381706 100644 --- a/easytier-core/src/tests/three_node.rs +++ b/easytier-core/src/tests/three_node.rs @@ -7,7 +7,7 @@ use crate::{ common::tests::_tunnel_pingpong_netns, ring_tunnel::RingTunnelConnector, tcp_tunnel::{TcpTunnelConnector, TcpTunnelListener}, - udp_tunnel::UdpTunnelConnector, + udp_tunnel::{UdpTunnelConnector, UdpTunnelListener}, }, }; @@ -225,3 +225,37 @@ pub async fn proxy_three_node_disconnect_test() { // TODO: add some traffic here, also should check route & peer list tokio::time::sleep(tokio::time::Duration::from_secs(35)).await; } + +#[tokio::test] +#[serial_test::serial] +pub async fn udp_proxy_three_node_test() { + let insts = init_three_node("tcp").await; + + insts[2] + .get_global_ctx() + .add_proxy_cidr("10.1.2.0/24".parse().unwrap()) + .unwrap(); + assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1); + + wait_proxy_route_appear( + &insts[0].get_peer_manager(), + "10.144.144.3", + insts[2].id(), + "10.1.2.0/24", + ) + .await; + + // wait updater + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + + let tcp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap()); + let tcp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap()); + + _tunnel_pingpong_netns( + tcp_listener, + tcp_connector, + NetNS::new(Some("net_d".into())), + NetNS::new(Some("net_a".into())), + ) + .await; +} diff --git a/easytier-core/src/tunnels/common.rs b/easytier-core/src/tunnels/common.rs index cae44d5..e8ef12e 100644 --- a/easytier-core/src/tunnels/common.rs +++ b/easytier-core/src/tunnels/common.rs @@ -286,6 +286,10 @@ pub(crate) fn setup_sokcet2( #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))] socket2_socket.set_reuse_port(true)?; + if bind_addr.ip().is_unspecified() { + return Ok(()); + } + // linux/mac does not use interface of bind_addr to send packet, so we need to bind device // win can handle this with bind correctly #[cfg(any(target_os = "ios", target_os = "macos"))] @@ -380,7 +384,7 @@ pub mod tests { let mut send = tunnel.pin_sink(); let mut recv = tunnel.pin_stream(); - let send_data = Bytes::from("abc"); + let send_data = Bytes::from("12345678abcdefg"); send.send(send_data).await.unwrap(); let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next()) .await @@ -388,7 +392,7 @@ pub mod tests { .unwrap() .unwrap(); println!("echo back: {:?}", ret); - assert_eq!(ret, Bytes::from("abc")); + assert_eq!(ret, Bytes::from("12345678abcdefg")); close_tunnel(&tunnel).await.unwrap();