Fix udp and win route (#16)

* robust udp tunnel
* fix windows route add
* use pnet to get index
* windows disable udp reset
This commit is contained in:
Sijie.Sun 2024-02-08 16:27:18 +08:00 committed by GitHub
parent 2c2e41be24
commit 7fc4aecdb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 285 additions and 11 deletions

View File

@ -87,6 +87,9 @@ public-ip = { version = "0.2", features = ["default"] }
clap = { version = "4.4", features = ["derive"] }
[target.'cfg(windows)'.dependencies]
windows-sys = { version = "0.52", features = ["Win32_Networking_WinSock", "Win32_NetworkManagement_IpHelper", "Win32_Foundation", "Win32_System_IO"] }
[build-dependencies]
tonic-build = "0.10"

View File

@ -0,0 +1,2 @@
#[cfg(target_os = "windows")]
pub mod windows;

View File

@ -0,0 +1,142 @@
use std::{
ffi::c_void,
io::{self, ErrorKind},
mem,
net::SocketAddr,
os::windows::io::AsRawSocket,
ptr,
};
use windows_sys::{
core::PCSTR,
Win32::{
Foundation::{BOOL, FALSE},
Networking::WinSock::{
htonl, setsockopt, WSAGetLastError, WSAIoctl, IPPROTO_IP, IPPROTO_IPV6,
IPV6_UNICAST_IF, IP_UNICAST_IF, SIO_UDP_CONNRESET, SOCKET, SOCKET_ERROR,
},
},
};
use crate::tunnels::common::get_interface_name_by_ip;
pub fn disable_connection_reset<S: AsRawSocket>(socket: &S) -> io::Result<()> {
let handle = socket.as_raw_socket() as SOCKET;
unsafe {
// Ignoring UdpSocket's WSAECONNRESET error
// https://github.com/shadowsocks/shadowsocks-rust/issues/179
// https://stackoverflow.com/questions/30749423/is-winsock-error-10054-wsaeconnreset-normal-with-udp-to-from-localhost
//
// This is because `UdpSocket::recv_from` may return WSAECONNRESET
// if you called `UdpSocket::send_to` a destination that is not existed (may be closed).
//
// It is not an error. Could be ignored completely.
// We have to ignore it here because it will crash the server.
let mut bytes_returned: u32 = 0;
let enable: BOOL = FALSE;
let ret = WSAIoctl(
handle,
SIO_UDP_CONNRESET,
&enable as *const _ as *const c_void,
mem::size_of_val(&enable) as u32,
ptr::null_mut(),
0,
&mut bytes_returned as *mut _,
ptr::null_mut(),
None,
);
if ret == SOCKET_ERROR {
use std::io::Error;
// Error occurs
let err_code = WSAGetLastError();
return Err(Error::from_raw_os_error(err_code));
}
}
Ok(())
}
pub fn find_interface_index_cached(iface_name: &str) -> io::Result<u32> {
let ifaces = pnet::datalink::interfaces();
for iface in ifaces {
if iface.name == iface_name {
return Ok(iface.index);
}
}
let err = io::Error::new(
ErrorKind::NotFound,
format!("Failed to find interface index for {}", iface_name),
);
Err(err)
}
pub fn set_ip_unicast_if<S: AsRawSocket>(
socket: &S,
addr: &SocketAddr,
iface: &str,
) -> io::Result<()> {
let handle = socket.as_raw_socket() as SOCKET;
let if_index = find_interface_index_cached(iface)?;
unsafe {
// https://docs.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
let ret = match addr {
SocketAddr::V4(..) => {
// Interface index is in network byte order for IPPROTO_IP.
let if_index = htonl(if_index);
setsockopt(
handle,
IPPROTO_IP as i32,
IP_UNICAST_IF as i32,
&if_index as *const _ as PCSTR,
mem::size_of_val(&if_index) as i32,
)
}
SocketAddr::V6(..) => {
// Interface index is in host byte order for IPPROTO_IPV6.
setsockopt(
handle,
IPPROTO_IPV6 as i32,
IPV6_UNICAST_IF as i32,
&if_index as *const _ as PCSTR,
mem::size_of_val(&if_index) as i32,
)
}
};
if ret == SOCKET_ERROR {
let err = io::Error::from_raw_os_error(WSAGetLastError());
tracing::error!(
"set IP_UNICAST_IF / IPV6_UNICAST_IF interface: {}, index: {}, error: {}",
iface,
if_index,
err
);
return Err(err);
}
}
Ok(())
}
pub fn setup_socket_for_win<S: AsRawSocket>(
socket: &S,
bind_addr: &SocketAddr,
is_udp: bool,
) -> io::Result<()> {
if is_udp {
disable_connection_reset(socket)?;
}
if let Some(iface) = get_interface_name_by_ip(&bind_addr.ip()) {
set_ip_unicast_if(socket, bind_addr, iface.as_str())?;
}
Ok(())
}

View File

@ -196,8 +196,17 @@ impl IfConfiguerTrait for LinuxIfConfiger {
}
}
#[cfg(target_os = "windows")]
pub struct WindowsIfConfiger {}
#[cfg(target_os = "windows")]
impl WindowsIfConfiger {
pub fn get_interface_index(name: &str) -> Option<u32> {
crate::arch::windows::find_interface_index_cached(name).ok()
}
}
#[cfg(target_os = "windows")]
#[async_trait]
impl IfConfiguerTrait for WindowsIfConfiger {
async fn add_ipv4_route(
@ -206,12 +215,15 @@ impl IfConfiguerTrait for WindowsIfConfiger {
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
let Some(idx) = Self::get_interface_index(name) else {
return Err(Error::NotFound);
};
run_shell_cmd(
format!(
"route add {} mask {} {}",
"route ADD {} MASK {} 10.1.1.1 IF {} METRIC 255",
address,
cidr_to_subnet_mask(cidr_prefix),
name
idx
)
.as_str(),
)
@ -224,12 +236,15 @@ impl IfConfiguerTrait for WindowsIfConfiger {
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
let Some(idx) = Self::get_interface_index(name) else {
return Err(Error::NotFound);
};
run_shell_cmd(
format!(
"route delete {} mask {} {}",
"route DELETE {} MASK {} IF {}",
address,
cidr_to_subnet_mask(cidr_prefix),
name
idx
)
.as_str(),
)

View File

@ -14,6 +14,7 @@ use crate::{
},
peers::{peer_manager::PeerManager, PeerId},
tunnels::{
common::setup_sokcet2,
udp_tunnel::{UdpPacket, UdpTunnelConnector, UdpTunnelListener},
Tunnel, TunnelConnCounter, TunnelListener,
},
@ -387,9 +388,14 @@ impl UdpHolePunchConnector {
.unwrap(),
);
let socket = UdpSocket::bind(local_socket_addr)
.await
.with_context(|| "")?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(local_socket_addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2(&socket2_socket, &local_socket_addr)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
Ok(connector
.try_connect_with_socket(socket)
.await

View File

@ -5,6 +5,7 @@ mod tests;
use clap::Parser;
mod arch;
mod common;
mod connector;
mod gateway;

View File

@ -273,6 +273,12 @@ pub(crate) fn setup_sokcet2(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
) -> Result<(), TunnelError> {
#[cfg(target_os = "windows")]
{
let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM);
crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, is_udp)?;
}
socket2_socket.set_nonblocking(true)?;
socket2_socket.set_reuse_address(true)?;
socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?;

View File

@ -121,7 +121,11 @@ fn get_tunnel_from_socket(
}
let (buf, addr) = v.unwrap();
assert_eq!(addr, recv_addr.clone());
if recv_addr != addr {
tracing::warn!(?addr, ?recv_addr, "udp recv addr not match");
return None;
}
Some(Ok(try_get_data_payload(buf, conn_id.clone())?))
});
let stream = Box::pin(stream);
@ -304,7 +308,7 @@ impl TunnelListener for UdpTunnelListener {
};
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
let conn = Self::handle_connect(
let Ok(conn) = Self::handle_connect(
socket.clone(),
addr,
forward_tasks.clone(),
@ -313,7 +317,10 @@ impl TunnelListener for UdpTunnelListener {
udp_packet.conn_id.into(),
)
.await
.unwrap();
else {
tracing::error!(?addr, "udp handle connect error");
continue;
};
if let Err(e) = conn_send.send(conn).await {
tracing::warn!(?e, "udp send conn to accept channel error");
}
@ -465,6 +472,9 @@ impl UdpTunnelConnector {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
log::warn!("udp connect: {:?}", self.addr);
#[cfg(target_os = "windows")]
crate::arch::windows::disable_connection_reset(&socket)?;
// send syn
let conn_id = rand::random();
let udp_packet = UdpPacket::new_syn_packet(conn_id);
@ -544,7 +554,12 @@ impl super::TunnelConnector for UdpTunnelConnector {
#[cfg(test)]
mod tests {
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
use std::time::Duration;
use rand::Rng;
use tokio::time::timeout;
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong};
use super::*;
@ -578,4 +593,88 @@ mod tests {
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
async fn send_random_data_to_socket(remote_url: url::Url) {
let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
socket
.connect(format!(
"{}:{}",
remote_url.host().unwrap(),
remote_url.port().unwrap()
))
.await
.unwrap();
// get a random 100-len buf
loop {
let mut buf = vec![0u8; 100];
rand::thread_rng().fill(&mut buf[..]);
socket.send(&buf).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
}
#[tokio::test]
async fn udp_multiple_conns() {
let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap());
listener.listen().await.unwrap();
let _lis = tokio::spawn(async move {
loop {
let ret = listener.accept().await.unwrap();
assert_eq!(
ret.info().unwrap().local_addr,
listener.local_url().to_string()
);
tokio::spawn(async move { _tunnel_echo_server(ret, false).await });
}
});
let mut connector1 = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap());
let mut connector2 = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap());
let t1 = connector1.connect().await.unwrap();
let t2 = connector2.connect().await.unwrap();
tokio::spawn(timeout(
Duration::from_secs(2),
send_random_data_to_socket(t1.info().unwrap().local_addr.parse().unwrap()),
));
tokio::spawn(timeout(
Duration::from_secs(2),
send_random_data_to_socket(t1.info().unwrap().remote_addr.parse().unwrap()),
));
tokio::spawn(timeout(
Duration::from_secs(2),
send_random_data_to_socket(t2.info().unwrap().remote_addr.parse().unwrap()),
));
let sender1 = tokio::spawn(async move {
let mut sink = t1.pin_sink();
let mut stream = t1.pin_stream();
for i in 0..10 {
sink.send(Bytes::from("hello1")).await.unwrap();
let recv = stream.next().await.unwrap().unwrap();
println!("t1 recv: {:?}, {:?}", recv, i);
assert_eq!(recv, Bytes::from("hello1"));
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
});
let sender2 = tokio::spawn(async move {
let mut sink = t2.pin_sink();
let mut stream = t2.pin_stream();
for i in 0..10 {
sink.send(Bytes::from("hello2")).await.unwrap();
let recv = stream.next().await.unwrap().unwrap();
println!("t2 recv: {:?}, {:?}", recv, i);
assert_eq!(recv, Bytes::from("hello2"));
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
});
let _ = tokio::join!(sender1, sender2);
}
}