improve hole punching and stun test (#124)

* implement new stun test algorithm, do test faster and provide more info
* support punching for symmetric
This commit is contained in:
Sijie.Sun 2024-06-02 07:20:57 +08:00 committed by GitHub
parent bdbb1f02d6
commit abf9d23d52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1326 additions and 343 deletions

View File

@ -59,6 +59,9 @@ message StunInfo {
NatType udp_nat_type = 1;
NatType tcp_nat_type = 2;
int64 last_update_time = 3;
repeated string public_ip = 4;
uint32 min_port = 5;
uint32 max_port = 6;
}
message Route {

View File

@ -0,0 +1,24 @@
#[doc(hidden)]
pub struct Defer<F: FnOnce()> {
// internal struct used by defer! macro
func: Option<F>,
}
impl<F: FnOnce()> Defer<F> {
pub fn new(func: F) -> Self {
Self { func: Some(func) }
}
}
impl<F: FnOnce()> Drop for Defer<F> {
fn drop(&mut self) {
self.func.take().map(|f| f());
}
}
#[macro_export]
macro_rules! defer {
( $($tt:tt)* ) => {
let _deferred = $crate::common::defer::Defer::new(|| { $($tt)* });
};
}

View File

@ -8,6 +8,7 @@ use tracing::Instrument;
pub mod config;
pub mod constants;
pub mod defer;
pub mod error;
pub mod global_ctx;
pub mod ifcfg;

View File

@ -75,7 +75,7 @@ impl NetNSGuard {
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct NetNS {
name: Option<String>,
}

View File

@ -1,18 +1,20 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use std::collections::BTreeSet;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use crate::rpc::{NatType, StunInfo};
use anyhow::Context;
use chrono::Local;
use crossbeam::atomic::AtomicCell;
use rand::seq::IteratorRandom;
use tokio::net::{lookup_host, UdpSocket};
use tokio::sync::RwLock;
use tokio::sync::{broadcast, Mutex};
use tokio::task::JoinSet;
use tracing::Level;
use tracing::{Instrument, Level};
use bytecodec::{DecodeExt, EncodeExt};
use stun_codec::rfc5389::methods::BINDING;
use stun_codec::rfc5780::attributes::ChangeRequest;
use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder};
use crate::common::error::Error;
@ -22,13 +24,15 @@ use super::stun_codec_ext::*;
struct HostResolverIter {
hostnames: Vec<String>,
ips: Vec<SocketAddr>,
max_ip_per_domain: u32,
}
impl HostResolverIter {
fn new(hostnames: Vec<String>) -> Self {
fn new(hostnames: Vec<String>, max_ip_per_domain: u32) -> Self {
Self {
hostnames,
ips: vec![],
max_ip_per_domain,
}
}
@ -40,9 +44,17 @@ impl HostResolverIter {
}
let host = self.hostnames.remove(0);
let host = if host.contains(':') {
host
} else {
format!("{}:3478", host)
};
match lookup_host(&host).await {
Ok(ips) => {
self.ips = ips.collect();
self.ips = ips
.filter(|x| x.is_ipv4())
.choose_multiple(&mut rand::thread_rng(), self.max_ip_per_domain as usize);
}
Err(e) => {
tracing::warn!(?host, ?e, "lookup host for stun failed");
@ -55,19 +67,30 @@ impl HostResolverIter {
}
}
#[derive(Debug, Clone)]
struct StunPacket {
data: Vec<u8>,
addr: SocketAddr,
}
type StunPacketReceiver = tokio::sync::broadcast::Receiver<StunPacket>;
#[derive(Debug, Clone, Copy)]
struct BindRequestResponse {
source_addr: SocketAddr,
send_to_addr: SocketAddr,
local_addr: SocketAddr,
stun_server_addr: SocketAddr,
recv_from_addr: SocketAddr,
mapped_socket_addr: Option<SocketAddr>,
changed_socket_addr: Option<SocketAddr>,
ip_changed: bool,
port_changed: bool,
change_ip: bool,
change_port: bool,
real_ip_changed: bool,
real_port_changed: bool,
latency_us: u32,
}
impl BindRequestResponse {
@ -77,18 +100,26 @@ impl BindRequestResponse {
}
#[derive(Debug, Clone)]
struct Stun {
struct StunClient {
stun_server: SocketAddr,
req_repeat: u8,
resp_timeout: Duration,
req_repeat: u32,
socket: Arc<UdpSocket>,
stun_packet_receiver: Arc<Mutex<StunPacketReceiver>>,
}
impl Stun {
pub fn new(stun_server: SocketAddr) -> Self {
impl StunClient {
pub fn new(
stun_server: SocketAddr,
socket: Arc<UdpSocket>,
stun_packet_receiver: StunPacketReceiver,
) -> Self {
Self {
stun_server,
req_repeat: 2,
resp_timeout: Duration::from_millis(3000),
req_repeat: 2,
socket,
stun_packet_receiver: Arc::new(Mutex::new(stun_packet_receiver)),
}
}
@ -96,7 +127,6 @@ impl Stun {
async fn wait_stun_response<'a, const N: usize>(
&self,
buf: &'a mut [u8; N],
udp: &UdpSocket,
tids: &Vec<u128>,
expected_ip_changed: bool,
expected_port_changed: bool,
@ -106,16 +136,20 @@ impl Stun {
let deadline = now + self.resp_timeout;
while now < deadline {
let mut udp_buf = [0u8; 1500];
let (len, remote_addr) =
tokio::time::timeout(deadline - now, udp.recv_from(udp_buf.as_mut_slice()))
.await??;
let mut locked_receiver = self.stun_packet_receiver.lock().await;
let stun_packet_raw = tokio::time::timeout(deadline - now, locked_receiver.recv())
.await?
.with_context(|| "recv stun packet from broadcast channel error")?;
now = tokio::time::Instant::now();
let (len, remote_addr) = (stun_packet_raw.data.len(), stun_packet_raw.addr);
if len < 20 {
continue;
}
let udp_buf = stun_packet_raw.data;
// TODO:: we cannot borrow `buf` directly in udp recv_from, so we copy it here
unsafe { std::ptr::copy(udp_buf.as_ptr(), buf.as_ptr() as *mut u8, len) };
@ -136,18 +170,6 @@ impl Stun {
continue;
}
// some stun server use changed socket even we don't ask for.
if expected_ip_changed && stun_host.ip() == remote_addr.ip() {
continue;
}
if expected_port_changed
&& stun_host.ip() == remote_addr.ip()
&& stun_host.port() == remote_addr.port()
{
continue;
}
return Ok((msg, remote_addr));
}
@ -196,16 +218,14 @@ impl Stun {
#[tracing::instrument(ret, err, level = Level::DEBUG)]
pub async fn bind_request(
&self,
source_port: u16,
self,
change_ip: bool,
change_port: bool,
) -> Result<BindRequestResponse, Error> {
let stun_host = self.stun_server;
let udp = UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?;
// repeat req in case of packet loss
let mut tids = vec![];
for _ in 0..self.req_repeat {
let tid = rand::random::<u32>();
// let tid = 1;
@ -222,16 +242,19 @@ impl Stun {
let msg = encoder
.encode_into_bytes(message.clone())
.with_context(|| "encode stun message")?;
tids.push(tid as u128);
tracing::trace!(?message, ?msg, tid, "send stun request");
udp.send_to(msg.as_slice().into(), &stun_host).await?;
tracing::debug!(?message, ?msg, tid, "send stun request");
self.socket
.send_to(msg.as_slice().into(), &stun_host)
.await?;
}
let now = Instant::now();
tracing::trace!("waiting stun response");
let mut buf = [0; 1620];
let (msg, recv_addr) = self
.wait_stun_response(&mut buf, &udp, &tids, change_ip, change_port, &stun_host)
.wait_stun_response(&mut buf, &tids, change_ip, change_port, &stun_host)
.await?;
let changed_socket_addr = Self::extract_changed_addr(&msg);
@ -239,16 +262,18 @@ impl Stun {
let real_port_changed = stun_host.port() != recv_addr.port();
let resp = BindRequestResponse {
source_addr: udp.local_addr()?,
send_to_addr: stun_host,
local_addr: self.socket.local_addr()?,
stun_server_addr: stun_host,
recv_from_addr: recv_addr,
mapped_socket_addr: Self::extrace_mapped_addr(&msg),
changed_socket_addr,
ip_changed: change_ip,
port_changed: change_port,
change_ip,
change_port,
real_ip_changed,
real_port_changed,
latency_us: now.elapsed().as_micros() as u32,
};
tracing::debug!(
@ -262,105 +287,256 @@ impl Stun {
}
}
pub struct UdpNatTypeDetector {
stun_servers: Vec<String>,
struct StunClientBuilder {
udp: Arc<UdpSocket>,
task_set: JoinSet<()>,
stun_packet_sender: broadcast::Sender<StunPacket>,
}
impl UdpNatTypeDetector {
pub fn new(stun_servers: Vec<String>) -> Self {
Self { stun_servers }
}
impl StunClientBuilder {
pub fn new(udp: Arc<UdpSocket>) -> Self {
let (stun_packet_sender, _) = broadcast::channel(1024);
let mut task_set = JoinSet::new();
pub async fn get_udp_nat_type(&self, mut source_port: u16) -> NatType {
// Like classic STUN (rfc3489). Detect NAT behavior for UDP.
// Modified from rfc3489. Requires at least two STUN servers.
let mut ret_test1_1 = None;
let mut ret_test1_2 = None;
let mut ret_test2 = None;
let mut ret_test3 = None;
if source_port == 0 {
let udp = UdpSocket::bind("0.0.0.0:0").await.unwrap();
source_port = udp.local_addr().unwrap().port();
}
let mut succ = false;
let mut ips = HostResolverIter::new(self.stun_servers.clone());
while let Some(server_ip) = ips.next().await {
let stun = Stun::new(server_ip.clone());
let ret = stun.bind_request(source_port, false, false).await;
if ret.is_err() {
// Try another STUN server
continue;
}
if ret_test1_1.is_none() {
ret_test1_1 = ret.ok();
continue;
}
ret_test1_2 = ret.ok();
let ret = stun.bind_request(source_port, true, true).await;
if let Ok(resp) = ret {
if !resp.real_ip_changed || !resp.real_port_changed {
tracing::debug!(
?server_ip,
?ret,
"stun bind request return with unchanged ip and port"
);
// Try another STUN server
continue;
let udp_clone = udp.clone();
let stun_packet_sender_clone = stun_packet_sender.clone();
task_set.spawn(
async move {
let mut buf = [0; 1620];
tracing::info!("start stun packet listener");
loop {
let Ok((len, addr)) = udp_clone.recv_from(&mut buf).await else {
tracing::error!("udp recv_from error");
break;
};
let data = buf[..len].to_vec();
tracing::debug!(?addr, ?data, "recv udp stun packet");
let _ = stun_packet_sender_clone.send(StunPacket { data, addr });
}
}
ret_test2 = ret.ok();
ret_test3 = stun.bind_request(source_port, false, true).await.ok();
tracing::debug!(?ret_test3, "stun bind request with changed port");
succ = true;
break;
}
.instrument(tracing::info_span!("stun_packet_listener")),
);
if !succ {
Self {
udp,
task_set,
stun_packet_sender,
}
}
pub fn new_stun_client(&self, stun_server: SocketAddr) -> StunClient {
StunClient::new(
stun_server,
self.udp.clone(),
self.stun_packet_sender.subscribe(),
)
}
pub async fn stop(&mut self) {
self.task_set.abort_all();
while let Some(_) = self.task_set.join_next().await {}
}
}
#[derive(Debug, Clone)]
pub struct UdpNatTypeDetectResult {
source_addr: SocketAddr,
stun_resps: Vec<BindRequestResponse>,
}
impl UdpNatTypeDetectResult {
fn new(source_addr: SocketAddr, stun_resps: Vec<BindRequestResponse>) -> Self {
Self {
source_addr,
stun_resps,
}
}
fn has_ip_changed_resp(&self) -> bool {
for resp in self.stun_resps.iter() {
if resp.real_ip_changed {
return true;
}
}
false
}
fn has_port_changed_resp(&self) -> bool {
for resp in self.stun_resps.iter() {
if resp.real_port_changed {
return true;
}
}
false
}
fn is_open_internet(&self) -> bool {
for resp in self.stun_resps.iter() {
if resp.mapped_socket_addr == Some(self.source_addr) {
return true;
}
}
return false;
}
fn is_pat(&self) -> bool {
for resp in self.stun_resps.iter() {
if resp.mapped_socket_addr.map(|x| x.port()) == Some(self.source_addr.port()) {
return true;
}
}
false
}
fn stun_server_count(&self) -> usize {
// find resp with distinct stun server
self.stun_resps
.iter()
.map(|x| x.stun_server_addr)
.collect::<BTreeSet<_>>()
.len()
}
fn is_cone(&self) -> bool {
// if unique mapped addr count is less than stun server count, it is cone
let mapped_addr_count = self
.stun_resps
.iter()
.filter_map(|x| x.mapped_socket_addr)
.collect::<BTreeSet<_>>()
.len();
mapped_addr_count < self.stun_server_count()
}
pub fn nat_type(&self) -> NatType {
if self.stun_server_count() < 2 {
return NatType::Unknown;
}
tracing::debug!(
?ret_test1_1,
?ret_test1_2,
?ret_test2,
?ret_test3,
"finish stun test, try to detect nat type"
);
let ret_test1_1 = ret_test1_1.unwrap();
let ret_test1_2 = ret_test1_2.unwrap();
if ret_test1_1.mapped_socket_addr != ret_test1_2.mapped_socket_addr {
return NatType::Symmetric;
}
if ret_test1_1.mapped_socket_addr.is_some()
&& ret_test1_1.source_addr == ret_test1_1.mapped_socket_addr.unwrap()
{
if !ret_test2.is_none() {
return NatType::OpenInternet;
} else {
return NatType::SymUdpFirewall;
}
} else {
if let Some(ret_test2) = ret_test2 {
if source_port == ret_test2.get_mapped_addr_no_check().port()
&& source_port == ret_test1_1.get_mapped_addr_no_check().port()
{
if self.is_cone() {
if self.has_ip_changed_resp() {
if self.is_open_internet() {
return NatType::OpenInternet;
} else if self.is_pat() {
return NatType::NoPat;
} else {
return NatType::FullCone;
}
} else if self.has_port_changed_resp() {
return NatType::Restricted;
} else {
if !ret_test3.is_none() {
return NatType::Restricted;
} else {
return NatType::PortRestricted;
}
return NatType::PortRestricted;
}
} else if !self.stun_resps.is_empty() {
return NatType::Symmetric;
} else {
return NatType::Unknown;
}
}
pub fn public_ips(&self) -> Vec<IpAddr> {
self.stun_resps
.iter()
.filter_map(|x| x.mapped_socket_addr.map(|x| x.ip()))
.collect::<BTreeSet<_>>()
.into_iter()
.collect()
}
pub fn collect_available_stun_server(&self) -> Vec<SocketAddr> {
let mut ret = vec![];
for resp in self.stun_resps.iter() {
if !ret.contains(&resp.stun_server_addr) {
ret.push(resp.stun_server_addr);
}
}
ret
}
pub fn local_addr(&self) -> SocketAddr {
self.source_addr
}
pub fn extend_result(&mut self, other: UdpNatTypeDetectResult) {
self.stun_resps.extend(other.stun_resps);
}
pub fn min_port(&self) -> u16 {
self.stun_resps
.iter()
.filter_map(|x| x.mapped_socket_addr.map(|x| x.port()))
.min()
.unwrap_or(0)
}
pub fn max_port(&self) -> u16 {
self.stun_resps
.iter()
.filter_map(|x| x.mapped_socket_addr.map(|x| x.port()))
.max()
.unwrap_or(u16::MAX)
}
}
pub struct UdpNatTypeDetector {
stun_server_hosts: Vec<String>,
max_ip_per_domain: u32,
}
impl UdpNatTypeDetector {
pub fn new(stun_server_hosts: Vec<String>, max_ip_per_domain: u32) -> Self {
Self {
stun_server_hosts,
max_ip_per_domain,
}
}
pub async fn detect_nat_type(&self, source_port: u16) -> Result<UdpNatTypeDetectResult, Error> {
let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?);
self.detect_nat_type_with_socket(udp).await
}
#[tracing::instrument(skip(self))]
pub async fn detect_nat_type_with_socket(
&self,
udp: Arc<UdpSocket>,
) -> Result<UdpNatTypeDetectResult, Error> {
let mut stun_servers = vec![];
let mut host_resolver =
HostResolverIter::new(self.stun_server_hosts.clone(), self.max_ip_per_domain);
while let Some(addr) = host_resolver.next().await {
stun_servers.push(addr);
}
let client_builder = StunClientBuilder::new(udp.clone());
let mut stun_task_set = JoinSet::new();
for stun_server in stun_servers.iter() {
stun_task_set.spawn(
client_builder
.new_stun_client(*stun_server)
.bind_request(false, false),
);
stun_task_set.spawn(
client_builder
.new_stun_client(*stun_server)
.bind_request(false, true),
);
stun_task_set.spawn(
client_builder
.new_stun_client(*stun_server)
.bind_request(true, true),
);
}
let mut bind_resps = vec![];
while let Some(resp) = stun_task_set.join_next().await {
if let Ok(Ok(resp)) = resp {
bind_resps.push(resp);
}
}
Ok(UdpNatTypeDetectResult::new(udp.local_addr()?, bind_resps))
}
}
@ -373,7 +549,8 @@ pub trait StunInfoCollectorTrait: Send + Sync {
pub struct StunInfoCollector {
stun_servers: Arc<RwLock<Vec<String>>>,
udp_nat_type: Arc<AtomicCell<(NatType, std::time::Instant)>>,
udp_nat_test_result: Arc<RwLock<Option<UdpNatTypeDetectResult>>>,
nat_test_result_time: Arc<AtomicCell<chrono::DateTime<Local>>>,
redetect_notify: Arc<tokio::sync::Notify>,
tasks: JoinSet<()>,
}
@ -381,27 +558,47 @@ pub struct StunInfoCollector {
#[async_trait::async_trait]
impl StunInfoCollectorTrait for StunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
let (typ, time) = self.udp_nat_type.load();
let Some(result) = self.udp_nat_test_result.read().unwrap().clone() else {
return Default::default();
};
StunInfo {
udp_nat_type: typ as i32,
udp_nat_type: result.nat_type() as i32,
tcp_nat_type: 0,
last_update_time: time.elapsed().as_secs() as i64,
last_update_time: self.nat_test_result_time.load().timestamp(),
public_ip: result.public_ips().iter().map(|x| x.to_string()).collect(),
min_port: result.min_port() as u32,
max_port: result.max_port() as u32,
}
}
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> {
let stun_servers = self.stun_servers.read().await.clone();
let mut ips = HostResolverIter::new(stun_servers.clone());
while let Some(server) = ips.next().await {
let stun = Stun::new(server.clone());
let Ok(ret) = stun.bind_request(local_port, false, false).await else {
let stun_servers = self
.udp_nat_test_result
.read()
.unwrap()
.clone()
.map(|x| x.collect_available_stun_server())
.ok_or(Error::NotFound)?;
let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", local_port)).await?);
let mut client_builder = StunClientBuilder::new(udp.clone());
for server in stun_servers.iter() {
let Ok(ret) = client_builder
.new_stun_client(*server)
.bind_request(false, false)
.await
else {
tracing::warn!(?server, "stun bind request failed");
continue;
};
if let Some(mapped_addr) = ret.mapped_socket_addr {
// make sure udp socket is available after return ok.
client_builder.stop().await;
return Ok(mapped_addr);
}
}
Err(Error::NotFound)
}
}
@ -410,10 +607,8 @@ impl StunInfoCollector {
pub fn new(stun_servers: Vec<String>) -> Self {
let mut ret = Self {
stun_servers: Arc::new(RwLock::new(stun_servers)),
udp_nat_type: Arc::new(AtomicCell::new((
NatType::Unknown,
std::time::Instant::now(),
))),
udp_nat_test_result: Arc::new(RwLock::new(None)),
nat_test_result_time: Arc::new(AtomicCell::new(Local::now())),
redetect_notify: Arc::new(tokio::sync::Notify::new()),
tasks: JoinSet::new(),
};
@ -431,46 +626,78 @@ impl StunInfoCollector {
// NOTICE: we may need to choose stun stun server based on geo location
// stun server cross nation may return a external ip address with high latency and loss rate
vec![
"stun.miwifi.com:3478".to_string(),
"stun.chat.bilibili.com:3478".to_string(), // bilibili's stun server doesn't repond to change_ip and change_port
"stun.cloudflare.com:3478".to_string(),
"stun.syncthing.net:3478".to_string(),
"stun.isp.net.au:3478".to_string(),
"stun.nextcloud.com:3478".to_string(),
"stun.freeswitch.org:3478".to_string(),
"stun.voip.blackberry.com:3478".to_string(),
"stunserver.stunprotocol.org:3478".to_string(),
"stun.sipnet.com:3478".to_string(),
"stun.radiojar.com:3478".to_string(),
"stun.sonetel.com:3478".to_string(),
"stun.voipgate.com:3478".to_string(),
"stun.miwifi.com",
"stun.cdnbye.com",
"stun.hitv.com",
"stun.chat.bilibili.com",
"stun.douyucdn.cn:18000",
"fwa.lifesizecloud.com",
"global.turn.twilio.com",
"turn.cloudflare.com",
"stun.isp.net.au",
"stun.nextcloud.com",
"stun.freeswitch.org",
"stun.voip.blackberry.com",
"stunserver.stunprotocol.org",
"stun.sipnet.com",
"stun.radiojar.com",
"stun.sonetel.com",
]
.iter()
.map(|x| x.to_string())
.collect()
}
fn start_stun_routine(&mut self) {
let stun_servers = self.stun_servers.clone();
let udp_nat_type = self.udp_nat_type.clone();
let udp_nat_test_result = self.udp_nat_test_result.clone();
let udp_test_time = self.nat_test_result_time.clone();
let redetect_notify = self.redetect_notify.clone();
self.tasks.spawn(async move {
loop {
let detector = UdpNatTypeDetector::new(stun_servers.read().await.clone());
let old_nat_type = udp_nat_type.load().0;
let mut ret = NatType::Unknown;
for _ in 1..5 {
// if nat type degrade, sleep and retry. so result can be relatively stable.
ret = detector.get_udp_nat_type(0).await;
if ret == NatType::Unknown || ret <= old_nat_type {
break;
let servers = stun_servers.read().unwrap().clone();
// use first three and random choose one from the rest
let servers = servers
.iter()
.take(2)
.chain(servers.iter().skip(2).choose(&mut rand::thread_rng()))
.map(|x| x.to_string())
.collect();
let detector = UdpNatTypeDetector::new(servers, 1);
let ret = detector.detect_nat_type(0).await;
tracing::debug!(?ret, "finish udp nat type detect");
let mut nat_type = NatType::Unknown;
let sleep_sec = match &ret {
Ok(resp) => {
*udp_nat_test_result.write().unwrap() = Some(resp.clone());
udp_test_time.store(Local::now());
nat_type = resp.nat_type();
if nat_type == NatType::Unknown {
15
} else {
600
}
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
udp_nat_type.store((ret, std::time::Instant::now()));
let sleep_sec = match ret {
NatType::Unknown => 15,
_ => 60,
_ => 15,
};
tracing::info!(?ret, ?sleep_sec, "finish udp nat type detect");
// if nat type is symmtric, detect with another port to gather more info
if nat_type == NatType::Symmetric {
let old_resp = ret.unwrap();
let old_local_port = old_resp.local_addr().port();
let new_port = if old_local_port >= 65535 {
old_local_port - 1
} else {
old_local_port + 1
};
let ret = detector.detect_nat_type(new_port).await;
tracing::debug!(?ret, "finish udp nat type detect with another port");
if let Ok(resp) = ret {
udp_nat_test_result.write().unwrap().as_mut().map(|x| {
x.extend_result(resp);
});
}
}
tokio::select! {
_ = redetect_notify.notified() => {}
@ -483,53 +710,26 @@ impl StunInfoCollector {
pub fn update_stun_info(&self) {
self.redetect_notify.notify_one();
}
pub async fn set_stun_servers(&self, stun_servers: Vec<String>) {
*self.stun_servers.write().await = stun_servers;
self.update_stun_info();
}
}
#[cfg(test)]
mod tests {
use super::*;
pub fn enable_log() {
let filter = tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing::level_filters::LevelFilter::TRACE.into())
.from_env()
.unwrap()
.add_directive("tarpc=error".parse().unwrap());
tracing_subscriber::fmt::fmt()
.pretty()
.with_env_filter(filter)
.init();
}
#[tokio::test]
async fn test_stun_bind_request() {
// miwifi / qq seems not correctly responde to change_ip and change_port, they always try to change the src ip and port.
// let mut ips = HostResolverIter::new(vec!["stun1.l.google.com:19302".to_string()]);
let mut ips_ = HostResolverIter::new(vec!["stun.canets.org:3478".to_string()]);
let mut ips = vec![];
while let Some(ip) = ips_.next().await {
ips.push(ip);
async fn test_udp_nat_type_detector() {
let collector = StunInfoCollector::new_with_default_servers();
collector.update_stun_info();
loop {
let ret = collector.get_stun_info();
if ret.udp_nat_type != NatType::Unknown as i32 {
println!("{:#?}", ret);
break;
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
println!("ip: {:?}", ips);
for ip in ips.iter() {
let stun = Stun::new(ip.clone());
let _rs = stun.bind_request(12345, true, true).await;
}
}
#[tokio::test]
async fn test_udp_nat_type_detect() {
let detector = UdpNatTypeDetector::new(vec![
"stun.counterpath.com:3478".to_string(),
"180.235.108.91:3478".to_string(),
]);
let ret = detector.get_udp_nat_type(0).await;
assert_ne!(ret, NatType::Unknown);
let port_mapping = collector.get_udp_port_mapping(3000).await;
println!("{:#?}", port_mapping);
}
}

View File

@ -1,11 +1,12 @@
use std::net::SocketAddr;
use bytecodec::fixnum::{U32beDecoder, U32beEncoder};
use stun_codec::net::{socket_addr_xor, SocketAddrDecoder, SocketAddrEncoder};
use stun_codec::rfc5389::attributes::{
MappedAddress, Software, XorMappedAddress, XorMappedAddress2,
};
use stun_codec::rfc5780::attributes::{ChangeRequest, OtherAddress, ResponseOrigin};
use stun_codec::rfc5780::attributes::{OtherAddress, ResponseOrigin};
use stun_codec::{define_attribute_enums, AttributeType, Message, TransactionId};
use bytecodec::{ByteCount, Decode, Encode, Eos, Result, SizedEncode, TryTaggedDecode};
@ -197,6 +198,75 @@ impl_encode!(SourceAddressEncoder, SourceAddress, |item: Self::Item| {
item.0
});
/// `CHANGE-REQUEST` attribute.
///
/// See [RFC 5780 -- 7.2. CHANGE-REQUEST] about this attribute.
///
/// [RFC 5780 -- 7.2. CHANGE-REQUEST]: https://tools.ietf.org/html/rfc5780#section-7.2
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ChangeRequest(bool, bool);
impl ChangeRequest {
/// The codepoint of the type of the attribute.
pub const CODEPOINT: u16 = 0x0003;
/// Makes a new `ChangeRequest` instance.
pub fn new(ip: bool, port: bool) -> Self {
ChangeRequest(ip, port)
}
/// Returns whether the client requested the server to send the Binding Response with a
/// different IP address than the one the Binding Request was received on
pub fn ip(&self) -> bool {
self.0
}
/// Returns whether the client requested the server to send the Binding Response with a
/// different port than the one the Binding Request was received on
pub fn port(&self) -> bool {
self.1
}
}
impl stun_codec::Attribute for ChangeRequest {
type Decoder = ChangeRequestDecoder;
type Encoder = ChangeRequestEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
}
/// [`ChangeRequest`] decoder.
#[derive(Debug, Default)]
pub struct ChangeRequestDecoder(U32beDecoder);
impl ChangeRequestDecoder {
/// Makes a new `ChangeRequestDecoder` instance.
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(ChangeRequestDecoder, ChangeRequest, |item| {
Ok(ChangeRequest((item & 0x4) != 0, (item & 0x2) != 0))
});
/// [`ChangeRequest`] encoder.
#[derive(Debug, Default)]
pub struct ChangeRequestEncoder(U32beEncoder);
impl ChangeRequestEncoder {
/// Makes a new `ChangeRequestEncoder` instance.
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(ChangeRequestEncoder, ChangeRequest, |item: Self::Item| {
let ip = item.0 as u8;
let port = item.1 as u8;
((ip << 1 | port) << 1) as u32
});
pub fn tid_to_u128(tid: &TransactionId) -> u128 {
let mut tid_buf = [0u8; 16];
// copy bytes from msg_tid to tid_buf

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,11 @@
#![allow(dead_code)]
use std::{net::SocketAddr, vec};
use std::{net::SocketAddr, time::Duration, vec};
use clap::{command, Args, Parser, Subcommand};
use common::stun::StunInfoCollectorTrait;
use rpc::vpn_portal_rpc_client::VpnPortalRpcClient;
use tokio::time::timeout;
use utils::{list_peer_route_pair, PeerRoutePair};
mod arch;
@ -13,7 +15,7 @@ mod tunnel;
mod utils;
use crate::{
common::stun::{StunInfoCollector, UdpNatTypeDetector},
common::stun::StunInfoCollector,
rpc::{
connector_manage_rpc_client::ConnectorManageRpcClient,
peer_center_rpc_client::PeerCenterRpcClient, peer_manage_rpc_client::PeerManageRpcClient,
@ -309,8 +311,19 @@ async fn main() -> Result<(), Error> {
handler.handle_route_list().await?;
}
SubCommand::Stun => {
let stun = UdpNatTypeDetector::new(StunInfoCollector::get_default_servers());
println!("udp type: {:?}", stun.get_udp_nat_type(0).await);
timeout(Duration::from_secs(5), async move {
let collector = StunInfoCollector::new_with_default_servers();
loop {
let ret = collector.get_stun_info();
if ret.udp_nat_type != NatType::Unknown as i32 {
println!("stun info: {:#?}", ret);
break;
}
tokio::time::sleep(Duration::from_millis(200)).await;
}
})
.await
.unwrap();
}
SubCommand::PeerCenter => {
let mut peer_center_client = handler.get_peer_center_client().await?;

View File

@ -76,7 +76,7 @@ impl PeerConnPinger {
let now = std::time::Instant::now();
// wait until we get a pong packet in ctrl_resp_receiver
let resp = timeout(Duration::from_secs(1), async {
let resp = timeout(Duration::from_secs(2), async {
loop {
match receiver.recv().await {
Ok(p) => {

View File

@ -77,16 +77,16 @@ fn new_sack_packet(conn_id: u32, magic: u64) -> ZCPacket {
)
}
pub fn new_hole_punch_packet() -> ZCPacket {
pub fn new_hole_punch_packet(tid: u32, buf_len: u16) -> ZCPacket {
// generate a 128 bytes vec with random data
let mut rng = rand::rngs::StdRng::from_entropy();
let mut buf = vec![0u8; 128];
let mut buf = vec![0u8; buf_len as usize];
rng.fill(&mut buf[..]);
new_udp_packet(
|header| {
header.msg_type = UdpPacketType::HolePunch as u8;
header.conn_id.set(0);
header.len.set(0);
header.conn_id.set(tid);
header.len.set(buf_len);
},
Some(&mut buf),
)
@ -304,7 +304,7 @@ impl UdpTunnelListenerData {
let header = zc_packet.udp_tunnel_header().unwrap();
if header.msg_type == UdpPacketType::Syn as u8 {
tokio::spawn(Self::handle_new_connect(self.clone(), *addr, zc_packet));
} else {
} else if header.msg_type != UdpPacketType::HolePunch as u8 {
if let Err(e) = self
.try_forward_packet(addr, header.conn_id.get(), zc_packet)
.await
@ -526,11 +526,10 @@ impl UdpTunnelConnector {
async fn build_tunnel(
&self,
socket: UdpSocket,
socket: Arc<UdpSocket>,
dst_addr: SocketAddr,
conn_id: u32,
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let socket = Arc::new(socket);
let ring_for_send_udp = Arc::new(RingTunnel::new(128));
let ring_for_recv_udp = Arc::new(RingTunnel::new(128));
tracing::debug!(
@ -610,13 +609,13 @@ impl UdpTunnelConnector {
pub async fn try_connect_with_socket(
&self,
socket: UdpSocket,
socket: Arc<UdpSocket>,
addr: SocketAddr,
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
log::warn!("udp connect: {:?}", self.addr);
#[cfg(target_os = "windows")]
crate::arch::windows::disable_connection_reset(&socket)?;
crate::arch::windows::disable_connection_reset(socket.as_ref())?;
// send syn
let conn_id = rand::random();
@ -649,7 +648,7 @@ impl UdpTunnelConnector {
UdpSocket::bind("[::]:0").await?
};
return self.try_connect_with_socket(socket, addr).await;
return self.try_connect_with_socket(Arc::new(socket), addr).await;
}
async fn connect_with_custom_bind(
@ -666,7 +665,7 @@ impl UdpTunnelConnector {
)?;
setup_sokcet2(&socket2_socket, &bind_addr)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
futures.push(self.try_connect_with_socket(socket, addr));
futures.push(self.try_connect_with_socket(Arc::new(socket), addr));
}
wait_for_connect_futures(futures).await
}