mirror of
https://github.com/EasyTier/EasyTier.git
synced 2024-11-16 11:42:27 +08:00
correctly handle ip fragment for udp/icmp proxy
This commit is contained in:
parent
b2100b78d3
commit
b8658852e2
|
@ -3,12 +3,13 @@ use std::{
|
|||
net::{IpAddr, Ipv4Addr, SocketAddrV4},
|
||||
sync::Arc,
|
||||
thread,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use pnet::packet::{
|
||||
icmp::{self, IcmpTypes},
|
||||
ip::IpNextHeaderProtocols,
|
||||
ipv4::{self, Ipv4Packet, MutableIpv4Packet},
|
||||
ipv4::Ipv4Packet,
|
||||
Packet,
|
||||
};
|
||||
use socket2::Socket;
|
||||
|
@ -25,7 +26,10 @@ use crate::{
|
|||
tunnel::packet_def::{PacketType, ZCPacket},
|
||||
};
|
||||
|
||||
use super::CidrSet;
|
||||
use super::{
|
||||
ip_reassembler::{compose_ipv4_packet, IpReassembler},
|
||||
CidrSet,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct IcmpNatKey {
|
||||
|
@ -68,6 +72,8 @@ pub struct IcmpProxy {
|
|||
nat_table: IcmpNatTable,
|
||||
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
|
||||
ip_resemmbler: Arc<IpReassembler>,
|
||||
}
|
||||
|
||||
fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, IpAddr), Error> {
|
||||
|
@ -80,7 +86,7 @@ fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, I
|
|||
}
|
||||
|
||||
fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSender<ZCPacket>) {
|
||||
let mut buf = [0u8; 2048];
|
||||
let mut buf = [0u8; 8192];
|
||||
let data: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(&mut buf[..]) };
|
||||
|
||||
loop {
|
||||
|
@ -92,7 +98,7 @@ fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSe
|
|||
continue;
|
||||
}
|
||||
|
||||
let Some(mut ipv4_packet) = MutableIpv4Packet::new(&mut buf[..len]) else {
|
||||
let Some(ipv4_packet) = Ipv4Packet::new(&buf[..len]) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
|
@ -120,24 +126,31 @@ fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSe
|
|||
continue;
|
||||
};
|
||||
|
||||
ipv4_packet.set_destination(dest_ip);
|
||||
let src_v4 = ipv4_packet.get_source();
|
||||
let payload_len = len - ipv4_packet.get_header_length() as usize * 4;
|
||||
let id = ipv4_packet.get_identification();
|
||||
let _ = compose_ipv4_packet(
|
||||
&mut buf[..],
|
||||
&src_v4,
|
||||
&dest_ip,
|
||||
IpNextHeaderProtocols::Icmp,
|
||||
payload_len,
|
||||
1200,
|
||||
id,
|
||||
|buf| {
|
||||
let mut p = ZCPacket::new_with_payload(buf);
|
||||
p.fill_peer_manager_hdr(
|
||||
v.my_peer_id.into(),
|
||||
v.src_peer_id.into(),
|
||||
PacketType::Data as u8,
|
||||
);
|
||||
|
||||
// MacOS do not correctly set ip length when receiving from raw socket
|
||||
ipv4_packet.set_total_length(len as u16);
|
||||
|
||||
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||
|
||||
let mut p = ZCPacket::new_with_payload(ipv4_packet.packet());
|
||||
p.fill_peer_manager_hdr(
|
||||
v.my_peer_id.into(),
|
||||
v.src_peer_id.into(),
|
||||
PacketType::Data as u8,
|
||||
if let Err(e) = sender.send(p) {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
if let Err(e) = sender.send(p) {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -166,6 +179,8 @@ impl IcmpProxy {
|
|||
|
||||
nat_table: Arc::new(dashmap::DashMap::new()),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
|
||||
ip_resemmbler: Arc::new(IpReassembler::new(Duration::from_secs(10))),
|
||||
};
|
||||
|
||||
Ok(Arc::new(ret))
|
||||
|
@ -226,6 +241,14 @@ impl IcmpProxy {
|
|||
.instrument(tracing::info_span!("icmp proxy send loop")),
|
||||
);
|
||||
|
||||
let ip_resembler = self.ip_resemmbler.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
ip_resembler.remove_expired_packets();
|
||||
}
|
||||
});
|
||||
|
||||
self.peer_manager
|
||||
.add_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
|
@ -269,7 +292,18 @@ impl IcmpProxy {
|
|||
return None;
|
||||
}
|
||||
|
||||
let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?;
|
||||
let resembled_buf: Option<Vec<u8>>;
|
||||
let icmp_packet = if IpReassembler::is_packet_fragmented(&ipv4) {
|
||||
resembled_buf =
|
||||
self.ip_resemmbler
|
||||
.add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4);
|
||||
if resembled_buf.is_none() {
|
||||
return None;
|
||||
};
|
||||
icmp::echo_request::EchoRequestPacket::new(resembled_buf.as_ref().unwrap())?
|
||||
} else {
|
||||
icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?
|
||||
};
|
||||
|
||||
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
|
||||
// drop it because we do not support other icmp types
|
||||
|
|
299
easytier/src/gateway/ip_reassembler.rs
Normal file
299
easytier/src/gateway/ip_reassembler.rs
Normal file
|
@ -0,0 +1,299 @@
|
|||
use dashmap::DashMap;
|
||||
use pnet::packet::ip::IpNextHeaderProtocol;
|
||||
use pnet::packet::ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet};
|
||||
use pnet::packet::Packet;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::common::error::Error;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct IpFragment {
|
||||
id: u16,
|
||||
offset: u16,
|
||||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<'a> From<&Ipv4Packet<'a>> for IpFragment {
|
||||
fn from(packet: &Ipv4Packet<'a>) -> Self {
|
||||
let id = packet.get_identification();
|
||||
let offset = packet.get_fragment_offset() * 8;
|
||||
let data = packet.payload().to_vec();
|
||||
IpFragment { id, offset, data }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct IpPacket {
|
||||
source: Ipv4Addr,
|
||||
destination: Ipv4Addr,
|
||||
total_length: Option<u16>,
|
||||
fragments: Vec<IpFragment>,
|
||||
}
|
||||
|
||||
impl IpPacket {
|
||||
fn new(source: Ipv4Addr, destination: Ipv4Addr) -> Self {
|
||||
IpPacket {
|
||||
source,
|
||||
destination,
|
||||
total_length: None,
|
||||
fragments: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_fragment(&mut self, fragment: IpFragment) {
|
||||
// make sure the fragment doesn't overlap with existing fragments
|
||||
for f in &self.fragments {
|
||||
if f.offset <= fragment.offset && fragment.offset < f.offset + f.data.len() as u16 {
|
||||
return;
|
||||
}
|
||||
if fragment.offset <= f.offset
|
||||
&& f.offset < fragment.offset + fragment.data.len() as u16
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
self.fragments.push(fragment);
|
||||
}
|
||||
|
||||
fn is_complete(&self) -> bool {
|
||||
if self.total_length.is_none() {
|
||||
return false;
|
||||
}
|
||||
let mut total_length = 0;
|
||||
for fragment in &self.fragments {
|
||||
total_length += fragment.data.len() as u16;
|
||||
}
|
||||
tracing::trace!(?total_length, ?self.total_length, "ip resembler checking is_complete");
|
||||
Some(total_length) == self.total_length
|
||||
}
|
||||
|
||||
fn set_total_length(&mut self, total_length: u16) {
|
||||
self.total_length = Some(total_length);
|
||||
}
|
||||
|
||||
fn assemble(&mut self) -> Option<Vec<u8>> {
|
||||
if !self.is_complete() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// sort fragments by offset
|
||||
self.fragments.sort_by_key(|f| f.offset);
|
||||
|
||||
let mut packet = vec![0u8; self.total_length.unwrap() as usize];
|
||||
for fragment in &self.fragments {
|
||||
let start = fragment.offset as usize;
|
||||
let end = start + fragment.data.len();
|
||||
packet[start..end].copy_from_slice(&fragment.data);
|
||||
}
|
||||
|
||||
Some(packet)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
|
||||
struct IpResemblerKey {
|
||||
source: Ipv4Addr,
|
||||
destination: Ipv4Addr,
|
||||
id: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct IpResemblerValue {
|
||||
packet: IpPacket,
|
||||
timestamp: Instant,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct IpReassembler {
|
||||
packets: DashMap<IpResemblerKey, IpResemblerValue>,
|
||||
timeout: Duration,
|
||||
}
|
||||
|
||||
impl IpReassembler {
|
||||
pub fn new(timeout: Duration) -> Self {
|
||||
IpReassembler {
|
||||
packets: DashMap::new(),
|
||||
timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_packet_fragmented(packet: &Ipv4Packet) -> bool {
|
||||
packet.get_fragment_offset() != 0 || packet.get_flags() & Ipv4Flags::MoreFragments != 0
|
||||
}
|
||||
|
||||
pub fn is_last_fragment(packet: &Ipv4Packet) -> bool {
|
||||
packet.get_flags() & Ipv4Flags::MoreFragments == 0
|
||||
}
|
||||
|
||||
pub fn add_fragment(
|
||||
&self,
|
||||
source: Ipv4Addr,
|
||||
destination: Ipv4Addr,
|
||||
packet: &Ipv4Packet,
|
||||
) -> Option<Vec<u8>> {
|
||||
let id = packet.get_identification();
|
||||
let total_length = packet.get_total_length() - packet.get_header_length() as u16 * 4;
|
||||
if total_length != packet.payload().len() as u16 {
|
||||
tracing::trace!(
|
||||
?packet,
|
||||
?total_length,
|
||||
payload_len = ?packet.payload().len(),
|
||||
"unexpected total length",
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let fragment: IpFragment = packet.into();
|
||||
let key = IpResemblerKey {
|
||||
source,
|
||||
destination,
|
||||
id,
|
||||
};
|
||||
|
||||
let mut entry = self.packets.entry(key.clone()).or_insert_with(|| {
|
||||
let packet = IpPacket::new(source, destination);
|
||||
let timestamp = Instant::now();
|
||||
IpResemblerValue { packet, timestamp }
|
||||
});
|
||||
let value_mut = entry.value_mut();
|
||||
|
||||
if Self::is_last_fragment(packet) {
|
||||
value_mut
|
||||
.packet
|
||||
.set_total_length(total_length + fragment.offset);
|
||||
}
|
||||
|
||||
value_mut.packet.add_fragment(fragment);
|
||||
if let Some(data) = value_mut.packet.assemble() {
|
||||
drop(entry);
|
||||
self.packets.remove(&key);
|
||||
Some(data)
|
||||
} else {
|
||||
value_mut.timestamp = Instant::now();
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_expired_packets(&self) {
|
||||
let timeout = self.timeout;
|
||||
self.packets.retain(|_, v| v.timestamp.elapsed() <= timeout);
|
||||
}
|
||||
}
|
||||
|
||||
// ip payload should be in buf[20..]
|
||||
pub fn compose_ipv4_packet<F>(
|
||||
buf: &mut [u8],
|
||||
src_v4: &Ipv4Addr,
|
||||
dst_v4: &Ipv4Addr,
|
||||
next_protocol: IpNextHeaderProtocol,
|
||||
payload_len: usize,
|
||||
payload_mtu: usize,
|
||||
ip_id: u16,
|
||||
cb: F,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
F: Fn(&[u8]) -> Result<(), Error>,
|
||||
{
|
||||
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
|
||||
let mut buf_offset = 0;
|
||||
let mut fragment_offset = 0;
|
||||
let mut cur_piece = 0;
|
||||
while fragment_offset < payload_len {
|
||||
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
|
||||
let fragment_len = next_fragment_offset - fragment_offset;
|
||||
let mut ipv4_packet =
|
||||
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20]).unwrap();
|
||||
ipv4_packet.set_version(4);
|
||||
ipv4_packet.set_header_length(5);
|
||||
ipv4_packet.set_total_length((fragment_len + 20) as u16);
|
||||
ipv4_packet.set_identification(ip_id);
|
||||
if total_pieces > 1 {
|
||||
if cur_piece != total_pieces - 1 {
|
||||
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
|
||||
} else {
|
||||
ipv4_packet.set_flags(0);
|
||||
}
|
||||
assert_eq!(0, fragment_offset % 8);
|
||||
ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8);
|
||||
} else {
|
||||
ipv4_packet.set_flags(Ipv4Flags::DontFragment);
|
||||
ipv4_packet.set_fragment_offset(0);
|
||||
}
|
||||
ipv4_packet.set_ecn(0);
|
||||
ipv4_packet.set_dscp(0);
|
||||
ipv4_packet.set_ttl(32);
|
||||
ipv4_packet.set_source(src_v4.clone());
|
||||
ipv4_packet.set_destination(dst_v4.clone());
|
||||
ipv4_packet.set_next_level_protocol(next_protocol);
|
||||
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||
|
||||
tracing::trace!(?ipv4_packet, "udp nat packet response send");
|
||||
|
||||
cb(ipv4_packet.packet())?;
|
||||
|
||||
buf_offset += next_fragment_offset - fragment_offset;
|
||||
fragment_offset = next_fragment_offset;
|
||||
cur_piece += 1;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn resembler() {
|
||||
let raw_packets = vec![
|
||||
// last packet
|
||||
vec![
|
||||
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x01, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
|
||||
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x04, 0x05, 0x06, 0x07, 0x04, 0x05, 0x06, 0x07,
|
||||
],
|
||||
// 1st packet
|
||||
vec![
|
||||
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x00, 0x02, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
|
||||
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x08, 0x09, 0x0a, 0x0b, 0x04, 0x05, 0x06, 0x07,
|
||||
],
|
||||
// 2nd packet
|
||||
vec![
|
||||
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
|
||||
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
],
|
||||
// expired packet
|
||||
vec![
|
||||
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x47, 0x20, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
|
||||
0x00, 0x01, 0xc0, 0xa8, 0x00, 0x02, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
],
|
||||
];
|
||||
|
||||
let source = "192.168.0.1".parse().unwrap();
|
||||
let destination = "192.168.0.2".parse().unwrap();
|
||||
let resembler = IpReassembler::new(Duration::from_secs(1));
|
||||
|
||||
for (idx, raw_packet) in raw_packets.iter().enumerate() {
|
||||
if let Some(packet) = Ipv4Packet::new(&raw_packet) {
|
||||
let ret = resembler.add_fragment(source, destination, &packet);
|
||||
if idx != 2 {
|
||||
assert!(ret.is_none());
|
||||
} else {
|
||||
assert!(ret.is_some());
|
||||
}
|
||||
println!(
|
||||
"packet: {:?}, ret: {:?}, palyload_len: {}",
|
||||
packet,
|
||||
ret,
|
||||
packet.payload().len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
resembler.remove_expired_packets();
|
||||
assert_eq!(1, resembler.packets.len());
|
||||
|
||||
std::thread::sleep(Duration::from_secs(2));
|
||||
resembler.remove_expired_packets();
|
||||
assert_eq!(0, resembler.packets.len());
|
||||
}
|
||||
}
|
|
@ -4,6 +4,7 @@ use tokio::task::JoinSet;
|
|||
use crate::common::global_ctx::ArcGlobalCtx;
|
||||
|
||||
pub mod icmp_proxy;
|
||||
pub mod ip_reassembler;
|
||||
pub mod tcp_proxy;
|
||||
pub mod udp_proxy;
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ use std::{
|
|||
use dashmap::DashMap;
|
||||
use pnet::packet::{
|
||||
ip::IpNextHeaderProtocols,
|
||||
ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet},
|
||||
ipv4::Ipv4Packet,
|
||||
udp::{self, MutableUdpPacket},
|
||||
Packet,
|
||||
};
|
||||
|
@ -25,6 +25,7 @@ use tracing::Level;
|
|||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
||||
gateway::ip_reassembler::compose_ipv4_packet,
|
||||
peers::{peer_manager::PeerManager, PeerPacketFilter},
|
||||
tunnel::{
|
||||
common::setup_sokcet2,
|
||||
|
@ -32,7 +33,7 @@ use crate::{
|
|||
},
|
||||
};
|
||||
|
||||
use super::CidrSet;
|
||||
use super::{ip_reassembler::IpReassembler, CidrSet};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct UdpNatKey {
|
||||
|
@ -105,60 +106,31 @@ impl UdpNatEntry {
|
|||
nat_src_v4.ip(),
|
||||
));
|
||||
|
||||
let payload_len = payload_len + 8; // include udp header
|
||||
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
|
||||
let mut buf_offset = 0;
|
||||
let mut fragment_offset = 0;
|
||||
let mut cur_piece = 0;
|
||||
while fragment_offset < payload_len {
|
||||
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
|
||||
let fragment_len = next_fragment_offset - fragment_offset;
|
||||
let mut ipv4_packet =
|
||||
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20])
|
||||
.unwrap();
|
||||
ipv4_packet.set_version(4);
|
||||
ipv4_packet.set_header_length(5);
|
||||
ipv4_packet.set_total_length((fragment_len + 20) as u16);
|
||||
ipv4_packet.set_identification(ip_id);
|
||||
if total_pieces > 1 {
|
||||
if cur_piece != total_pieces - 1 {
|
||||
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
|
||||
} else {
|
||||
ipv4_packet.set_flags(0);
|
||||
compose_ipv4_packet(
|
||||
&mut buf[..],
|
||||
src_v4.ip(),
|
||||
nat_src_v4.ip(),
|
||||
IpNextHeaderProtocols::Udp,
|
||||
payload_len + 8, // include udp header
|
||||
payload_mtu,
|
||||
ip_id,
|
||||
|buf| {
|
||||
let mut p = ZCPacket::new_with_payload(buf);
|
||||
p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8);
|
||||
|
||||
if let Err(e) = packet_sender.send(p) {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
||||
return Err(Error::AnyhowError(e.into()));
|
||||
}
|
||||
assert_eq!(0, fragment_offset % 8);
|
||||
ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8);
|
||||
} else {
|
||||
ipv4_packet.set_flags(Ipv4Flags::DontFragment);
|
||||
ipv4_packet.set_fragment_offset(0);
|
||||
}
|
||||
ipv4_packet.set_ecn(0);
|
||||
ipv4_packet.set_dscp(0);
|
||||
ipv4_packet.set_ttl(32);
|
||||
ipv4_packet.set_source(src_v4.ip().clone());
|
||||
ipv4_packet.set_destination(nat_src_v4.ip().clone());
|
||||
ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Udp);
|
||||
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
||||
tracing::trace!(?ipv4_packet, "udp nat packet response send");
|
||||
|
||||
let mut p = ZCPacket::new_with_payload(ipv4_packet.packet());
|
||||
p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8);
|
||||
|
||||
if let Err(e) = packet_sender.send(p) {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
||||
return Err(Error::AnyhowError(e.into()));
|
||||
}
|
||||
|
||||
buf_offset += next_fragment_offset - fragment_offset;
|
||||
fragment_offset = next_fragment_offset;
|
||||
cur_piece += 1;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn forward_task(self: Arc<Self>, mut packet_sender: UnboundedSender<ZCPacket>) {
|
||||
let mut buf = [0u8; 8192];
|
||||
let mut buf = [0u8; 65536];
|
||||
let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) };
|
||||
let mut ip_id = 1;
|
||||
|
||||
|
@ -223,6 +195,8 @@ pub struct UdpProxy {
|
|||
receiver: Mutex<Option<UnboundedReceiver<ZCPacket>>>,
|
||||
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
|
||||
ip_resemmbler: Arc<IpReassembler>,
|
||||
}
|
||||
|
||||
impl UdpProxy {
|
||||
|
@ -247,7 +221,18 @@ impl UdpProxy {
|
|||
return None;
|
||||
}
|
||||
|
||||
let udp_packet = udp::UdpPacket::new(ipv4.payload())?;
|
||||
let resembled_buf: Option<Vec<u8>>;
|
||||
let udp_packet = if IpReassembler::is_packet_fragmented(&ipv4) {
|
||||
resembled_buf =
|
||||
self.ip_resemmbler
|
||||
.add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4);
|
||||
if resembled_buf.is_none() {
|
||||
return None;
|
||||
};
|
||||
udp::UdpPacket::new(resembled_buf.as_ref().unwrap())?
|
||||
} else {
|
||||
udp::UdpPacket::new(ipv4.payload())?
|
||||
};
|
||||
|
||||
tracing::trace!(
|
||||
?packet,
|
||||
|
@ -336,6 +321,7 @@ impl UdpProxy {
|
|||
sender,
|
||||
receiver: Mutex::new(Some(receiver)),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
ip_resemmbler: Arc::new(IpReassembler::new(Duration::from_secs(10))),
|
||||
};
|
||||
Ok(Arc::new(ret))
|
||||
}
|
||||
|
@ -362,6 +348,14 @@ impl UdpProxy {
|
|||
}
|
||||
});
|
||||
|
||||
let ip_resembler = self.ip_resemmbler.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
ip_resembler.remove_expired_packets();
|
||||
}
|
||||
});
|
||||
|
||||
// forward packets to peer manager
|
||||
let mut receiver = self.receiver.lock().await.take().unwrap();
|
||||
let peer_manager = self.peer_manager.clone();
|
||||
|
|
|
@ -292,7 +292,10 @@ impl VirtualNic {
|
|||
config.platform(|config| {
|
||||
config.skip_config(true);
|
||||
config.guid(None);
|
||||
config.ring_cap(Some(config.min_ring_cap() * 2));
|
||||
config.ring_cap(Some(std::cmp::min(
|
||||
config.min_ring_cap() * 32,
|
||||
config.max_ring_cap(),
|
||||
)));
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -136,7 +136,7 @@ pub async fn init_three_node(proto: &str) -> Vec<Instance> {
|
|||
vec![inst1, inst2, inst3]
|
||||
}
|
||||
|
||||
async fn ping_test(from_netns: &str, target_ip: &str) -> bool {
|
||||
async fn ping_test(from_netns: &str, target_ip: &str, payload_size: Option<usize>) -> bool {
|
||||
let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard();
|
||||
let code = tokio::process::Command::new("ip")
|
||||
.args(&[
|
||||
|
@ -146,6 +146,8 @@ async fn ping_test(from_netns: &str, target_ip: &str) -> bool {
|
|||
"ping",
|
||||
"-c",
|
||||
"1",
|
||||
"-s",
|
||||
payload_size.unwrap_or(56).to_string().as_str(),
|
||||
"-W",
|
||||
"1",
|
||||
target_ip.to_string().as_str(),
|
||||
|
@ -175,7 +177,7 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg", "ws", "wss")] pr
|
|||
);
|
||||
|
||||
wait_for_condition(
|
||||
|| async { ping_test("net_c", "10.144.144.1").await },
|
||||
|| async { ping_test("net_c", "10.144.144.1", None).await },
|
||||
Duration::from_secs(5000),
|
||||
)
|
||||
.await;
|
||||
|
@ -185,6 +187,8 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg", "ws", "wss")] pr
|
|||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
||||
use rand::Rng;
|
||||
|
||||
use crate::tunnel::{common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener};
|
||||
|
||||
let mut insts = init_three_node(proto).await;
|
||||
|
@ -210,11 +214,15 @@ pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str
|
|||
let tcp_listener = TcpTunnelListener::new("tcp://10.1.2.4:22223".parse().unwrap());
|
||||
let tcp_connector = TcpTunnelConnector::new("tcp://10.1.2.4:22223".parse().unwrap());
|
||||
|
||||
let mut buf = vec![0; 32];
|
||||
rand::thread_rng().fill(&mut buf[..]);
|
||||
|
||||
_tunnel_pingpong_netns(
|
||||
tcp_listener,
|
||||
tcp_connector,
|
||||
NetNS::new(Some("net_d".into())),
|
||||
NetNS::new(Some("net_a".into())),
|
||||
buf,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
@ -241,7 +249,13 @@ pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &st
|
|||
.await;
|
||||
|
||||
wait_for_condition(
|
||||
|| async { ping_test("net_a", "10.1.2.4").await },
|
||||
|| async { ping_test("net_a", "10.1.2.4", None).await },
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
wait_for_condition(
|
||||
|| async { ping_test("net_a", "10.1.2.4", Some(5 * 1024)).await },
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
@ -318,6 +332,8 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str
|
|||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
||||
use rand::Rng;
|
||||
|
||||
use crate::tunnel::{common::tests::_tunnel_pingpong_netns, udp::UdpTunnelListener};
|
||||
|
||||
let mut insts = init_three_node(proto).await;
|
||||
|
@ -343,11 +359,32 @@ pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str
|
|||
let tcp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap());
|
||||
let tcp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap());
|
||||
|
||||
// NOTE: this should not excced udp tunnel max buffer size
|
||||
let mut buf = vec![0; 20 * 1024];
|
||||
rand::thread_rng().fill(&mut buf[..]);
|
||||
|
||||
_tunnel_pingpong_netns(
|
||||
tcp_listener,
|
||||
tcp_connector,
|
||||
NetNS::new(Some("net_d".into())),
|
||||
NetNS::new(Some("net_a".into())),
|
||||
buf,
|
||||
)
|
||||
.await;
|
||||
|
||||
// no fragment
|
||||
let tcp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap());
|
||||
let tcp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap());
|
||||
|
||||
let mut buf = vec![0; 1 * 1024];
|
||||
rand::thread_rng().fill(&mut buf[..]);
|
||||
|
||||
_tunnel_pingpong_netns(
|
||||
tcp_listener,
|
||||
tcp_connector,
|
||||
NetNS::new(Some("net_d".into())),
|
||||
NetNS::new(Some("net_a".into())),
|
||||
buf,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
@ -443,7 +480,7 @@ pub async fn foreign_network_forward_nic_data() {
|
|||
.await;
|
||||
|
||||
wait_for_condition(
|
||||
|| async { ping_test("net_b", "10.144.145.2").await },
|
||||
|| async { ping_test("net_b", "10.144.145.2", None).await },
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
@ -531,19 +568,19 @@ pub async fn wireguard_vpn_portal() {
|
|||
|
||||
// ping other node in network
|
||||
wait_for_condition(
|
||||
|| async { ping_test("net_d", "10.144.144.1").await },
|
||||
|| async { ping_test("net_d", "10.144.144.1", None).await },
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
wait_for_condition(
|
||||
|| async { ping_test("net_d", "10.144.144.2").await },
|
||||
|| async { ping_test("net_d", "10.144.144.2", None).await },
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
// ping portal node
|
||||
wait_for_condition(
|
||||
|| async { ping_test("net_d", "10.144.144.3").await },
|
||||
|| async { ping_test("net_d", "10.144.144.3", None).await },
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
|
|
@ -107,7 +107,10 @@ impl<R> FramedReader<R> {
|
|||
}
|
||||
}
|
||||
|
||||
fn extract_one_packet(buf: &mut BytesMut) -> Option<ZCPacket> {
|
||||
fn extract_one_packet(
|
||||
buf: &mut BytesMut,
|
||||
max_packet_size: usize,
|
||||
) -> Option<Result<ZCPacket, TunnelError>> {
|
||||
if buf.len() < TCP_TUNNEL_HEADER_SIZE {
|
||||
// header is not complete
|
||||
return None;
|
||||
|
@ -115,6 +118,11 @@ impl<R> FramedReader<R> {
|
|||
|
||||
let header = TCPTunnelHeader::ref_from_prefix(&buf[..]).unwrap();
|
||||
let body_len = header.len.get() as usize;
|
||||
if body_len > max_packet_size {
|
||||
// body is too long
|
||||
return Some(Err(TunnelError::InvalidPacket("body too long".to_string())));
|
||||
}
|
||||
|
||||
if buf.len() < TCP_TUNNEL_HEADER_SIZE + body_len {
|
||||
// body is not complete
|
||||
return None;
|
||||
|
@ -122,7 +130,7 @@ impl<R> FramedReader<R> {
|
|||
|
||||
// extract one packet
|
||||
let packet_buf = buf.split_to(TCP_TUNNEL_HEADER_SIZE + body_len);
|
||||
Some(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP))
|
||||
Some(Ok(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -139,8 +147,10 @@ where
|
|||
let mut self_mut = self.project();
|
||||
|
||||
loop {
|
||||
while let Some(packet) = Self::extract_one_packet(self_mut.buf) {
|
||||
return Poll::Ready(Some(Ok(packet)));
|
||||
while let Some(packet) =
|
||||
Self::extract_one_packet(self_mut.buf, *self_mut.max_packet_size)
|
||||
{
|
||||
return Poll::Ready(Some(packet));
|
||||
}
|
||||
|
||||
reserve_buf(
|
||||
|
@ -465,7 +475,14 @@ pub mod tests {
|
|||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
|
||||
_tunnel_pingpong_netns(
|
||||
listener,
|
||||
connector,
|
||||
NetNS::new(None),
|
||||
NetNS::new(None),
|
||||
"12345678abcdefg".as_bytes().to_vec(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
|
||||
|
@ -473,6 +490,7 @@ pub mod tests {
|
|||
mut connector: C,
|
||||
l_netns: NetNS,
|
||||
c_netns: NetNS,
|
||||
buf: Vec<u8>,
|
||||
) where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
|
@ -503,7 +521,7 @@ pub mod tests {
|
|||
|
||||
let (mut recv, mut send) = tunnel.split();
|
||||
|
||||
send.send(ZCPacket::new_with_payload("12345678abcdefg".as_bytes()))
|
||||
send.send(ZCPacket::new_with_payload(buf.as_slice()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -513,7 +531,7 @@ pub mod tests {
|
|||
.unwrap()
|
||||
.unwrap();
|
||||
println!("echo back: {:?}", ret);
|
||||
assert_eq!(ret.payload(), Bytes::from("12345678abcdefg"));
|
||||
assert_eq!(ret.payload(), Bytes::from(buf));
|
||||
|
||||
send.close().await.unwrap();
|
||||
|
||||
|
|
|
@ -158,7 +158,13 @@ where
|
|||
let mut buf = BytesMut::new();
|
||||
loop {
|
||||
reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 16);
|
||||
let (dg_size, addr) = socket.recv_buf_from(&mut buf).await.unwrap();
|
||||
let (dg_size, addr) = match socket.recv_buf_from(&mut buf).await {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::error!(?e, "udp recv from socket error");
|
||||
break;
|
||||
}
|
||||
};
|
||||
tracing::trace!(
|
||||
"udp recv packet: {:?}, buf: {:?}, size: {}",
|
||||
addr,
|
||||
|
|
Loading…
Reference in New Issue
Block a user