minor fixed (#120)

1. fix mtu, always set by ourselves and use smaller value
2. wireguard connector should return tunnel after receive packet
This commit is contained in:
Sijie.Sun 2024-05-18 18:04:06 +08:00 committed by GitHub
parent 0ead308392
commit 6efbb5cb3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 79 additions and 58 deletions

View File

@ -151,7 +151,7 @@ pub struct Flags {
pub enable_encryption: bool, pub enable_encryption: bool,
#[derivative(Default(value = "true"))] #[derivative(Default(value = "true"))]
pub enable_ipv6: bool, pub enable_ipv6: bool,
#[derivative(Default(value = "1420"))] #[derivative(Default(value = "1380"))]
pub mtu: u16, pub mtu: u16,
#[derivative(Default(value = "true"))] #[derivative(Default(value = "true"))]
pub latency_first: bool, pub latency_first: bool,

View File

@ -30,6 +30,7 @@ pub trait IfConfiguerTrait: Send + Sync {
async fn wait_interface_show(&self, _name: &str) -> Result<(), Error> { async fn wait_interface_show(&self, _name: &str) -> Result<(), Error> {
return Ok(()); return Ok(());
} }
async fn set_mtu(&self, _name: &str, _mtu: u32) -> Result<(), Error>;
} }
fn cidr_to_subnet_mask(prefix_length: u8) -> Ipv4Addr { fn cidr_to_subnet_mask(prefix_length: u8) -> Ipv4Addr {
@ -77,9 +78,7 @@ async fn run_shell_cmd(cmd: &str) -> Result<(), Error> {
tracing::info!(?cmd, ?ec, ?succ, ?stdout, ?stderr, "run shell cmd"); tracing::info!(?cmd, ?ec, ?succ, ?stdout, ?stderr, "run shell cmd");
if !cmd_out.status.success() { if !cmd_out.status.success() {
return Err(Error::ShellCommandError( return Err(Error::ShellCommandError(stdout + &stderr));
stdout + &stderr,
));
} }
Ok(()) Ok(())
} }
@ -154,6 +153,10 @@ impl IfConfiguerTrait for MacIfConfiger {
.await .await
} }
} }
async fn set_mtu(&self, name: &str, mtu: u32) -> Result<(), Error> {
run_shell_cmd(format!("ifconfig {} mtu {}", name, mtu).as_str()).await
}
} }
pub struct LinuxIfConfiger {} pub struct LinuxIfConfiger {}
@ -210,6 +213,10 @@ impl IfConfiguerTrait for LinuxIfConfiger {
.await .await
} }
} }
async fn set_mtu(&self, name: &str, mtu: u32) -> Result<(), Error> {
run_shell_cmd(format!("ip link set dev {} mtu {}", name, mtu).as_str()).await
}
} }
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
@ -362,6 +369,13 @@ impl IfConfiguerTrait for WindowsIfConfiger {
.await??, .await??,
) )
} }
async fn set_mtu(&self, name: &str, mtu: u32) -> Result<(), Error> {
run_shell_cmd(
format!("netsh interface ipv4 set subinterface {} mtu={}", name, mtu).as_str(),
)
.await
}
} }
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]

View File

@ -295,14 +295,6 @@ impl VirtualNic {
todo!("queue_num != 1") todo!("queue_num != 1")
} }
config.queues(self.queue_num); config.queues(self.queue_num);
let flags = self.global_ctx.config.get_flags();
let mut mtu_in_config = flags.mtu;
if flags.enable_encryption {
mtu_in_config -= 20;
}
config.mtu(mtu_in_config as i32);
config.up(); config.up();
let dev = { let dev = {
@ -313,6 +305,19 @@ impl VirtualNic {
let ifname = dev.get_ref().name()?; let ifname = dev.get_ref().name()?;
self.ifcfg.wait_interface_show(ifname.as_str()).await?; self.ifcfg.wait_interface_show(ifname.as_str()).await?;
let flags = self.global_ctx.config.get_flags();
let mut mtu_in_config = flags.mtu;
if flags.enable_encryption {
mtu_in_config -= 20;
}
{
// set mtu by ourselves, rust-tun does not handle it correctly on windows
let _g = self.global_ctx.net_ns.guard();
self.ifcfg
.set_mtu(ifname.as_str(), mtu_in_config as u32)
.await?;
}
let (a, b) = BiLock::new(dev); let (a, b) = BiLock::new(dev);
let ft = TunnelWrapper::new( let ft = TunnelWrapper::new(

View File

@ -38,7 +38,7 @@ use super::{
IpVersion, Tunnel, TunnelError, TunnelListener, TunnelUrl, ZCPacketSink, ZCPacketStream, IpVersion, Tunnel, TunnelError, TunnelListener, TunnelUrl, ZCPacketSink, ZCPacketStream,
}; };
const MAX_PACKET: usize = 65500; const MAX_PACKET: usize = 2048;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
enum WgType { enum WgType {
@ -335,6 +335,7 @@ impl WgPeerData {
} }
struct WgPeer { struct WgPeer {
tunn: Option<Mutex<Tunn>>,
udp: Arc<UdpSocket>, // only for send udp: Arc<UdpSocket>, // only for send
config: WgConfig, config: WgConfig,
endpoint: SocketAddr, endpoint: SocketAddr,
@ -350,10 +351,18 @@ struct WgPeer {
impl WgPeer { impl WgPeer {
fn new(udp: Arc<UdpSocket>, config: WgConfig, endpoint: SocketAddr) -> Self { fn new(udp: Arc<UdpSocket>, config: WgConfig, endpoint: SocketAddr) -> Self {
WgPeer { WgPeer {
tunn: Some(Mutex::new(Tunn::new(
config.my_secret_key.clone(),
config.peer_public_key.clone(),
None,
None,
rand::thread_rng().next_u32(),
None,
))),
udp, udp,
config, config,
endpoint, endpoint,
sink: std::sync::Mutex::new(None), sink: std::sync::Mutex::new(None),
data: None, data: None,
@ -392,14 +401,7 @@ impl WgPeer {
let data = WgPeerData { let data = WgPeerData {
udp: self.udp.clone(), udp: self.udp.clone(),
endpoint: self.endpoint, endpoint: self.endpoint,
tunn: Arc::new(Mutex::new(Tunn::new( tunn: Arc::new(self.tunn.take().unwrap()),
self.config.my_secret_key.clone(),
self.config.peer_public_key.clone(),
None,
None,
rand::thread_rng().next_u32(),
None,
))),
wg_type: self.config.wg_type.clone(), wg_type: self.config.wg_type.clone(),
stopped: Arc::new(AtomicBool::new(false)), stopped: Arc::new(AtomicBool::new(false)),
}; };
@ -421,6 +423,29 @@ impl WgPeer {
.stopped .stopped
.load(std::sync::atomic::Ordering::Relaxed) .load(std::sync::atomic::Ordering::Relaxed)
} }
async fn create_handshake_init(&self) -> Vec<u8> {
let mut dst = vec![0u8; 2048];
let handshake_init = self
.tunn
.as_ref()
.unwrap()
.lock()
.await
.format_handshake_initiation(&mut dst, false);
assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
sent
} else {
unreachable!();
};
handshake_init.into()
}
fn udp_socket(&self) -> Arc<UdpSocket> {
self.udp.clone()
}
} }
type ConnSender = tokio::sync::mpsc::UnboundedSender<Box<dyn Tunnel>>; type ConnSender = tokio::sync::mpsc::UnboundedSender<Box<dyn Tunnel>>;
@ -592,37 +617,6 @@ impl WgTunnelConnector {
} }
} }
fn create_handshake_init(tun: &mut Tunn) -> Vec<u8> {
let mut dst = vec![0u8; 2048];
let handshake_init = tun.format_handshake_initiation(&mut dst, false);
assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
sent
} else {
unreachable!();
};
handshake_init.into()
}
fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec<u8> {
let mut dst = vec![0u8; 2048];
let keepalive = tun.decapsulate(None, handshake_resp, &mut dst);
assert!(
matches!(keepalive, TunnResult::WriteToNetwork(_)),
"Failed to parse handshake response, {:?}",
keepalive
);
let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive {
sent
} else {
unreachable!();
};
keepalive.into()
}
#[tracing::instrument(skip(config))] #[tracing::instrument(skip(config))]
async fn connect_with_socket( async fn connect_with_socket(
addr_url: url::Url, addr_url: url::Url,
@ -634,17 +628,25 @@ impl WgTunnelConnector {
let local_addr = udp.local_addr().unwrap().to_string(); let local_addr = udp.local_addr().unwrap().to_string();
let mut wg_peer = WgPeer::new(Arc::new(udp), config.clone(), addr); let mut wg_peer = WgPeer::new(Arc::new(udp), config.clone(), addr);
let tunnel = wg_peer.start_and_get_tunnel(); let udp = wg_peer.udp_socket();
// do handshake here so we will return after receive first packet
let handshake = wg_peer.create_handshake_init().await;
udp.send_to(&handshake, addr).await?;
let mut buf = [0u8; MAX_PACKET];
let (n, recv_addr) = udp.recv_from(&mut buf).await.unwrap();
if recv_addr != addr {
tracing::warn!(?recv_addr, "Received packet from changed address");
}
let tunnel = wg_peer.start_and_get_tunnel();
let data = wg_peer.data.as_ref().unwrap().clone(); let data = wg_peer.data.as_ref().unwrap().clone();
let mut sink = wg_peer.sink.lock().unwrap().take().unwrap(); let mut sink = wg_peer.sink.lock().unwrap().take().unwrap();
wg_peer.tasks.spawn(async move { wg_peer.tasks.spawn(async move {
data.handle_one_packet_from_peer(&mut sink, &buf[..n]).await;
loop { loop {
let mut buf = vec![0u8; MAX_PACKET]; let mut buf = vec![0u8; MAX_PACKET];
let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap(); let (n, _) = data.udp.recv_from(&mut buf).await.unwrap();
if recv_addr != addr {
tracing::warn!(?recv_addr, "Received packet from changed address");
}
data.handle_one_packet_from_peer(&mut sink, &buf[..n]).await; data.handle_one_packet_from_peer(&mut sink, &buf[..n]).await;
} }
}); });