mirror of
https://github.com/EasyTier/EasyTier.git
synced 2024-11-16 11:42:27 +08:00
some minor bug fixs (#41)
* fix joinset leak; * fix udp packet format * fix trace log panic * avoid waiting after listener accept
This commit is contained in:
parent
0f6f553010
commit
ce889e990e
|
@ -1,3 +1,11 @@
|
|||
use std::{
|
||||
fmt::Debug,
|
||||
future,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::Instrument;
|
||||
|
||||
pub mod config;
|
||||
pub mod constants;
|
||||
pub mod error;
|
||||
|
@ -30,3 +38,83 @@ pub type PeerId = u32;
|
|||
pub fn new_peer_id() -> PeerId {
|
||||
rand::random()
|
||||
}
|
||||
|
||||
pub fn join_joinset_background<T: Debug + Send + Sync + 'static>(
|
||||
js: Arc<Mutex<JoinSet<T>>>,
|
||||
origin: String,
|
||||
) {
|
||||
let js = Arc::downgrade(&js);
|
||||
tokio::spawn(
|
||||
async move {
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
if js.weak_count() == 0 {
|
||||
tracing::info!("joinset task exit");
|
||||
break;
|
||||
}
|
||||
|
||||
future::poll_fn(|cx| {
|
||||
tracing::info!("try join joinset tasks");
|
||||
let Some(js) = js.upgrade() else {
|
||||
return std::task::Poll::Ready(());
|
||||
};
|
||||
|
||||
let mut js = js.lock().unwrap();
|
||||
while !js.is_empty() {
|
||||
let ret = js.poll_join_next(cx);
|
||||
if ret.is_pending() {
|
||||
return std::task::Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
std::task::Poll::Ready(())
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!(
|
||||
"join_joinset_background",
|
||||
origin = origin
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_join_joinset_backgroud() {
|
||||
let js = Arc::new(Mutex::new(JoinSet::<()>::new()));
|
||||
join_joinset_background(js.clone(), "TEST".to_owned());
|
||||
js.try_lock().unwrap().spawn(async {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
});
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
assert!(js.try_lock().unwrap().is_empty());
|
||||
|
||||
for _ in 0..5 {
|
||||
js.try_lock().unwrap().spawn(async {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
});
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
|
||||
for _ in 0..5 {
|
||||
js.try_lock().unwrap().spawn(async {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
});
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
assert!(js.try_lock().unwrap().is_empty());
|
||||
|
||||
let weak_js = Arc::downgrade(&js);
|
||||
drop(js);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
assert_eq!(weak_js.weak_count(), 0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,8 +8,8 @@ use tracing::Instrument;
|
|||
|
||||
use crate::{
|
||||
common::{
|
||||
constants, error::Error, global_ctx::ArcGlobalCtx, rkyv_util::encode_to_bytes,
|
||||
stun::StunInfoCollectorTrait, PeerId,
|
||||
constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background,
|
||||
rkyv_util::encode_to_bytes, stun::StunInfoCollectorTrait, PeerId,
|
||||
},
|
||||
peers::peer_manager::PeerManager,
|
||||
rpc::NatType,
|
||||
|
@ -75,9 +75,15 @@ impl UdpHolePunchListener {
|
|||
while let Ok(conn) = listener.accept().await {
|
||||
last_connected_time_clone.store(std::time::Instant::now());
|
||||
tracing::warn!(?conn, "udp hole punching listener got peer connection");
|
||||
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
|
||||
tracing::error!(?e, "failed to add tunnel as server in hole punch listener");
|
||||
}
|
||||
let peer_mgr = peer_mgr.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
|
||||
tracing::error!(
|
||||
?e,
|
||||
"failed to add tunnel as server in hole punch listener"
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
running_clone.store(false);
|
||||
|
@ -115,7 +121,7 @@ struct UdpHolePunchConnectorData {
|
|||
struct UdpHolePunchRpcServer {
|
||||
data: Arc<UdpHolePunchConnectorData>,
|
||||
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
|
@ -140,7 +146,7 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
|
|||
|| my_udp_nat_type == NatType::Restricted as i32
|
||||
{
|
||||
// send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
self.tasks.lock().unwrap().spawn(async move {
|
||||
for _ in 0..10 {
|
||||
tracing::info!(?local_mapped_addr, "sending hole punching packet");
|
||||
// generate a 128 bytes vec with random data
|
||||
|
@ -164,10 +170,9 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
|
|||
|
||||
impl UdpHolePunchRpcServer {
|
||||
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
|
||||
Self {
|
||||
data,
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
}
|
||||
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
|
||||
join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned());
|
||||
Self { data, tasks }
|
||||
}
|
||||
|
||||
async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> {
|
||||
|
|
|
@ -16,6 +16,7 @@ use tracing::Instrument;
|
|||
|
||||
use crate::common::error::Result;
|
||||
use crate::common::global_ctx::GlobalCtx;
|
||||
use crate::common::join_joinset_background;
|
||||
use crate::common::netns::NetNS;
|
||||
use crate::peers::packet::{self, ArchivedPacket};
|
||||
use crate::peers::peer_manager::PeerManager;
|
||||
|
@ -71,7 +72,7 @@ pub struct TcpProxy {
|
|||
peer_manager: Arc<PeerManager>,
|
||||
local_port: AtomicU16,
|
||||
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
|
||||
syn_map: SynSockMap,
|
||||
conn_map: ConnSockMap,
|
||||
|
@ -215,7 +216,7 @@ impl TcpProxy {
|
|||
peer_manager,
|
||||
|
||||
local_port: AtomicU16::new(0),
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
|
||||
|
||||
syn_map: Arc::new(DashMap::new()),
|
||||
conn_map: Arc::new(DashMap::new()),
|
||||
|
@ -247,6 +248,7 @@ impl TcpProxy {
|
|||
self.peer_manager
|
||||
.add_nic_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
join_joinset_background(self.tasks.clone(), "TcpProxy".to_owned());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -268,7 +270,7 @@ impl TcpProxy {
|
|||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
}
|
||||
};
|
||||
tasks.lock().await.spawn(syn_map_cleaner_task);
|
||||
tasks.lock().unwrap().spawn(syn_map_cleaner_task);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -312,7 +314,7 @@ impl TcpProxy {
|
|||
let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone());
|
||||
assert!(old_nat_val.is_none());
|
||||
|
||||
tasks.lock().await.spawn(Self::connect_to_nat_dst(
|
||||
tasks.lock().unwrap().spawn(Self::connect_to_nat_dst(
|
||||
net_ns.clone(),
|
||||
tcp_stream,
|
||||
conn_map.clone(),
|
||||
|
@ -325,7 +327,7 @@ impl TcpProxy {
|
|||
};
|
||||
self.tasks
|
||||
.lock()
|
||||
.await
|
||||
.unwrap()
|
||||
.spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener")));
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -100,15 +100,19 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
|
|||
tunnel_info.remote_addr.clone(),
|
||||
));
|
||||
tracing::info!(ret = ?ret, "conn accepted");
|
||||
let server_ret = peer_manager.handle_tunnel(ret).await;
|
||||
if let Err(e) = &server_ret {
|
||||
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
|
||||
tunnel_info.local_addr,
|
||||
tunnel_info.remote_addr,
|
||||
e.to_string(),
|
||||
));
|
||||
tracing::error!(error = ?e, "handle conn error");
|
||||
}
|
||||
let peer_manager = peer_manager.clone();
|
||||
let global_ctx = global_ctx.clone();
|
||||
tokio::spawn(async move {
|
||||
let server_ret = peer_manager.handle_tunnel(ret).await;
|
||||
if let Err(e) = &server_ret {
|
||||
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
|
||||
tunnel_info.local_addr,
|
||||
tunnel_info.remote_addr,
|
||||
e.to_string(),
|
||||
));
|
||||
tracing::error!(error = ?e, "handle conn error");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ pub enum PacketType {
|
|||
TaRpc = 6,
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[derive(Archive, Deserialize, Serialize)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
pub struct Packet {
|
||||
|
@ -109,6 +109,19 @@ pub struct Packet {
|
|||
pub payload: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Packet {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
|
||||
self.from_peer,
|
||||
self.to_peer,
|
||||
self.packet_type,
|
||||
&self.payload.as_bytes()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ArchivedPacket {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
|
|
|
@ -13,7 +13,10 @@ use tokio_util::{
|
|||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::rkyv_util::{self, encode_to_bytes, vec_to_string},
|
||||
common::{
|
||||
join_joinset_background,
|
||||
rkyv_util::{self, encode_to_bytes, vec_to_string},
|
||||
},
|
||||
rpc::TunnelInfo,
|
||||
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
|
||||
};
|
||||
|
@ -27,7 +30,7 @@ use super::{
|
|||
|
||||
pub const UDP_DATA_MTU: usize = 2500;
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[derive(Archive, Deserialize, Serialize)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
pub enum UdpPacketPayload {
|
||||
|
@ -37,14 +40,29 @@ pub enum UdpPacketPayload {
|
|||
Data(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for UdpPacketPayload {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
|
||||
match self {
|
||||
UdpPacketPayload::Syn => tmp.field("Syn", &"").finish(),
|
||||
UdpPacketPayload::Sack => tmp.field("Sack", &"").finish(),
|
||||
UdpPacketPayload::HolePunch(s) => tmp.field("HolePunch", &s.as_bytes()).finish(),
|
||||
UdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct UdpPacket {
|
||||
pub conn_id: u32,
|
||||
pub magic: u32,
|
||||
pub payload: UdpPacketPayload,
|
||||
}
|
||||
|
||||
const UDP_PACKET_MAGIC: u32 = 0x19941126;
|
||||
|
||||
impl std::fmt::Debug for ArchivedUdpPacketPayload {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
|
||||
|
@ -63,6 +81,7 @@ impl UdpPacket {
|
|||
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::Data(vec_to_string(data)),
|
||||
}
|
||||
}
|
||||
|
@ -70,6 +89,7 @@ impl UdpPacket {
|
|||
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
conn_id: 0,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::HolePunch(vec_to_string(data)),
|
||||
}
|
||||
}
|
||||
|
@ -77,6 +97,7 @@ impl UdpPacket {
|
|||
pub fn new_syn_packet(conn_id: u32) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::Syn,
|
||||
}
|
||||
}
|
||||
|
@ -84,6 +105,7 @@ impl UdpPacket {
|
|||
pub fn new_sack_packet(conn_id: u32) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::Sack,
|
||||
}
|
||||
}
|
||||
|
@ -100,6 +122,11 @@ fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
|
|||
return None;
|
||||
}
|
||||
|
||||
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||
tracing::warn!(?udp_packet, "udp magic not match");
|
||||
return None;
|
||||
}
|
||||
|
||||
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
|
||||
tracing::warn!(?udp_packet, "udp payload not data");
|
||||
return None;
|
||||
|
@ -189,7 +216,7 @@ pub struct UdpTunnelListener {
|
|||
socket: Option<Arc<UdpSocket>>,
|
||||
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
forward_tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
|
||||
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
|
||||
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
|
||||
|
@ -202,7 +229,7 @@ impl UdpTunnelListener {
|
|||
addr,
|
||||
socket: None,
|
||||
sock_map: Arc::new(DashMap::new()),
|
||||
forward_tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
|
||||
conn_recv,
|
||||
conn_send: Some(conn_send),
|
||||
}
|
||||
|
@ -234,7 +261,7 @@ impl UdpTunnelListener {
|
|||
async fn handle_connect(
|
||||
socket: Arc<UdpSocket>,
|
||||
addr: SocketAddr,
|
||||
forward_tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
local_url: url::Url,
|
||||
conn_id: u32,
|
||||
|
@ -251,7 +278,7 @@ impl UdpTunnelListener {
|
|||
let addr_copy = addr.clone();
|
||||
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
|
||||
let ctunnel_stream = ctunnel.pin_stream();
|
||||
forward_tasks.lock().await.spawn(async move {
|
||||
forward_tasks.lock().unwrap().spawn(async move {
|
||||
let ret = ctunnel_stream
|
||||
.map(|v| {
|
||||
tracing::trace!(?v, "udp stream recv something in forward task");
|
||||
|
@ -304,7 +331,7 @@ impl TunnelListener for UdpTunnelListener {
|
|||
let sock_map = self.sock_map.clone();
|
||||
let conn_send = self.conn_send.take().unwrap();
|
||||
let local_url = self.local_url().clone();
|
||||
self.forward_tasks.lock().await.spawn(
|
||||
self.forward_tasks.lock().unwrap().spawn(
|
||||
async move {
|
||||
loop {
|
||||
let mut buf = BytesMut::new();
|
||||
|
@ -323,6 +350,11 @@ impl TunnelListener for UdpTunnelListener {
|
|||
continue;
|
||||
};
|
||||
|
||||
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||
tracing::info!(?udp_packet, "udp magic not match");
|
||||
continue;
|
||||
}
|
||||
|
||||
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
|
||||
let Ok(conn) = Self::handle_connect(
|
||||
socket.clone(),
|
||||
|
@ -350,22 +382,7 @@ impl TunnelListener for UdpTunnelListener {
|
|||
.instrument(tracing::info_span!("udp forward task", ?self.socket)),
|
||||
);
|
||||
|
||||
// let forward_tasks_clone = self.forward_tasks.clone();
|
||||
// tokio::spawn(async move {
|
||||
// loop {
|
||||
// let mut locked_forward_tasks = forward_tasks_clone.lock().await;
|
||||
// tokio::select! {
|
||||
// ret = locked_forward_tasks.join_next() => {
|
||||
// tracing::warn!(?ret, "udp forward task exit");
|
||||
// }
|
||||
// else => {
|
||||
// drop(locked_forward_tasks);
|
||||
// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
// continue;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -453,6 +470,14 @@ impl UdpTunnelConnector {
|
|||
)));
|
||||
};
|
||||
|
||||
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||
tracing::info!(?udp_packet, "udp magic not match");
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, magic not match. magic: {:?}",
|
||||
udp_packet.magic
|
||||
)));
|
||||
}
|
||||
|
||||
if conn_id != udp_packet.conn_id {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, conn id not match. conn_id: {:?}, {:?}",
|
||||
|
|
Loading…
Reference in New Issue
Block a user