correctly handle ip fragment for udp/icmp proxy

This commit is contained in:
sijie.sun 2024-06-08 09:23:20 +08:00
parent b2100b78d3
commit b8658852e2
8 changed files with 480 additions and 88 deletions

View File

@ -3,12 +3,13 @@ use std::{
net::{IpAddr, Ipv4Addr, SocketAddrV4},
sync::Arc,
thread,
time::Duration,
};
use pnet::packet::{
icmp::{self, IcmpTypes},
ip::IpNextHeaderProtocols,
ipv4::{self, Ipv4Packet, MutableIpv4Packet},
ipv4::Ipv4Packet,
Packet,
};
use socket2::Socket;
@ -25,7 +26,10 @@ use crate::{
tunnel::packet_def::{PacketType, ZCPacket},
};
use super::CidrSet;
use super::{
ip_reassembler::{compose_ipv4_packet, IpReassembler},
CidrSet,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct IcmpNatKey {
@ -68,6 +72,8 @@ pub struct IcmpProxy {
nat_table: IcmpNatTable,
tasks: Mutex<JoinSet<()>>,
ip_resemmbler: Arc<IpReassembler>,
}
fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, IpAddr), Error> {
@ -80,7 +86,7 @@ fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, I
}
fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSender<ZCPacket>) {
let mut buf = [0u8; 2048];
let mut buf = [0u8; 8192];
let data: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(&mut buf[..]) };
loop {
@ -92,7 +98,7 @@ fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSe
continue;
}
let Some(mut ipv4_packet) = MutableIpv4Packet::new(&mut buf[..len]) else {
let Some(ipv4_packet) = Ipv4Packet::new(&buf[..len]) else {
continue;
};
@ -120,24 +126,31 @@ fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSe
continue;
};
ipv4_packet.set_destination(dest_ip);
let src_v4 = ipv4_packet.get_source();
let payload_len = len - ipv4_packet.get_header_length() as usize * 4;
let id = ipv4_packet.get_identification();
let _ = compose_ipv4_packet(
&mut buf[..],
&src_v4,
&dest_ip,
IpNextHeaderProtocols::Icmp,
payload_len,
1200,
id,
|buf| {
let mut p = ZCPacket::new_with_payload(buf);
p.fill_peer_manager_hdr(
v.my_peer_id.into(),
v.src_peer_id.into(),
PacketType::Data as u8,
);
// MacOS do not correctly set ip length when receiving from raw socket
ipv4_packet.set_total_length(len as u16);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
let mut p = ZCPacket::new_with_payload(ipv4_packet.packet());
p.fill_peer_manager_hdr(
v.my_peer_id.into(),
v.src_peer_id.into(),
PacketType::Data as u8,
if let Err(e) = sender.send(p) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
}
Ok(())
},
);
if let Err(e) = sender.send(p) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
break;
}
}
}
@ -166,6 +179,8 @@ impl IcmpProxy {
nat_table: Arc::new(dashmap::DashMap::new()),
tasks: Mutex::new(JoinSet::new()),
ip_resemmbler: Arc::new(IpReassembler::new(Duration::from_secs(10))),
};
Ok(Arc::new(ret))
@ -226,6 +241,14 @@ impl IcmpProxy {
.instrument(tracing::info_span!("icmp proxy send loop")),
);
let ip_resembler = self.ip_resemmbler.clone();
self.tasks.lock().await.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
ip_resembler.remove_expired_packets();
}
});
self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone()))
.await;
@ -269,7 +292,18 @@ impl IcmpProxy {
return None;
}
let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?;
let resembled_buf: Option<Vec<u8>>;
let icmp_packet = if IpReassembler::is_packet_fragmented(&ipv4) {
resembled_buf =
self.ip_resemmbler
.add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4);
if resembled_buf.is_none() {
return None;
};
icmp::echo_request::EchoRequestPacket::new(resembled_buf.as_ref().unwrap())?
} else {
icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?
};
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
// drop it because we do not support other icmp types

View File

@ -0,0 +1,299 @@
use dashmap::DashMap;
use pnet::packet::ip::IpNextHeaderProtocol;
use pnet::packet::ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet};
use pnet::packet::Packet;
use std::net::Ipv4Addr;
use std::time::{Duration, Instant};
use crate::common::error::Error;
#[derive(Debug, Clone)]
pub(crate) struct IpFragment {
id: u16,
offset: u16,
data: Vec<u8>,
}
impl<'a> From<&Ipv4Packet<'a>> for IpFragment {
fn from(packet: &Ipv4Packet<'a>) -> Self {
let id = packet.get_identification();
let offset = packet.get_fragment_offset() * 8;
let data = packet.payload().to_vec();
IpFragment { id, offset, data }
}
}
#[derive(Debug, Clone)]
struct IpPacket {
source: Ipv4Addr,
destination: Ipv4Addr,
total_length: Option<u16>,
fragments: Vec<IpFragment>,
}
impl IpPacket {
fn new(source: Ipv4Addr, destination: Ipv4Addr) -> Self {
IpPacket {
source,
destination,
total_length: None,
fragments: Vec::new(),
}
}
fn add_fragment(&mut self, fragment: IpFragment) {
// make sure the fragment doesn't overlap with existing fragments
for f in &self.fragments {
if f.offset <= fragment.offset && fragment.offset < f.offset + f.data.len() as u16 {
return;
}
if fragment.offset <= f.offset
&& f.offset < fragment.offset + fragment.data.len() as u16
{
return;
}
}
self.fragments.push(fragment);
}
fn is_complete(&self) -> bool {
if self.total_length.is_none() {
return false;
}
let mut total_length = 0;
for fragment in &self.fragments {
total_length += fragment.data.len() as u16;
}
tracing::trace!(?total_length, ?self.total_length, "ip resembler checking is_complete");
Some(total_length) == self.total_length
}
fn set_total_length(&mut self, total_length: u16) {
self.total_length = Some(total_length);
}
fn assemble(&mut self) -> Option<Vec<u8>> {
if !self.is_complete() {
return None;
}
// sort fragments by offset
self.fragments.sort_by_key(|f| f.offset);
let mut packet = vec![0u8; self.total_length.unwrap() as usize];
for fragment in &self.fragments {
let start = fragment.offset as usize;
let end = start + fragment.data.len();
packet[start..end].copy_from_slice(&fragment.data);
}
Some(packet)
}
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
struct IpResemblerKey {
source: Ipv4Addr,
destination: Ipv4Addr,
id: u16,
}
#[derive(Debug)]
struct IpResemblerValue {
packet: IpPacket,
timestamp: Instant,
}
#[derive(Debug)]
pub(crate) struct IpReassembler {
packets: DashMap<IpResemblerKey, IpResemblerValue>,
timeout: Duration,
}
impl IpReassembler {
pub fn new(timeout: Duration) -> Self {
IpReassembler {
packets: DashMap::new(),
timeout,
}
}
pub fn is_packet_fragmented(packet: &Ipv4Packet) -> bool {
packet.get_fragment_offset() != 0 || packet.get_flags() & Ipv4Flags::MoreFragments != 0
}
pub fn is_last_fragment(packet: &Ipv4Packet) -> bool {
packet.get_flags() & Ipv4Flags::MoreFragments == 0
}
pub fn add_fragment(
&self,
source: Ipv4Addr,
destination: Ipv4Addr,
packet: &Ipv4Packet,
) -> Option<Vec<u8>> {
let id = packet.get_identification();
let total_length = packet.get_total_length() - packet.get_header_length() as u16 * 4;
if total_length != packet.payload().len() as u16 {
tracing::trace!(
?packet,
?total_length,
payload_len = ?packet.payload().len(),
"unexpected total length",
);
return None;
}
let fragment: IpFragment = packet.into();
let key = IpResemblerKey {
source,
destination,
id,
};
let mut entry = self.packets.entry(key.clone()).or_insert_with(|| {
let packet = IpPacket::new(source, destination);
let timestamp = Instant::now();
IpResemblerValue { packet, timestamp }
});
let value_mut = entry.value_mut();
if Self::is_last_fragment(packet) {
value_mut
.packet
.set_total_length(total_length + fragment.offset);
}
value_mut.packet.add_fragment(fragment);
if let Some(data) = value_mut.packet.assemble() {
drop(entry);
self.packets.remove(&key);
Some(data)
} else {
value_mut.timestamp = Instant::now();
None
}
}
pub fn remove_expired_packets(&self) {
let timeout = self.timeout;
self.packets.retain(|_, v| v.timestamp.elapsed() <= timeout);
}
}
// ip payload should be in buf[20..]
pub fn compose_ipv4_packet<F>(
buf: &mut [u8],
src_v4: &Ipv4Addr,
dst_v4: &Ipv4Addr,
next_protocol: IpNextHeaderProtocol,
payload_len: usize,
payload_mtu: usize,
ip_id: u16,
cb: F,
) -> Result<(), Error>
where
F: Fn(&[u8]) -> Result<(), Error>,
{
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
let mut buf_offset = 0;
let mut fragment_offset = 0;
let mut cur_piece = 0;
while fragment_offset < payload_len {
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
let fragment_len = next_fragment_offset - fragment_offset;
let mut ipv4_packet =
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20]).unwrap();
ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5);
ipv4_packet.set_total_length((fragment_len + 20) as u16);
ipv4_packet.set_identification(ip_id);
if total_pieces > 1 {
if cur_piece != total_pieces - 1 {
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
} else {
ipv4_packet.set_flags(0);
}
assert_eq!(0, fragment_offset % 8);
ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8);
} else {
ipv4_packet.set_flags(Ipv4Flags::DontFragment);
ipv4_packet.set_fragment_offset(0);
}
ipv4_packet.set_ecn(0);
ipv4_packet.set_dscp(0);
ipv4_packet.set_ttl(32);
ipv4_packet.set_source(src_v4.clone());
ipv4_packet.set_destination(dst_v4.clone());
ipv4_packet.set_next_level_protocol(next_protocol);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
tracing::trace!(?ipv4_packet, "udp nat packet response send");
cb(ipv4_packet.packet())?;
buf_offset += next_fragment_offset - fragment_offset;
fragment_offset = next_fragment_offset;
cur_piece += 1;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resembler() {
let raw_packets = vec![
// last packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x01, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x04, 0x05, 0x06, 0x07, 0x04, 0x05, 0x06, 0x07,
],
// 1st packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x00, 0x02, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x08, 0x09, 0x0a, 0x0b, 0x04, 0x05, 0x06, 0x07,
],
// 2nd packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
],
// expired packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x47, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
],
];
let source = "192.168.0.1".parse().unwrap();
let destination = "192.168.0.2".parse().unwrap();
let resembler = IpReassembler::new(Duration::from_secs(1));
for (idx, raw_packet) in raw_packets.iter().enumerate() {
if let Some(packet) = Ipv4Packet::new(&raw_packet) {
let ret = resembler.add_fragment(source, destination, &packet);
if idx != 2 {
assert!(ret.is_none());
} else {
assert!(ret.is_some());
}
println!(
"packet: {:?}, ret: {:?}, palyload_len: {}",
packet,
ret,
packet.payload().len()
);
}
}
resembler.remove_expired_packets();
assert_eq!(1, resembler.packets.len());
std::thread::sleep(Duration::from_secs(2));
resembler.remove_expired_packets();
assert_eq!(0, resembler.packets.len());
}
}

View File

@ -4,6 +4,7 @@ use tokio::task::JoinSet;
use crate::common::global_ctx::ArcGlobalCtx;
pub mod icmp_proxy;
pub mod ip_reassembler;
pub mod tcp_proxy;
pub mod udp_proxy;

View File

@ -7,7 +7,7 @@ use std::{
use dashmap::DashMap;
use pnet::packet::{
ip::IpNextHeaderProtocols,
ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet},
ipv4::Ipv4Packet,
udp::{self, MutableUdpPacket},
Packet,
};
@ -25,6 +25,7 @@ use tracing::Level;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
gateway::ip_reassembler::compose_ipv4_packet,
peers::{peer_manager::PeerManager, PeerPacketFilter},
tunnel::{
common::setup_sokcet2,
@ -32,7 +33,7 @@ use crate::{
},
};
use super::CidrSet;
use super::{ip_reassembler::IpReassembler, CidrSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct UdpNatKey {
@ -105,60 +106,31 @@ impl UdpNatEntry {
nat_src_v4.ip(),
));
let payload_len = payload_len + 8; // include udp header
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
let mut buf_offset = 0;
let mut fragment_offset = 0;
let mut cur_piece = 0;
while fragment_offset < payload_len {
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
let fragment_len = next_fragment_offset - fragment_offset;
let mut ipv4_packet =
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20])
.unwrap();
ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5);
ipv4_packet.set_total_length((fragment_len + 20) as u16);
ipv4_packet.set_identification(ip_id);
if total_pieces > 1 {
if cur_piece != total_pieces - 1 {
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
} else {
ipv4_packet.set_flags(0);
compose_ipv4_packet(
&mut buf[..],
src_v4.ip(),
nat_src_v4.ip(),
IpNextHeaderProtocols::Udp,
payload_len + 8, // include udp header
payload_mtu,
ip_id,
|buf| {
let mut p = ZCPacket::new_with_payload(buf);
p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8);
if let Err(e) = packet_sender.send(p) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
return Err(Error::AnyhowError(e.into()));
}
assert_eq!(0, fragment_offset % 8);
ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8);
} else {
ipv4_packet.set_flags(Ipv4Flags::DontFragment);
ipv4_packet.set_fragment_offset(0);
}
ipv4_packet.set_ecn(0);
ipv4_packet.set_dscp(0);
ipv4_packet.set_ttl(32);
ipv4_packet.set_source(src_v4.ip().clone());
ipv4_packet.set_destination(nat_src_v4.ip().clone());
ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Udp);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
Ok(())
},
)?;
tracing::trace!(?ipv4_packet, "udp nat packet response send");
let mut p = ZCPacket::new_with_payload(ipv4_packet.packet());
p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8);
if let Err(e) = packet_sender.send(p) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
return Err(Error::AnyhowError(e.into()));
}
buf_offset += next_fragment_offset - fragment_offset;
fragment_offset = next_fragment_offset;
cur_piece += 1;
}
Ok(())
}
async fn forward_task(self: Arc<Self>, mut packet_sender: UnboundedSender<ZCPacket>) {
let mut buf = [0u8; 8192];
let mut buf = [0u8; 65536];
let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) };
let mut ip_id = 1;
@ -223,6 +195,8 @@ pub struct UdpProxy {
receiver: Mutex<Option<UnboundedReceiver<ZCPacket>>>,
tasks: Mutex<JoinSet<()>>,
ip_resemmbler: Arc<IpReassembler>,
}
impl UdpProxy {
@ -247,7 +221,18 @@ impl UdpProxy {
return None;
}
let udp_packet = udp::UdpPacket::new(ipv4.payload())?;
let resembled_buf: Option<Vec<u8>>;
let udp_packet = if IpReassembler::is_packet_fragmented(&ipv4) {
resembled_buf =
self.ip_resemmbler
.add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4);
if resembled_buf.is_none() {
return None;
};
udp::UdpPacket::new(resembled_buf.as_ref().unwrap())?
} else {
udp::UdpPacket::new(ipv4.payload())?
};
tracing::trace!(
?packet,
@ -336,6 +321,7 @@ impl UdpProxy {
sender,
receiver: Mutex::new(Some(receiver)),
tasks: Mutex::new(JoinSet::new()),
ip_resemmbler: Arc::new(IpReassembler::new(Duration::from_secs(10))),
};
Ok(Arc::new(ret))
}
@ -362,6 +348,14 @@ impl UdpProxy {
}
});
let ip_resembler = self.ip_resemmbler.clone();
self.tasks.lock().await.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
ip_resembler.remove_expired_packets();
}
});
// forward packets to peer manager
let mut receiver = self.receiver.lock().await.take().unwrap();
let peer_manager = self.peer_manager.clone();

View File

@ -292,7 +292,10 @@ impl VirtualNic {
config.platform(|config| {
config.skip_config(true);
config.guid(None);
config.ring_cap(Some(config.min_ring_cap() * 2));
config.ring_cap(Some(std::cmp::min(
config.min_ring_cap() * 32,
config.max_ring_cap(),
)));
});
}

View File

@ -136,7 +136,7 @@ pub async fn init_three_node(proto: &str) -> Vec<Instance> {
vec![inst1, inst2, inst3]
}
async fn ping_test(from_netns: &str, target_ip: &str) -> bool {
async fn ping_test(from_netns: &str, target_ip: &str, payload_size: Option<usize>) -> bool {
let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard();
let code = tokio::process::Command::new("ip")
.args(&[
@ -146,6 +146,8 @@ async fn ping_test(from_netns: &str, target_ip: &str) -> bool {
"ping",
"-c",
"1",
"-s",
payload_size.unwrap_or(56).to_string().as_str(),
"-W",
"1",
target_ip.to_string().as_str(),
@ -175,7 +177,7 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg", "ws", "wss")] pr
);
wait_for_condition(
|| async { ping_test("net_c", "10.144.144.1").await },
|| async { ping_test("net_c", "10.144.144.1", None).await },
Duration::from_secs(5000),
)
.await;
@ -185,6 +187,8 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg", "ws", "wss")] pr
#[tokio::test]
#[serial_test::serial]
pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
use rand::Rng;
use crate::tunnel::{common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener};
let mut insts = init_three_node(proto).await;
@ -210,11 +214,15 @@ pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str
let tcp_listener = TcpTunnelListener::new("tcp://10.1.2.4:22223".parse().unwrap());
let tcp_connector = TcpTunnelConnector::new("tcp://10.1.2.4:22223".parse().unwrap());
let mut buf = vec![0; 32];
rand::thread_rng().fill(&mut buf[..]);
_tunnel_pingpong_netns(
tcp_listener,
tcp_connector,
NetNS::new(Some("net_d".into())),
NetNS::new(Some("net_a".into())),
buf,
)
.await;
}
@ -241,7 +249,13 @@ pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &st
.await;
wait_for_condition(
|| async { ping_test("net_a", "10.1.2.4").await },
|| async { ping_test("net_a", "10.1.2.4", None).await },
Duration::from_secs(5),
)
.await;
wait_for_condition(
|| async { ping_test("net_a", "10.1.2.4", Some(5 * 1024)).await },
Duration::from_secs(5),
)
.await;
@ -318,6 +332,8 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str
#[tokio::test]
#[serial_test::serial]
pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
use rand::Rng;
use crate::tunnel::{common::tests::_tunnel_pingpong_netns, udp::UdpTunnelListener};
let mut insts = init_three_node(proto).await;
@ -343,11 +359,32 @@ pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str
let tcp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap());
let tcp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap());
// NOTE: this should not excced udp tunnel max buffer size
let mut buf = vec![0; 20 * 1024];
rand::thread_rng().fill(&mut buf[..]);
_tunnel_pingpong_netns(
tcp_listener,
tcp_connector,
NetNS::new(Some("net_d".into())),
NetNS::new(Some("net_a".into())),
buf,
)
.await;
// no fragment
let tcp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap());
let tcp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap());
let mut buf = vec![0; 1 * 1024];
rand::thread_rng().fill(&mut buf[..]);
_tunnel_pingpong_netns(
tcp_listener,
tcp_connector,
NetNS::new(Some("net_d".into())),
NetNS::new(Some("net_a".into())),
buf,
)
.await;
}
@ -443,7 +480,7 @@ pub async fn foreign_network_forward_nic_data() {
.await;
wait_for_condition(
|| async { ping_test("net_b", "10.144.145.2").await },
|| async { ping_test("net_b", "10.144.145.2", None).await },
Duration::from_secs(5),
)
.await;
@ -531,19 +568,19 @@ pub async fn wireguard_vpn_portal() {
// ping other node in network
wait_for_condition(
|| async { ping_test("net_d", "10.144.144.1").await },
|| async { ping_test("net_d", "10.144.144.1", None).await },
Duration::from_secs(5),
)
.await;
wait_for_condition(
|| async { ping_test("net_d", "10.144.144.2").await },
|| async { ping_test("net_d", "10.144.144.2", None).await },
Duration::from_secs(5),
)
.await;
// ping portal node
wait_for_condition(
|| async { ping_test("net_d", "10.144.144.3").await },
|| async { ping_test("net_d", "10.144.144.3", None).await },
Duration::from_secs(5),
)
.await;

View File

@ -107,7 +107,10 @@ impl<R> FramedReader<R> {
}
}
fn extract_one_packet(buf: &mut BytesMut) -> Option<ZCPacket> {
fn extract_one_packet(
buf: &mut BytesMut,
max_packet_size: usize,
) -> Option<Result<ZCPacket, TunnelError>> {
if buf.len() < TCP_TUNNEL_HEADER_SIZE {
// header is not complete
return None;
@ -115,6 +118,11 @@ impl<R> FramedReader<R> {
let header = TCPTunnelHeader::ref_from_prefix(&buf[..]).unwrap();
let body_len = header.len.get() as usize;
if body_len > max_packet_size {
// body is too long
return Some(Err(TunnelError::InvalidPacket("body too long".to_string())));
}
if buf.len() < TCP_TUNNEL_HEADER_SIZE + body_len {
// body is not complete
return None;
@ -122,7 +130,7 @@ impl<R> FramedReader<R> {
// extract one packet
let packet_buf = buf.split_to(TCP_TUNNEL_HEADER_SIZE + body_len);
Some(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP))
Some(Ok(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP)))
}
}
@ -139,8 +147,10 @@ where
let mut self_mut = self.project();
loop {
while let Some(packet) = Self::extract_one_packet(self_mut.buf) {
return Poll::Ready(Some(Ok(packet)));
while let Some(packet) =
Self::extract_one_packet(self_mut.buf, *self_mut.max_packet_size)
{
return Poll::Ready(Some(packet));
}
reserve_buf(
@ -465,7 +475,14 @@ pub mod tests {
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
_tunnel_pingpong_netns(
listener,
connector,
NetNS::new(None),
NetNS::new(None),
"12345678abcdefg".as_bytes().to_vec(),
)
.await;
}
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
@ -473,6 +490,7 @@ pub mod tests {
mut connector: C,
l_netns: NetNS,
c_netns: NetNS,
buf: Vec<u8>,
) where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
@ -503,7 +521,7 @@ pub mod tests {
let (mut recv, mut send) = tunnel.split();
send.send(ZCPacket::new_with_payload("12345678abcdefg".as_bytes()))
send.send(ZCPacket::new_with_payload(buf.as_slice()))
.await
.unwrap();
@ -513,7 +531,7 @@ pub mod tests {
.unwrap()
.unwrap();
println!("echo back: {:?}", ret);
assert_eq!(ret.payload(), Bytes::from("12345678abcdefg"));
assert_eq!(ret.payload(), Bytes::from(buf));
send.close().await.unwrap();

View File

@ -158,7 +158,13 @@ where
let mut buf = BytesMut::new();
loop {
reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 16);
let (dg_size, addr) = socket.recv_buf_from(&mut buf).await.unwrap();
let (dg_size, addr) = match socket.recv_buf_from(&mut buf).await {
Ok(v) => v,
Err(e) => {
tracing::error!(?e, "udp recv from socket error");
break;
}
};
tracing::trace!(
"udp recv packet: {:?}, buf: {:?}, size: {}",
addr,