allow specify bind dev for tunnels. also fix bugs #46)

1. fix wireguard / udp tunnel stack overflow on win.
2. custom panic handler to save panic stack.
3. fix iface filter on windows and linux.
4. add scheme black list to direct connector
This commit is contained in:
Sijie.Sun 2024-04-03 21:46:52 +08:00 committed by GitHub
parent 25a7603990
commit e4be86cf92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 320 additions and 104 deletions

View File

@ -86,6 +86,7 @@ anyhow = "1.0"
tarpc = { version = "0.32", features = ["tokio1", "serde1"] }
url = { version = "2.5", features = ["serde"] }
percent-encoding = "2.3.1"
# for tun packet
byteorder = "1.5.0"
@ -144,8 +145,8 @@ serial_test = "3.0.0"
rstest = "0.18.2"
[profile.dev]
panic = "abort"
panic = "unwind"
[profile.release]
panic = "abort"
panic = "unwind"
lto = true

View File

@ -19,8 +19,6 @@ use windows_sys::{
},
};
use crate::tunnels::common::get_interface_name_by_ip;
pub fn disable_connection_reset<S: AsRawSocket>(socket: &S) -> io::Result<()> {
let handle = socket.as_raw_socket() as SOCKET;
@ -132,13 +130,14 @@ pub fn set_ip_unicast_if<S: AsRawSocket>(
pub fn setup_socket_for_win<S: AsRawSocket>(
socket: &S,
bind_addr: &SocketAddr,
bind_dev: Option<String>,
is_udp: bool,
) -> io::Result<()> {
if is_udp {
disable_connection_reset(socket)?;
}
if let Some(iface) = get_interface_name_by_ip(&bind_addr.ip()) {
if let Some(iface) = bind_dev {
set_ip_unicast_if(socket, bind_addr, iface.as_str())?;
}

View File

@ -1,7 +1,3 @@
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 60;
pub const DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC: u64 = 60;
macro_rules! define_global_var {
($name:ident, $type:ty, $init:expr) => {
pub static $name: once_cell::sync::Lazy<tokio::sync::Mutex<$type>> =

View File

@ -7,7 +7,9 @@ use tokio::{
task::JoinSet,
};
use super::{constants::DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC, netns::NetNS};
use super::netns::NetNS;
pub const CACHED_IP_LIST_TIMEOUT_SEC: u64 = 60;
struct InterfaceFilter {
iface: NetworkInterface,
@ -15,33 +17,37 @@ struct InterfaceFilter {
#[cfg(target_os = "linux")]
impl InterfaceFilter {
async fn is_iface_bridge(&self) -> bool {
let path = format!("/sys/class/net/{}/bridge", self.iface.name);
async fn is_tun_tap_device(&self) -> bool {
let path = format!("/sys/class/net/{}/tun_flags", self.iface.name);
tokio::fs::metadata(&path).await.is_ok()
}
async fn is_iface_phsical(&self) -> bool {
let path = format!("/sys/class/net/{}/device", self.iface.name);
tokio::fs::metadata(&path).await.is_ok()
async fn has_valid_ip(&self) -> bool {
self.iface
.ips
.iter()
.map(|ip| ip.ip())
.any(|ip| !ip.is_loopback() && !ip.is_unspecified() && !ip.is_multicast())
}
async fn filter_iface(&self) -> bool {
tracing::trace!(
"filter linux iface: {:?}, is_point_to_point: {}, is_loopback: {}, is_up: {}, is_lower_up: {}, is_bridge: {}, is_physical: {}",
"filter linux iface: {:?}, is_point_to_point: {}, is_loopback: {}, is_up: {}, is_lower_up: {}, is_tun: {}, has_valid_ip: {}",
self.iface,
self.iface.is_point_to_point(),
self.iface.is_loopback(),
self.iface.is_up(),
self.iface.is_lower_up(),
self.is_iface_bridge().await,
self.is_iface_phsical().await,
self.is_tun_tap_device().await,
self.has_valid_ip().await
);
!self.iface.is_point_to_point()
&& !self.iface.is_loopback()
&& self.iface.is_up()
&& self.iface.is_lower_up()
&& (self.is_iface_bridge().await || self.is_iface_phsical().await)
&& !self.is_tun_tap_device().await
&& self.has_valid_ip().await
}
}
@ -85,7 +91,22 @@ impl InterfaceFilter {
#[cfg(target_os = "windows")]
impl InterfaceFilter {
async fn filter_iface(&self) -> bool {
!self.iface.is_point_to_point() && !self.iface.is_loopback() && self.iface.is_up()
tracing::debug!(
"iface_name: {:?}, p2p: {:?}, is_up: {:?}, iface: {:?}",
self.iface.name,
self.iface.is_point_to_point(),
self.iface.is_up(),
self.iface
);
!self.iface.is_point_to_point()
&& !self.iface.is_loopback()
&& self
.iface
.ips
.iter()
.map(|ip| ip.ip())
.any(|ip| !ip.is_loopback() && !ip.is_unspecified() && !ip.is_multicast())
&& self.iface.mac.map(|mac| !mac.is_zero()).unwrap_or(false)
}
}
@ -143,9 +164,7 @@ impl IPCollector {
loop {
let ip_addrs = Self::do_collect_ip_addrs(true, net_ns.clone()).await;
*cached_ip_list.write().await = ip_addrs;
tokio::time::sleep(std::time::Duration::from_secs(
DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC,
))
tokio::time::sleep(std::time::Duration::from_secs(CACHED_IP_LIST_TIMEOUT_SEC))
.await;
}
});
@ -154,6 +173,25 @@ impl IPCollector {
return self.cached_ip_list.read().await.deref().clone();
}
pub async fn collect_interfaces(net_ns: NetNS) -> Vec<NetworkInterface> {
let _g = net_ns.guard();
let ifaces = pnet::datalink::interfaces();
let mut ret = vec![];
for iface in ifaces {
let f = InterfaceFilter {
iface: iface.clone(),
};
if !f.filter_iface().await {
continue;
}
ret.push(iface);
}
ret
}
#[tracing::instrument(skip(net_ns))]
async fn do_collect_ip_addrs(with_public: bool, net_ns: NetNS) -> GetIpListResponse {
let mut ret = crate::rpc::peer::GetIpListResponse::new();
@ -170,17 +208,9 @@ impl IPCollector {
}
}
let ifaces = Self::collect_interfaces(net_ns.clone()).await;
let _g = net_ns.guard();
let ifaces = pnet::datalink::interfaces();
for iface in ifaces {
let f = InterfaceFilter {
iface: iface.clone(),
};
if !f.filter_iface().await {
continue;
}
for ip in iface.ips {
let ip: std::net::IpAddr = ip.ip();
if ip.is_loopback() || ip.is_multicast() {

View File

@ -3,12 +3,7 @@
use std::sync::Arc;
use crate::{
common::{
constants::{self, DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC},
error::Error,
global_ctx::ArcGlobalCtx,
PeerId,
},
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
peers::{peer_manager::PeerManager, peer_rpc::PeerRpcManager},
};
@ -18,6 +13,9 @@ use tracing::Instrument;
use super::create_connector_by_url;
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300;
#[tarpc::service]
pub trait DirectConnectorRpc {
async fn get_ip_list() -> GetIpListResponse;
@ -76,10 +74,25 @@ impl DirectConnectorManagerRpcServer {
#[derive(Hash, Eq, PartialEq, Clone)]
struct DstBlackListItem(PeerId, String);
#[derive(Hash, Eq, PartialEq, Clone)]
struct DstSchemeBlackListItem(PeerId, String);
struct DirectConnectorManagerData {
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
dst_blacklist: timedmap::TimedMap<DstBlackListItem, ()>,
dst_sceme_blacklist: timedmap::TimedMap<DstSchemeBlackListItem, ()>,
}
impl DirectConnectorManagerData {
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
Self {
global_ctx,
peer_manager,
dst_blacklist: timedmap::TimedMap::new(),
dst_sceme_blacklist: timedmap::TimedMap::new(),
}
}
}
impl std::fmt::Debug for DirectConnectorManagerData {
@ -101,11 +114,7 @@ impl DirectConnectorManager {
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
Self {
global_ctx: global_ctx.clone(),
data: Arc::new(DirectConnectorManagerData {
global_ctx,
peer_manager,
dst_blacklist: timedmap::TimedMap::new(),
}),
data: Arc::new(DirectConnectorManagerData::new(global_ctx, peer_manager)),
tasks: JoinSet::new(),
}
}
@ -117,7 +126,7 @@ impl DirectConnectorManager {
pub fn run_as_server(&mut self) {
self.data.peer_manager.get_peer_rpc_mgr().run_service(
constants::DIRECT_CONNECTOR_SERVICE_ID,
DIRECT_CONNECTOR_SERVICE_ID,
DirectConnectorManagerRpcServer::new(self.global_ctx.clone()).serve(),
);
}
@ -193,7 +202,7 @@ impl DirectConnectorManager {
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
addr: String,
) {
) -> Result<(), Error> {
let ret = Self::do_try_connect_to_ip(data.clone(), dst_peer_id, addr.clone()).await;
if let Err(e) = ret {
if !matches!(e, Error::UrlInBlacklist) {
@ -208,47 +217,36 @@ impl DirectConnectorManager {
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
);
}
return Err(e);
} else {
log::info!("try_connect_to_ip success, peer_id: {}", dst_peer_id);
}
}
#[tracing::instrument]
async fn do_try_direct_connect(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
) -> Result<(), Error> {
let peer_manager = data.peer_manager.clone();
// check if we have direct connection with dst_peer_id
if let Some(c) = peer_manager.list_peer_conns(dst_peer_id).await {
// currently if we have any type of direct connection (udp or tcp), we will not try to connect
if !c.is_empty() {
return Ok(());
}
}
log::trace!("try direct connect to peer: {}", dst_peer_id);
let ip_list = peer_manager
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, dst_peer_id, |c| async {
let client =
DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn();
let ip_list = client.get_ip_list(tarpc::context::current()).await;
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
ip_list
})
.await?;
#[tracing::instrument]
async fn do_try_direct_connect_internal(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
ip_list: GetIpListResponse,
) -> Result<(), Error> {
let available_listeners = ip_list
.listeners
.iter()
.filter_map(|l| if l.scheme() != "ring" { Some(l) } else { None })
.filter(|l| l.port().is_some())
.filter(|l| {
!data.dst_sceme_blacklist.contains(&DstSchemeBlackListItem(
dst_peer_id.clone(),
l.scheme().to_string(),
))
})
.collect::<Vec<_>>();
let mut listener = available_listeners
.get(0)
.ok_or(anyhow::anyhow!("peer {} have no listener", dst_peer_id))?;
let mut listener = available_listeners.get(0).ok_or(anyhow::anyhow!(
"peer {} have no valid listener",
dst_peer_id
))?;
// if have default listener, use it first
listener = available_listeners
@ -283,30 +281,77 @@ impl DirectConnectorManager {
addr,
));
let mut has_succ = false;
while let Some(ret) = tasks.join_next().await {
if let Err(e) = ret {
log::error!("join direct connect task failed: {:?}", e);
} else if let Ok(Ok(_)) = ret {
has_succ = true;
}
}
if !has_succ {
data.dst_sceme_blacklist.insert(
DstSchemeBlackListItem(dst_peer_id.clone(), listener.scheme().to_string()),
(),
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
);
}
Ok(())
}
#[tracing::instrument]
async fn do_try_direct_connect(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
) -> Result<(), Error> {
let peer_manager = data.peer_manager.clone();
// check if we have direct connection with dst_peer_id
if let Some(c) = peer_manager.list_peer_conns(dst_peer_id).await {
// currently if we have any type of direct connection (udp or tcp), we will not try to connect
if !c.is_empty() {
return Ok(());
}
}
log::trace!("try direct connect to peer: {}", dst_peer_id);
let ip_list = peer_manager
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, dst_peer_id, |c| async {
let client =
DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn();
let ip_list = client.get_ip_list(tarpc::context::current()).await;
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
ip_list
})
.await?;
Self::do_try_direct_connect_internal(data, dst_peer_id, ip_list).await
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
connector::direct::DirectConnectorManager,
connector::direct::{
DirectConnectorManager, DirectConnectorManagerData, DstBlackListItem,
DstSchemeBlackListItem,
},
instance::listeners::ListenerManager,
peers::tests::{
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
wait_route_appear_with_cost,
},
tunnels::tcp_tunnel::TcpTunnelListener,
rpc::peer::GetIpListResponse,
};
#[rstest::rstest]
#[tokio::test]
async fn direct_connector_basic_test() {
async fn direct_connector_basic_test(#[values("tcp", "udp", "wg")] proto: &str) {
let p_a = create_mock_peer_manager().await;
let p_b = create_mock_peer_manager().await;
let p_c = create_mock_peer_manager().await;
@ -321,14 +366,14 @@ mod tests {
dm_a.run_as_client();
dm_c.run_as_server();
let port = if proto == "wg" { 11040 } else { 11041 };
p_c.get_global_ctx()
.config
.set_listeners(vec![format!("{}://0.0.0.0:{}", proto, port)
.parse()
.unwrap()]);
let mut lis_c = ListenerManager::new(p_c.get_global_ctx(), p_c.clone());
lis_c
.add_listener(TcpTunnelListener::new(
"tcp://0.0.0.0:11040".parse().unwrap(),
))
.await
.unwrap();
lis_c.prepare_listeners().await.unwrap();
lis_c.run().await.unwrap();
@ -336,4 +381,31 @@ mod tests {
.await
.unwrap();
}
#[tokio::test]
async fn direct_connector_scheme_blacklist() {
let p_a = create_mock_peer_manager().await;
let data = Arc::new(DirectConnectorManagerData::new(
p_a.get_global_ctx(),
p_a.clone(),
));
let mut ip_list = GetIpListResponse::new();
ip_list
.listeners
.push("tcp://127.0.0.1:10222".parse().unwrap());
ip_list.interface_ipv4s.push("127.0.0.1".to_string());
DirectConnectorManager::do_try_direct_connect_internal(data.clone(), 1, ip_list.clone())
.await
.unwrap();
assert!(data
.dst_sceme_blacklist
.contains(&DstSchemeBlackListItem(1, "tcp".into())));
assert!(data
.dst_blacklist
.contains(&DstBlackListItem(1, ip_list.listeners[0].to_string())));
}
}

View File

@ -3,7 +3,7 @@
#[cfg(test)]
mod tests;
use std::net::SocketAddr;
use std::{backtrace, io::Write as _, net::SocketAddr};
use anyhow::Context;
use clap::Parser;
@ -318,9 +318,21 @@ fn peer_conn_info_to_string(p: crate::rpc::PeerConnInfo) -> String {
)
}
fn setup_panic_handler() {
std::panic::set_hook(Box::new(|info| {
let backtrace = backtrace::Backtrace::force_capture();
println!("panic occurred: {:?}", info);
let _ = std::fs::File::create("easytier-panic.log")
.and_then(|mut f| f.write_all(format!("{:?}\n{:#?}", info, backtrace).as_bytes()));
std::process::exit(1);
}));
}
#[tokio::main(flavor = "current_thread")]
#[tracing::instrument]
pub async fn main() {
setup_panic_handler();
let cli = Cli::parse();
tracing::info!(cli = ?cli, "cli args parsed");

View File

@ -25,7 +25,7 @@ use crate::{
PeerId,
},
define_tunnel_filter_chain,
peers::packet::{ArchivedPacketType, CtrlPacketPayload},
peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType},
rpc::{PeerConnInfo, PeerConnStats},
tunnels::{
stats::{Throughput, WindowLatency},
@ -52,6 +52,12 @@ macro_rules! wait_response {
let $out_var;
let rsp_bytes = Packet::decode(&rsp_vec);
if rsp_bytes.packet_type != PacketType::HandShake {
tracing::error!("unexpected packet type: {:?}", rsp_bytes);
return Err(TunnelError::WaitRespError(
"unexpected packet type".to_owned(),
));
}
let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes);
match &resp_payload {
$pattern => $out_var = $value,

View File

@ -275,14 +275,15 @@ pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
None
}
pub(crate) fn setup_sokcet2(
pub(crate) fn setup_sokcet2_ext(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
bind_dev: Option<String>,
) -> Result<(), TunnelError> {
#[cfg(target_os = "windows")]
{
let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM);
crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, is_udp)?;
crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?;
}
socket2_socket.set_nonblocking(true)?;
@ -299,7 +300,7 @@ pub(crate) fn setup_sokcet2(
// linux/mac does not use interface of bind_addr to send packet, so we need to bind device
// win can handle this with bind correctly
#[cfg(any(target_os = "ios", target_os = "macos"))]
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
if let Some(dev_name) = bind_dev {
// use IP_BOUND_IF to bind device
unsafe {
let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8);
@ -310,7 +311,7 @@ pub(crate) fn setup_sokcet2(
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
if let Some(dev_name) = bind_dev {
tracing::trace!(dev_name = ?dev_name, "bind device");
socket2_socket.bind_device(Some(dev_name.as_bytes()))?;
}
@ -318,6 +319,17 @@ pub(crate) fn setup_sokcet2(
Ok(())
}
pub(crate) fn setup_sokcet2(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
) -> Result<(), TunnelError> {
setup_sokcet2_ext(
socket2_socket,
bind_addr,
super::common::get_interface_name_by_ip(&bind_addr.ip()),
)
}
pub mod tests {
use std::time::Instant;

View File

@ -158,3 +158,35 @@ impl FromUrl for uuid::Uuid {
Ok(o)
}
}
pub struct TunnelUrl {
inner: url::Url,
}
impl From<url::Url> for TunnelUrl {
fn from(url: url::Url) -> Self {
TunnelUrl { inner: url }
}
}
impl From<TunnelUrl> for url::Url {
fn from(url: TunnelUrl) -> Self {
url.into_inner()
}
}
impl TunnelUrl {
pub fn into_inner(self) -> url::Url {
self.inner
}
pub fn bind_dev(&self) -> Option<String> {
self.inner.path().strip_prefix("/").and_then(|s| {
if s.is_empty() {
None
} else {
Some(String::from_utf8(percent_encoding::percent_decode_str(&s).collect()).unwrap())
}
})
}
}

View File

@ -23,9 +23,9 @@ use crate::{
use super::{
codec::BytesCodec,
common::{setup_sokcet2, FramedTunnel, TunnelWithCustomInfo},
common::{setup_sokcet2, setup_sokcet2_ext, FramedTunnel, TunnelWithCustomInfo},
ring_tunnel::create_ring_tunnel_pair,
DatagramSink, DatagramStream, Tunnel, TunnelListener,
DatagramSink, DatagramStream, Tunnel, TunnelListener, TunnelUrl,
};
pub const UDP_DATA_MTU: usize = 65000;
@ -323,7 +323,14 @@ impl TunnelListener for UdpTunnelListener {
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() {
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
} else {
setup_sokcet2(&socket2_socket, &addr)?;
}
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
let socket = self.socket.as_ref().unwrap().clone();
@ -335,7 +342,7 @@ impl TunnelListener for UdpTunnelListener {
async move {
loop {
let mut buf = BytesMut::new();
buf.resize(2500, 0);
buf.resize(UDP_DATA_MTU, 0);
let (_size, addr) = socket.recv_from(&mut buf).await.unwrap();
let _ = buf.split_off(_size);
log::trace!(
@ -597,7 +604,16 @@ mod tests {
use rand::Rng;
use tokio::time::timeout;
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong};
use crate::{
common::global_ctx::tests::get_mock_global_ctx,
tunnels::{
check_scheme_and_get_socket_addr,
common::{
get_interface_name_by_ip, setup_sokcet2_ext,
tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong},
},
},
};
use super::*;
@ -723,4 +739,34 @@ mod tests {
let a_udp_packet = rkyv_util::decode_from_bytes::<UdpPacket>(&b).unwrap();
println!("{:?}, {:?}", udp_packet, a_udp_packet);
}
#[tokio::test]
async fn bind_multi_ip_to_same_dev() {
let global_ctx = get_mock_global_ctx();
let ips = global_ctx
.get_ip_collector()
.collect_ip_addrs()
.await
.interface_ipv4s;
if ips.is_empty() {
return;
}
let bind_dev = get_interface_name_by_ip(&ips[0].parse().unwrap());
for ip in ips {
println!("bind to ip: {:?}, {:?}", ip, bind_dev);
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(
&format!("udp://{}:11111", ip).parse().unwrap(),
"udp",
)
.unwrap();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)
.unwrap();
setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap();
}
}
}

View File

@ -26,8 +26,10 @@ use crate::{
};
use super::{
check_scheme_and_get_socket_addr, common::setup_sokcet2, ring_tunnel::create_ring_tunnel_pair,
DatagramSink, DatagramStream, Tunnel, TunnelError, TunnelListener,
check_scheme_and_get_socket_addr,
common::{setup_sokcet2, setup_sokcet2_ext},
ring_tunnel::create_ring_tunnel_pair,
DatagramSink, DatagramStream, Tunnel, TunnelError, TunnelListener, TunnelUrl,
};
const MAX_PACKET: usize = 65500;
@ -132,7 +134,7 @@ impl Debug for WgPeerData {
impl WgPeerData {
#[tracing::instrument]
async fn handle_one_packet_from_me(&self, packet: &[u8]) -> Result<(), anyhow::Error> {
let mut send_buf = [0u8; MAX_PACKET];
let mut send_buf = vec![0u8; MAX_PACKET];
let encapsulate_result = {
let mut peer = self.tunn.lock().await;
@ -180,7 +182,7 @@ impl WgPeerData {
/// decapsulates them, and dispatches newly received IP packets.
#[tracing::instrument]
pub async fn handle_one_packet_from_peer(&self, recv_buf: &[u8]) {
let mut send_buf = [0u8; MAX_PACKET];
let mut send_buf = vec![0u8; MAX_PACKET];
let data = &recv_buf[..];
let decapsulate_result = {
let mut peer = self.tunn.lock().await;
@ -200,7 +202,7 @@ impl WgPeerData {
};
let mut peer = self.tunn.lock().await;
loop {
let mut send_buf = [0u8; MAX_PACKET];
let mut send_buf = vec![0u8; MAX_PACKET];
match peer.decapsulate(None, &[], &mut send_buf) {
TunnResult::WriteToNetwork(packet) => {
match self.udp.send_to(packet, self.endpoint).await {
@ -288,10 +290,11 @@ impl WgPeerData {
}
TunnResult::Done => {
// Sleep for a bit
tokio::time::sleep(Duration::from_millis(1)).await;
tokio::time::sleep(Duration::from_millis(250)).await;
}
other => {
tracing::warn!("Unexpected WireGuard routine task state: {:?}", other);
tokio::time::sleep(Duration::from_millis(250)).await;
}
};
}
@ -299,7 +302,7 @@ impl WgPeerData {
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
pub async fn routine_task(self) {
loop {
let mut send_buf = [0u8; MAX_PACKET];
let mut send_buf = vec![0u8; MAX_PACKET];
let tun_result = { self.tunn.lock().await.update_timers(&mut send_buf) };
self.handle_routine_tun_result(tun_result).await;
}
@ -462,7 +465,7 @@ impl WgTunnelListener {
}
});
let mut buf = [0u8; 4096];
let mut buf = vec![0u8; MAX_PACKET];
loop {
let Ok((n, addr)) = socket.recv_from(&mut buf).await else {
tracing::error!("Failed to receive from UDP socket");
@ -508,7 +511,14 @@ impl TunnelListener for WgTunnelListener {
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() {
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
} else {
setup_sokcet2(&socket2_socket, &addr)?;
}
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
self.tasks.spawn(Self::handle_udp_incoming(
self.get_udp_socket(),
@ -636,7 +646,7 @@ impl WgTunnelConnector {
let init = Self::create_handshake_init(&mut my_tun);
udp.send_to(&init, addr).await?;
let mut buf = [0u8; MAX_PACKET];
let mut buf = vec![0u8; MAX_PACKET];
let (n, _) = udp.recv_from(&mut buf).await.unwrap();
let keepalive = Self::parse_handshake_resp(&mut my_tun, &buf[..n]);
udp.send_to(&keepalive, addr).await?;
@ -647,7 +657,7 @@ impl WgTunnelConnector {
let data = wg_peer.data.as_ref().unwrap().clone();
wg_peer.tasks.spawn(async move {
loop {
let mut buf = [0u8; MAX_PACKET];
let mut buf = vec![0u8; MAX_PACKET];
let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap();
if recv_addr != addr {
continue;