use customized rpc implementation, remove Tarpc & Tonic (#348)

This patch removes Tarpc & Tonic GRPC and implements a customized rpc framework, which can be used by peer rpc and cli interface.

web config server can also use this rpc framework.

moreover, rewrite the public server logic, use ospf route to implement public server based networking. this make public server mesh possible.
This commit is contained in:
Sijie.Sun 2024-09-18 21:55:28 +08:00 committed by GitHub
parent 0467b0a3dc
commit 1b03223537
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
77 changed files with 3844 additions and 2856 deletions

295
Cargo.lock generated
View File

@ -369,15 +369,6 @@ dependencies = [
"system-deps",
]
[[package]]
name = "atomic-polyfill"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4"
dependencies = [
"critical-section",
]
[[package]]
name = "atomic-shim"
version = "0.2.0"
@ -427,53 +418,6 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "axum"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
dependencies = [
"async-trait",
"axum-core",
"bytes",
"futures-util",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"itoa 1.0.11",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"sync_wrapper 1.0.1",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper 0.1.2",
"tower-layer",
"tower-service",
]
[[package]]
name = "backtrace"
version = "0.3.73"
@ -960,12 +904,6 @@ dependencies = [
"error-code",
]
[[package]]
name = "cobs"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15"
[[package]]
name = "cocoa"
version = "0.25.0"
@ -1176,12 +1114,6 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "critical-section"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7059fff8937831a9ae6f0fe4d658ffabf58f2ca96aa9dec1c889f936f705f216"
[[package]]
name = "crossbeam"
version = "0.8.4"
@ -1638,14 +1570,16 @@ dependencies = [
"petgraph",
"pin-project-lite",
"pnet",
"postcard",
"prost",
"prost-build",
"prost-types",
"quinn",
"rand 0.8.5",
"rcgen",
"regex",
"reqwest 0.11.27",
"ring 0.17.8",
"rpc_build",
"rstest",
"rust-i18n",
"rustls",
@ -1657,7 +1591,6 @@ dependencies = [
"sys-locale",
"tabled",
"tachyonix",
"tarpc",
"thiserror",
"time",
"timedmap",
@ -1668,7 +1601,6 @@ dependencies = [
"tokio-util",
"tokio-websockets",
"toml 0.8.19",
"tonic",
"tonic-build",
"tracing",
"tracing-appender",
@ -1736,12 +1668,6 @@ version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ef6b89e5b37196644d8796de5268852ff179b44e96276cf4290264843743bb7"
[[package]]
name = "embedded-io"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced"
[[package]]
name = "encoding"
version = "0.2.33"
@ -2532,25 +2458,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "h2"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa82e28a107a8cc405f0839610bdc9b15f1e25ec7d696aa5cf173edbcb1486ab"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
"http 1.1.0",
"indexmap 2.4.0",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "half"
version = "2.4.1"
@ -2561,15 +2468,6 @@ dependencies = [
"crunchy",
]
[[package]]
name = "hash32"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67"
dependencies = [
"byteorder",
]
[[package]]
name = "hash32"
version = "0.3.1"
@ -2591,27 +2489,13 @@ version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "heapless"
version = "0.7.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f"
dependencies = [
"atomic-polyfill",
"hash32 0.2.1",
"rustc_version",
"serde",
"spin 0.9.8",
"stable_deref_trait",
]
[[package]]
name = "heapless"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad"
dependencies = [
"hash32 0.3.1",
"hash32",
"stable_deref_trait",
]
@ -2754,12 +2638,6 @@ dependencies = [
"libm",
]
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "hyper"
version = "0.14.30"
@ -2770,7 +2648,7 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
"h2 0.3.26",
"h2",
"http 0.2.12",
"http-body 0.4.6",
"httparse",
@ -2793,11 +2671,9 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2 0.4.5",
"http 1.1.0",
"http-body 1.0.1",
"httparse",
"httpdate",
"itoa 1.0.11",
"pin-project-lite",
"smallvec",
@ -2805,19 +2681,6 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-timeout"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793"
dependencies = [
"hyper 1.4.1",
"hyper-util",
"pin-project-lite",
"tokio",
"tower-service",
]
[[package]]
name = "hyper-tls"
version = "0.5.0"
@ -3380,12 +3243,6 @@ version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5"
[[package]]
name = "matchit"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "md5"
version = "0.7.0"
@ -3954,25 +3811,6 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "opentelemetry"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6105e89802af13fdf48c49d7646d3b533a70e536d818aae7e78ba0433d01acb8"
dependencies = [
"async-trait",
"crossbeam-channel",
"futures-channel",
"futures-executor",
"futures-util",
"js-sys",
"lazy_static",
"percent-encoding",
"pin-project",
"rand 0.8.5",
"thiserror",
]
[[package]]
name = "option-ext"
version = "0.2.0"
@ -4482,18 +4320,6 @@ dependencies = [
"universal-hash",
]
[[package]]
name = "postcard"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a55c51ee6c0db07e68448e336cf8ea4131a620edefebf9893e759b2d793420f8"
dependencies = [
"cobs",
"embedded-io",
"heapless 0.7.17",
"serde",
]
[[package]]
name = "powerfmt"
version = "0.2.0"
@ -4606,9 +4432,9 @@ dependencies = [
[[package]]
name = "prost"
version = "0.13.1"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13db3d3fde688c61e2446b4d843bc27a7e8af269a69440c0308021dc92333cc"
checksum = "3b2ecbe40f08db5c006b5764a2645f7f3f141ce756412ac9e1dd6087e6d32995"
dependencies = [
"bytes",
"prost-derive",
@ -4616,9 +4442,9 @@ dependencies = [
[[package]]
name = "prost-build"
version = "0.13.1"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bb182580f71dd070f88d01ce3de9f4da5021db7115d2e1c3605a754153b77c1"
checksum = "f8650aabb6c35b860610e9cff5dc1af886c9e25073b7b1712a68972af4281302"
dependencies = [
"bytes",
"heck 0.5.0",
@ -4637,9 +4463,9 @@ dependencies = [
[[package]]
name = "prost-derive"
version = "0.13.1"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18bec9b0adc4eba778b33684b7ba3e7137789434769ee3ce3930463ef904cfca"
checksum = "acf0c195eebb4af52c752bec4f52f645da98b6e92077a04110c7f349477ae5ac"
dependencies = [
"anyhow",
"itertools 0.13.0",
@ -4650,9 +4476,9 @@ dependencies = [
[[package]]
name = "prost-types"
version = "0.13.1"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cee5168b05f49d4b0ca581206eb14a7b22fafd963efe729ac48eb03266e25cc2"
checksum = "60caa6738c7369b940c3d49246a8d1749323674c65cb13010134f5c9bad5b519"
dependencies = [
"prost",
]
@ -4939,7 +4765,7 @@ dependencies = [
"encoding_rs",
"futures-core",
"futures-util",
"h2 0.3.26",
"h2",
"http 0.2.12",
"http-body 0.4.6",
"hyper 0.14.30",
@ -5035,6 +4861,14 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rpc_build"
version = "0.1.0"
dependencies = [
"heck 0.5.0",
"prost-build",
]
[[package]]
name = "rstest"
version = "0.18.2"
@ -5667,7 +5501,7 @@ dependencies = [
"byteorder",
"cfg-if",
"defmt",
"heapless 0.8.0",
"heapless",
"managed",
]
@ -6000,40 +5834,6 @@ version = "0.12.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "tarpc"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f07cb5fb67b0a90ea954b5ffd2fac9944ffef5937c801b987d3f8913f0c37348"
dependencies = [
"anyhow",
"fnv",
"futures",
"humantime",
"opentelemetry",
"pin-project",
"rand 0.8.5",
"serde",
"static_assertions",
"tarpc-plugins",
"thiserror",
"tokio",
"tokio-util",
"tracing",
"tracing-opentelemetry",
]
[[package]]
name = "tarpc-plugins"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ee42b4e559f17bce0385ebf511a7beb67d5cc33c12c96b7f4e9789919d9c10f"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "tauri"
version = "2.0.0-rc.2"
@ -6590,7 +6390,6 @@ dependencies = [
"futures-core",
"futures-sink",
"pin-project-lite",
"slab",
"tokio",
]
@ -6696,36 +6495,6 @@ dependencies = [
"winnow 0.6.18",
]
[[package]]
name = "tonic"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38659f4a91aba8598d27821589f5db7dddd94601e7a01b1e485a50e5484c7401"
dependencies = [
"async-stream",
"async-trait",
"axum",
"base64 0.22.1",
"bytes",
"h2 0.4.5",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"hyper 1.4.1",
"hyper-timeout",
"hyper-util",
"percent-encoding",
"pin-project",
"prost",
"socket2",
"tokio",
"tokio-stream",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tonic-build"
version = "0.12.1"
@ -6747,16 +6516,11 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"indexmap 1.9.3",
"pin-project",
"pin-project-lite",
"rand 0.8.5",
"slab",
"tokio",
"tokio-util",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
@ -6827,19 +6591,6 @@ dependencies = [
"tracing-core",
]
[[package]]
name = "tracing-opentelemetry"
version = "0.17.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbbe89715c1dbbb790059e2565353978564924ee85017b5fff365c872ff6721f"
dependencies = [
"once_cell",
"opentelemetry",
"tracing",
"tracing-core",
"tracing-subscriber",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.18"

View File

@ -10,4 +10,3 @@ panic = "unwind"
panic = "abort"
lto = true
codegen-units = 1
strip = true

View File

@ -49,7 +49,7 @@ futures = { version = "0.3", features = ["bilock", "unstable"] }
tokio = { version = "1", features = ["full"] }
tokio-stream = "0.1"
tokio-util = { version = "0.7.9", features = ["codec", "net"] }
tokio-util = { version = "0.7.9", features = ["codec", "net", "io"] }
async-stream = "0.3.5"
async-trait = "0.1.74"
@ -101,14 +101,10 @@ uuid = { version = "1.5.0", features = [
crossbeam-queue = "0.3"
once_cell = "1.18.0"
# for packet
postcard = { "version" = "1.0.8", features = ["alloc"] }
# for rpc
tonic = "0.12"
prost = "0.13"
prost-types = "0.13"
anyhow = "1.0"
tarpc = { version = "0.32", features = ["tokio1", "serde1"] }
url = { version = "2.5", features = ["serde"] }
percent-encoding = "2.3.1"
@ -194,6 +190,8 @@ winreg = "0.52"
tonic-build = "0.12"
globwalk = "0.8.1"
regex = "1"
prost-build = "0.13.2"
rpc_build = { path = "src/proto/rpc_build" }
[target.'cfg(windows)'.build-dependencies]
reqwest = { version = "0.11", features = ["blocking"] }

View File

@ -129,14 +129,31 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(target_os = "windows")]
WindowsBuild::check_for_win();
tonic_build::configure()
.type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]")
.type_attribute("cli.DirectConnectedPeerInfo", "#[derive(Hash)]")
.type_attribute("cli.PeerInfoForGlobalMap", "#[derive(Hash)]")
prost_build::Config::new()
.type_attribute(".common", "#[derive(serde::Serialize, serde::Deserialize)]")
.type_attribute(".error", "#[derive(serde::Serialize, serde::Deserialize)]")
.type_attribute(".cli", "#[derive(serde::Serialize, serde::Deserialize)]")
.type_attribute(
"peer_rpc.GetIpListResponse",
"#[derive(serde::Serialize, serde::Deserialize)]",
)
.type_attribute("peer_rpc.DirectConnectedPeerInfo", "#[derive(Hash)]")
.type_attribute("peer_rpc.PeerInfoForGlobalMap", "#[derive(Hash)]")
.type_attribute("common.RpcDescriptor", "#[derive(Hash, Eq)]")
.service_generator(Box::new(rpc_build::ServiceGenerator::new()))
.btree_map(&["."])
.compile(&["proto/cli.proto"], &["proto/"])
.compile_protos(
&[
"src/proto/peer_rpc.proto",
"src/proto/common.proto",
"src/proto/error.proto",
"src/proto/tests.proto",
"src/proto/cli.proto",
],
&["src/proto/"],
)
.unwrap();
// tonic_build::compile_protos("proto/cli.proto")?;
check_locale();
Ok(())
}

View File

@ -31,8 +31,6 @@ pub enum Error {
// RpcListenError(String),
#[error("Rpc connect error: {0}")]
RpcConnectError(String),
#[error("Rpc error: {0}")]
RpcClientError(#[from] tarpc::client::RpcError),
#[error("Timeout error: {0}")]
Timeout(#[from] tokio::time::error::Elapsed),
#[error("url in blacklist")]

View File

@ -4,7 +4,7 @@ use std::{
sync::{Arc, Mutex},
};
use crate::rpc::PeerConnInfo;
use crate::proto::cli::PeerConnInfo;
use crossbeam::atomic::AtomicCell;
use super::{
@ -179,6 +179,10 @@ impl GlobalCtx {
self.config.get_network_identity()
}
pub fn get_network_name(&self) -> String {
self.get_network_identity().network_name
}
pub fn get_ip_collector(&self) -> Arc<IPCollector> {
self.ip_collector.clone()
}
@ -191,7 +195,6 @@ impl GlobalCtx {
self.stun_info_collection.as_ref()
}
#[cfg(test)]
pub fn replace_stun_info_collector(&self, collector: Box<dyn StunInfoCollectorTrait>) {
// force replace the stun_info_collection without mut and drop the old one
let ptr = &self.stun_info_collection as *const Box<dyn StunInfoCollectorTrait>;

View File

@ -1,12 +1,13 @@
use std::{net::IpAddr, ops::Deref, sync::Arc};
use crate::rpc::peer::GetIpListResponse;
use pnet::datalink::NetworkInterface;
use tokio::{
sync::{Mutex, RwLock},
task::JoinSet,
};
use crate::proto::peer_rpc::GetIpListResponse;
use super::{netns::NetNS, stun::StunInfoCollectorTrait};
pub const CACHED_IP_LIST_TIMEOUT_SEC: u64 = 60;
@ -163,7 +164,7 @@ pub struct IPCollector {
impl IPCollector {
pub fn new<T: StunInfoCollectorTrait + 'static>(net_ns: NetNS, stun_info_collector: T) -> Self {
Self {
cached_ip_list: Arc::new(RwLock::new(GetIpListResponse::new())),
cached_ip_list: Arc::new(RwLock::new(GetIpListResponse::default())),
collect_ip_task: Mutex::new(JoinSet::new()),
net_ns,
stun_info_collector: Arc::new(Box::new(stun_info_collector)),
@ -195,14 +196,18 @@ impl IPCollector {
let Ok(ip_addr) = ip.parse::<IpAddr>() else {
continue;
};
if ip_addr.is_ipv4() {
cached_ip_list.write().await.public_ipv4 = ip.clone();
} else {
cached_ip_list.write().await.public_ipv6 = ip.clone();
match ip_addr {
IpAddr::V4(v) => {
cached_ip_list.write().await.public_ipv4 = Some(v.into())
}
IpAddr::V6(v) => {
cached_ip_list.write().await.public_ipv6 = Some(v.into())
}
}
}
let sleep_sec = if !cached_ip_list.read().await.public_ipv4.is_empty() {
let sleep_sec = if !cached_ip_list.read().await.public_ipv4.is_none() {
CACHED_IP_LIST_TIMEOUT_SEC
} else {
3
@ -236,7 +241,7 @@ impl IPCollector {
#[tracing::instrument(skip(net_ns))]
async fn do_collect_local_ip_addrs(net_ns: NetNS) -> GetIpListResponse {
let mut ret = crate::rpc::peer::GetIpListResponse::new();
let mut ret = GetIpListResponse::default();
let ifaces = Self::collect_interfaces(net_ns.clone()).await;
let _g = net_ns.guard();
@ -246,25 +251,28 @@ impl IPCollector {
if ip.is_loopback() || ip.is_multicast() {
continue;
}
if ip.is_ipv4() {
ret.interface_ipv4s.push(ip.to_string());
} else if ip.is_ipv6() {
ret.interface_ipv6s.push(ip.to_string());
match ip {
std::net::IpAddr::V4(v4) => {
ret.interface_ipv4s.push(v4.into());
}
std::net::IpAddr::V6(v6) => {
ret.interface_ipv6s.push(v6.into());
}
}
}
}
if let Ok(v4_addr) = local_ipv4().await {
tracing::trace!("got local ipv4: {}", v4_addr);
if !ret.interface_ipv4s.contains(&v4_addr.to_string()) {
ret.interface_ipv4s.push(v4_addr.to_string());
if !ret.interface_ipv4s.contains(&v4_addr.into()) {
ret.interface_ipv4s.push(v4_addr.into());
}
}
if let Ok(v6_addr) = local_ipv6().await {
tracing::trace!("got local ipv6: {}", v6_addr);
if !ret.interface_ipv6s.contains(&v6_addr.to_string()) {
ret.interface_ipv6s.push(v6_addr.to_string());
if !ret.interface_ipv6s.contains(&v6_addr.into()) {
ret.interface_ipv6s.push(v6_addr.into());
}
}

View File

@ -1,9 +1,10 @@
use std::collections::BTreeSet;
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use crate::rpc::{NatType, StunInfo};
use crate::proto::common::{NatType, StunInfo};
use anyhow::Context;
use chrono::Local;
use crossbeam::atomic::AtomicCell;
@ -161,7 +162,7 @@ impl StunClient {
continue;
};
tracing::debug!(b = ?&udp_buf[..len], ?tids, ?remote_addr, ?stun_host, "recv stun response, msg: {:#?}", msg);
tracing::trace!(b = ?&udp_buf[..len], ?tids, ?remote_addr, ?stun_host, "recv stun response, msg: {:#?}", msg);
if msg.class() != MessageClass::SuccessResponse
|| msg.method() != BINDING
@ -216,7 +217,7 @@ impl StunClient {
changed_addr
}
#[tracing::instrument(ret, err, level = Level::DEBUG)]
#[tracing::instrument(ret, level = Level::TRACE)]
pub async fn bind_request(
self,
change_ip: bool,
@ -243,7 +244,7 @@ impl StunClient {
.encode_into_bytes(message.clone())
.with_context(|| "encode stun message")?;
tids.push(tid as u128);
tracing::debug!(?message, ?msg, tid, "send stun request");
tracing::trace!(?message, ?msg, tid, "send stun request");
self.socket
.send_to(msg.as_slice().into(), &stun_host)
.await?;
@ -276,7 +277,7 @@ impl StunClient {
latency_us: now.elapsed().as_micros() as u32,
};
tracing::debug!(
tracing::trace!(
?stun_host,
?recv_addr,
?changed_socket_addr,
@ -303,14 +304,14 @@ impl StunClientBuilder {
task_set.spawn(
async move {
let mut buf = [0; 1620];
tracing::info!("start stun packet listener");
tracing::trace!("start stun packet listener");
loop {
let Ok((len, addr)) = udp_clone.recv_from(&mut buf).await else {
tracing::error!("udp recv_from error");
break;
};
let data = buf[..len].to_vec();
tracing::debug!(?addr, ?data, "recv udp stun packet");
tracing::trace!(?addr, ?data, "recv udp stun packet");
let _ = stun_packet_sender_clone.send(StunPacket { data, addr });
}
}
@ -552,12 +553,15 @@ pub struct StunInfoCollector {
udp_nat_test_result: Arc<RwLock<Option<UdpNatTypeDetectResult>>>,
nat_test_result_time: Arc<AtomicCell<chrono::DateTime<Local>>>,
redetect_notify: Arc<tokio::sync::Notify>,
tasks: JoinSet<()>,
tasks: std::sync::Mutex<JoinSet<()>>,
started: AtomicBool,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for StunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
self.start_stun_routine();
let Some(result) = self.udp_nat_test_result.read().unwrap().clone() else {
return Default::default();
};
@ -572,6 +576,8 @@ impl StunInfoCollectorTrait for StunInfoCollector {
}
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> {
self.start_stun_routine();
let stun_servers = self
.udp_nat_test_result
.read()
@ -605,17 +611,14 @@ impl StunInfoCollectorTrait for StunInfoCollector {
impl StunInfoCollector {
pub fn new(stun_servers: Vec<String>) -> Self {
let mut ret = Self {
Self {
stun_servers: Arc::new(RwLock::new(stun_servers)),
udp_nat_test_result: Arc::new(RwLock::new(None)),
nat_test_result_time: Arc::new(AtomicCell::new(Local::now())),
redetect_notify: Arc::new(tokio::sync::Notify::new()),
tasks: JoinSet::new(),
};
ret.start_stun_routine();
ret
tasks: std::sync::Mutex::new(JoinSet::new()),
started: AtomicBool::new(false),
}
}
pub fn new_with_default_servers() -> Self {
@ -648,12 +651,18 @@ impl StunInfoCollector {
.collect()
}
fn start_stun_routine(&mut self) {
fn start_stun_routine(&self) {
if self.started.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
self.started
.store(true, std::sync::atomic::Ordering::Relaxed);
let stun_servers = self.stun_servers.clone();
let udp_nat_test_result = self.udp_nat_test_result.clone();
let udp_test_time = self.nat_test_result_time.clone();
let redetect_notify = self.redetect_notify.clone();
self.tasks.spawn(async move {
self.tasks.lock().unwrap().spawn(async move {
loop {
let servers = stun_servers.read().unwrap().clone();
// use first three and random choose one from the rest
@ -712,6 +721,31 @@ impl StunInfoCollector {
}
}
pub struct MockStunInfoCollector {
pub udp_nat_type: NatType,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for MockStunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
StunInfo {
udp_nat_type: self.udp_nat_type as i32,
tcp_nat_type: NatType::Unknown as i32,
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
min_port: 100,
max_port: 200,
..Default::default()
}
}
async fn get_udp_port_mapping(&self, mut port: u16) -> Result<std::net::SocketAddr, Error> {
if port == 0 {
port = 40144;
}
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -5,9 +5,17 @@ use std::{net::SocketAddr, sync::Arc};
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
peers::{peer_manager::PeerManager, peer_rpc::PeerRpcManager},
proto::{
peer_rpc::{
DirectConnectorRpc, DirectConnectorRpcClientFactory, DirectConnectorRpcServer,
GetIpListRequest, GetIpListResponse,
},
rpc_types::{self, controller::BaseController},
},
};
use crate::rpc::{peer::GetIpListResponse, PeerConnInfo};
use crate::proto::cli::PeerConnInfo;
use anyhow::Context;
use tokio::{task::JoinSet, time::timeout};
use tracing::Instrument;
use url::Host;
@ -17,11 +25,6 @@ use super::create_connector_by_url;
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300;
#[tarpc::service]
pub trait DirectConnectorRpc {
async fn get_ip_list() -> GetIpListResponse;
}
#[async_trait::async_trait]
pub trait PeerManagerForDirectConnector {
async fn list_peers(&self) -> Vec<PeerId>;
@ -57,12 +60,23 @@ struct DirectConnectorManagerRpcServer {
global_ctx: ArcGlobalCtx,
}
#[tarpc::server]
#[async_trait::async_trait]
impl DirectConnectorRpc for DirectConnectorManagerRpcServer {
async fn get_ip_list(self, _: tarpc::context::Context) -> GetIpListResponse {
type Controller = BaseController;
async fn get_ip_list(
&self,
_: BaseController,
_: GetIpListRequest,
) -> rpc_types::error::Result<GetIpListResponse> {
let mut ret = self.global_ctx.get_ip_collector().collect_ip_addrs().await;
ret.listeners = self.global_ctx.get_running_listeners();
ret
ret.listeners = self
.global_ctx
.get_running_listeners()
.into_iter()
.map(Into::into)
.collect();
Ok(ret)
}
}
@ -130,10 +144,17 @@ impl DirectConnectorManager {
}
pub fn run_as_server(&mut self) {
self.data.peer_manager.get_peer_rpc_mgr().run_service(
DIRECT_CONNECTOR_SERVICE_ID,
DirectConnectorManagerRpcServer::new(self.global_ctx.clone()).serve(),
);
self.data
.peer_manager
.get_peer_rpc_mgr()
.rpc_server()
.registry()
.register(
DirectConnectorRpcServer::new(DirectConnectorManagerRpcServer::new(
self.global_ctx.clone(),
)),
&self.data.global_ctx.get_network_name(),
);
}
pub fn run_as_client(&mut self) {
@ -238,7 +259,8 @@ impl DirectConnectorManager {
let enable_ipv6 = data.global_ctx.get_flags().enable_ipv6;
let available_listeners = ip_list
.listeners
.iter()
.into_iter()
.map(Into::<url::Url>::into)
.filter_map(|l| if l.scheme() != "ring" { Some(l) } else { None })
.filter(|l| l.port().is_some() && l.host().is_some())
.filter(|l| {
@ -268,7 +290,7 @@ impl DirectConnectorManager {
Some(SocketAddr::V4(_)) => {
ip_list.interface_ipv4s.iter().for_each(|ip| {
let mut addr = (*listener).clone();
if addr.set_host(Some(ip.as_str())).is_ok() {
if addr.set_host(Some(ip.to_string().as_str())).is_ok() {
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
@ -277,19 +299,27 @@ impl DirectConnectorManager {
}
});
let mut addr = (*listener).clone();
if addr.set_host(Some(ip_list.public_ipv4.as_str())).is_ok() {
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr.to_string(),
));
if let Some(public_ipv4) = ip_list.public_ipv4 {
let mut addr = (*listener).clone();
if addr
.set_host(Some(public_ipv4.to_string().as_str()))
.is_ok()
{
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr.to_string(),
));
}
}
}
Some(SocketAddr::V6(_)) => {
ip_list.interface_ipv6s.iter().for_each(|ip| {
let mut addr = (*listener).clone();
if addr.set_host(Some(format!("[{}]", ip).as_str())).is_ok() {
if addr
.set_host(Some(format!("[{}]", ip.to_string()).as_str()))
.is_ok()
{
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
@ -298,16 +328,18 @@ impl DirectConnectorManager {
}
});
let mut addr = (*listener).clone();
if addr
.set_host(Some(format!("[{}]", ip_list.public_ipv6).as_str()))
.is_ok()
{
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr.to_string(),
));
if let Some(public_ipv6) = ip_list.public_ipv6 {
let mut addr = (*listener).clone();
if addr
.set_host(Some(format!("[{}]", public_ipv6.to_string()).as_str()))
.is_ok()
{
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr.to_string(),
));
}
}
}
p => {
@ -351,16 +383,21 @@ impl DirectConnectorManager {
tracing::trace!("try direct connect to peer: {}", dst_peer_id);
let ip_list = peer_manager
let rpc_stub = peer_manager
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, dst_peer_id, |c| async {
let client =
DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn();
let ip_list = client.get_ip_list(tarpc::context::current()).await;
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
ip_list
})
.await?;
.rpc_client()
.scoped_client::<DirectConnectorRpcClientFactory<BaseController>>(
peer_manager.my_peer_id(),
dst_peer_id,
data.global_ctx.get_network_name(),
);
let ip_list = rpc_stub
.get_ip_list(BaseController {}, GetIpListRequest {})
.await
.with_context(|| format!("get ip list from peer {}", dst_peer_id))?;
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
Self::do_try_direct_connect_internal(data, dst_peer_id, ip_list).await
}
@ -380,7 +417,7 @@ mod tests {
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
wait_route_appear_with_cost,
},
rpc::peer::GetIpListResponse,
proto::peer_rpc::GetIpListResponse,
};
#[rstest::rstest]
@ -436,12 +473,14 @@ mod tests {
p_a.get_global_ctx(),
p_a.clone(),
));
let mut ip_list = GetIpListResponse::new();
let mut ip_list = GetIpListResponse::default();
ip_list
.listeners
.push("tcp://127.0.0.1:10222".parse().unwrap());
ip_list.interface_ipv4s.push("127.0.0.1".to_string());
ip_list
.interface_ipv4s
.push("127.0.0.1".parse::<std::net::Ipv4Addr>().unwrap().into());
DirectConnectorManager::do_try_direct_connect_internal(data.clone(), 1, ip_list.clone())
.await

View File

@ -11,7 +11,12 @@ use tokio::{
use crate::{
common::PeerId,
peers::peer_conn::PeerConnId,
rpc as easytier_rpc,
proto::{
cli::{
ConnectorManageAction, ListConnectorResponse, ManageConnectorResponse, PeerConnInfo,
},
rpc_types::{self, controller::BaseController},
},
tunnel::{IpVersion, TunnelConnector},
};
@ -23,9 +28,9 @@ use crate::{
},
connector::set_bind_addr_for_peer_connector,
peers::peer_manager::PeerManager,
rpc::{
connector_manage_rpc_server::ConnectorManageRpc, Connector, ConnectorStatus,
ListConnectorRequest, ManageConnectorRequest,
proto::cli::{
Connector, ConnectorManageRpc, ConnectorStatus, ListConnectorRequest,
ManageConnectorRequest,
},
use_global_var,
};
@ -105,12 +110,18 @@ impl ManualConnectorManager {
Ok(())
}
pub async fn remove_connector(&self, url: &str) -> Result<(), Error> {
pub async fn remove_connector(&self, url: url::Url) -> Result<(), Error> {
tracing::info!("remove_connector: {}", url);
if !self.list_connectors().await.iter().any(|x| x.url == url) {
let url = url.into();
if !self
.list_connectors()
.await
.iter()
.any(|x| x.url.as_ref() == Some(&url))
{
return Err(Error::NotFound);
}
self.data.removed_conn_urls.insert(url.into());
self.data.removed_conn_urls.insert(url.to_string());
Ok(())
}
@ -137,7 +148,7 @@ impl ManualConnectorManager {
ret.insert(
0,
Connector {
url: conn_url,
url: Some(conn_url.parse().unwrap()),
status: status.into(),
},
);
@ -154,7 +165,7 @@ impl ManualConnectorManager {
ret.insert(
0,
Connector {
url: conn_url,
url: Some(conn_url.parse().unwrap()),
status: ConnectorStatus::Connecting.into(),
},
);
@ -213,14 +224,14 @@ impl ManualConnectorManager {
}
async fn handle_event(event: &GlobalCtxEvent, data: &ConnectorManagerData) {
let need_add_alive = |conn_info: &easytier_rpc::PeerConnInfo| conn_info.is_client;
let need_add_alive = |conn_info: &PeerConnInfo| conn_info.is_client;
match event {
GlobalCtxEvent::PeerConnAdded(conn_info) => {
if !need_add_alive(conn_info) {
return;
}
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
data.alive_conn_urls.insert(addr);
data.alive_conn_urls.insert(addr.unwrap().to_string());
tracing::warn!("peer conn added: {:?}", conn_info);
}
@ -229,7 +240,7 @@ impl ManualConnectorManager {
return;
}
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
data.alive_conn_urls.remove(&addr);
data.alive_conn_urls.remove(&addr.unwrap().to_string());
tracing::warn!("peer conn removed: {:?}", conn_info);
}
@ -303,7 +314,7 @@ impl ManualConnectorManager {
tracing::info!("reconnect get tunnel succ: {:?}", tunnel);
assert_eq!(
dead_url,
tunnel.info().unwrap().remote_addr,
tunnel.info().unwrap().remote_addr.unwrap().to_string(),
"info: {:?}",
tunnel.info()
);
@ -385,45 +396,43 @@ impl ManualConnectorManager {
}
}
#[derive(Clone)]
pub struct ConnectorManagerRpcService(pub Arc<ManualConnectorManager>);
#[tonic::async_trait]
#[async_trait::async_trait]
impl ConnectorManageRpc for ConnectorManagerRpcService {
type Controller = BaseController;
async fn list_connector(
&self,
_request: tonic::Request<ListConnectorRequest>,
) -> Result<tonic::Response<easytier_rpc::ListConnectorResponse>, tonic::Status> {
let mut ret = easytier_rpc::ListConnectorResponse::default();
_: BaseController,
_request: ListConnectorRequest,
) -> Result<ListConnectorResponse, rpc_types::error::Error> {
let mut ret = ListConnectorResponse::default();
let connectors = self.0.list_connectors().await;
ret.connectors = connectors;
Ok(tonic::Response::new(ret))
Ok(ret)
}
async fn manage_connector(
&self,
request: tonic::Request<ManageConnectorRequest>,
) -> Result<tonic::Response<easytier_rpc::ManageConnectorResponse>, tonic::Status> {
let req = request.into_inner();
let url = url::Url::parse(&req.url)
.map_err(|_| tonic::Status::invalid_argument("invalid url"))?;
if req.action == easytier_rpc::ConnectorManageAction::Remove as i32 {
self.0.remove_connector(url.path()).await.map_err(|e| {
tonic::Status::invalid_argument(format!("remove connector failed: {:?}", e))
})?;
return Ok(tonic::Response::new(
easytier_rpc::ManageConnectorResponse::default(),
));
_: BaseController,
req: ManageConnectorRequest,
) -> Result<ManageConnectorResponse, rpc_types::error::Error> {
let url: url::Url = req.url.ok_or(anyhow::anyhow!("url is empty"))?.into();
if req.action == ConnectorManageAction::Remove as i32 {
self.0
.remove_connector(url.clone())
.await
.with_context(|| format!("remove connector failed: {:?}", url))?;
return Ok(ManageConnectorResponse::default());
} else {
self.0
.add_connector_by_url(url.as_str())
.await
.map_err(|e| {
tonic::Status::invalid_argument(format!("add connector failed: {:?}", e))
})?;
.with_context(|| format!("add connector failed: {:?}", url))?;
}
Ok(tonic::Response::new(
easytier_rpc::ManageConnectorResponse::default(),
))
Ok(ManageConnectorResponse::default())
}
}

View File

@ -32,14 +32,14 @@ async fn set_bind_addr_for_peer_connector(
if is_ipv4 {
let mut bind_addrs = vec![];
for ipv4 in ips.interface_ipv4s {
let socket_addr = SocketAddrV4::new(ipv4.parse().unwrap(), 0).into();
let socket_addr = SocketAddrV4::new(ipv4.into(), 0).into();
bind_addrs.push(socket_addr);
}
connector.set_bind_addrs(bind_addrs);
} else {
let mut bind_addrs = vec![];
for ipv6 in ips.interface_ipv6s {
let socket_addr = SocketAddrV6::new(ipv6.parse().unwrap(), 0, 0, 0).into();
let socket_addr = SocketAddrV6::new(ipv6.into(), 0, 0, 0).into();
bind_addrs.push(socket_addr);
}
connector.set_bind_addrs(bind_addrs);

View File

@ -5,6 +5,7 @@ use std::{
Arc,
},
time::Duration,
u16,
};
use anyhow::Context;
@ -21,12 +22,20 @@ use zerocopy::FromBytes;
use crate::{
common::{
constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS,
error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS,
scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId,
},
defer,
peers::peer_manager::PeerManager,
rpc::NatType,
proto::{
common::NatType,
peer_rpc::{
TryPunchHoleRequest, TryPunchHoleResponse, TryPunchSymmetricRequest,
TryPunchSymmetricResponse, UdpHolePunchRpc, UdpHolePunchRpcClientFactory,
UdpHolePunchRpcServer,
},
rpc_types::{self, controller::BaseController},
},
tunnel::{
common::setup_sokcet2,
packet_def::{UDPTunnelHeader, UdpPacketType, UDP_TUNNEL_HEADER_SIZE},
@ -186,21 +195,6 @@ impl std::fmt::Debug for UdpSocketArray {
}
}
#[tarpc::service]
pub trait UdpHolePunchService {
async fn try_punch_hole(local_mapped_addr: SocketAddr) -> Option<SocketAddr>;
async fn try_punch_symmetric(
listener_addr: SocketAddr,
port: u16,
public_ips: Vec<Ipv4Addr>,
min_port: u16,
max_port: u16,
transaction_id: u32,
round: u32,
last_port_index: usize,
) -> Option<usize>;
}
#[derive(Debug)]
struct UdpHolePunchListener {
socket: Arc<UdpSocket>,
@ -324,23 +318,34 @@ impl UdpHolePunchConnectorData {
}
#[derive(Clone)]
struct UdpHolePunchRpcServer {
struct UdpHolePunchRpcService {
data: Arc<UdpHolePunchConnectorData>,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
}
#[tarpc::server]
impl UdpHolePunchService for UdpHolePunchRpcServer {
#[async_trait::async_trait]
impl UdpHolePunchRpc for UdpHolePunchRpcService {
type Controller = BaseController;
#[tracing::instrument(skip(self))]
async fn try_punch_hole(
self,
_: tarpc::context::Context,
local_mapped_addr: SocketAddr,
) -> Option<SocketAddr> {
&self,
_: BaseController,
request: TryPunchHoleRequest,
) -> Result<TryPunchHoleResponse, rpc_types::error::Error> {
let local_mapped_addr = request.local_mapped_addr.ok_or(anyhow::anyhow!(
"try_punch_hole request missing local_mapped_addr"
))?;
let local_mapped_addr = std::net::SocketAddr::from(local_mapped_addr);
// local mapped addr will be unspecified if peer is symmetric
let peer_is_symmetric = local_mapped_addr.ip().is_unspecified();
let (socket, mapped_addr) = self.select_listener(peer_is_symmetric).await?;
let (socket, mapped_addr) =
self.select_listener(peer_is_symmetric)
.await
.ok_or(anyhow::anyhow!(
"failed to select listener for hole punching"
))?;
tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching");
if !peer_is_symmetric {
@ -380,32 +385,48 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
}
}
Some(mapped_addr)
Ok(TryPunchHoleResponse {
remote_mapped_addr: Some(mapped_addr.into()),
})
}
#[instrument(skip(self))]
async fn try_punch_symmetric(
self,
_: tarpc::context::Context,
listener_addr: SocketAddr,
port: u16,
public_ips: Vec<Ipv4Addr>,
mut min_port: u16,
mut max_port: u16,
transaction_id: u32,
round: u32,
last_port_index: usize,
) -> Option<usize> {
&self,
_: BaseController,
request: TryPunchSymmetricRequest,
) -> Result<TryPunchSymmetricResponse, rpc_types::error::Error> {
let listener_addr = request.listener_addr.ok_or(anyhow::anyhow!(
"try_punch_symmetric request missing listener_addr"
))?;
let listener_addr = std::net::SocketAddr::from(listener_addr);
let port = request.port as u16;
let public_ips = request
.public_ips
.into_iter()
.map(|ip| std::net::Ipv4Addr::from(ip))
.collect::<Vec<_>>();
let mut min_port = request.min_port as u16;
let mut max_port = request.max_port as u16;
let transaction_id = request.transaction_id;
let round = request.round;
let last_port_index = request.last_port_index as usize;
tracing::info!("try_punch_symmetric start");
let punch_predictablely = self.data.punch_predicablely.load(Ordering::Relaxed);
let punch_randomly = self.data.punch_randomly.load(Ordering::Relaxed);
let total_port_count = self.data.shuffled_port_vec.len();
let listener = self.find_listener(&listener_addr).await?;
let listener = self
.find_listener(&listener_addr)
.await
.ok_or(anyhow::anyhow!(
"try_punch_symmetric failed to find listener"
))?;
let ip_count = public_ips.len();
if ip_count == 0 {
tracing::warn!("try_punch_symmetric got zero len public ip");
return None;
return Err(anyhow::anyhow!("try_punch_symmetric got zero len public ip").into());
}
min_port = std::cmp::max(1, min_port);
@ -447,7 +468,7 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
&ports,
)
.await
.ok()?;
.with_context(|| "failed to send symmetric hole punch packet predict")?;
}
if punch_randomly {
@ -461,20 +482,22 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
&self.data.shuffled_port_vec[start..end],
)
.await
.ok()?;
.with_context(|| "failed to send symmetric hole punch packet randomly")?;
return if end >= self.data.shuffled_port_vec.len() {
Some(1)
Ok(TryPunchSymmetricResponse { last_port_index: 1 })
} else {
Some(end)
Ok(TryPunchSymmetricResponse {
last_port_index: end as u32,
})
};
}
return Some(1);
return Ok(TryPunchSymmetricResponse { last_port_index: 1 });
}
}
impl UdpHolePunchRpcServer {
impl UdpHolePunchRpcService {
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned());
@ -593,10 +616,15 @@ impl UdpHolePunchConnector {
}
pub async fn run_as_server(&mut self) -> Result<(), Error> {
self.data.peer_mgr.get_peer_rpc_mgr().run_service(
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
UdpHolePunchRpcServer::new(self.data.clone()).serve(),
);
self.data
.peer_mgr
.get_peer_rpc_mgr()
.rpc_server()
.registry()
.register(
UdpHolePunchRpcServer::new(UdpHolePunchRpcService::new(self.data.clone())),
&self.data.global_ctx.get_network_name(),
);
Ok(())
}
@ -736,26 +764,26 @@ impl UdpHolePunchConnector {
.with_context(|| "failed to get udp port mapping")?;
// client -> server: tell server the mapped port, server will return the mapped address of listening port.
let Some(remote_mapped_addr) = data
let rpc_stub = data
.peer_mgr
.get_peer_rpc_mgr()
.do_client_rpc_scoped(
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
.rpc_client()
.scoped_client::<UdpHolePunchRpcClientFactory<BaseController>>(
data.peer_mgr.my_peer_id(),
dst_peer_id,
|c| async {
let client =
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn();
let remote_mapped_addr = client
.try_punch_hole(tarpc::context::current(), local_mapped_addr)
.await;
tracing::info!(?remote_mapped_addr, ?dst_peer_id, "got remote mapped addr");
remote_mapped_addr
data.global_ctx.get_network_name(),
);
let remote_mapped_addr = rpc_stub
.try_punch_hole(
BaseController {},
TryPunchHoleRequest {
local_mapped_addr: Some(local_mapped_addr.into()),
},
)
.await?
else {
return Err(anyhow::anyhow!("failed to get remote mapped addr"));
};
.remote_mapped_addr
.ok_or(anyhow::anyhow!("failed to get remote mapped addr"))?;
// server: will send some punching resps, total 10 packets.
// client: use the socket to create UdpTunnel with UdpTunnelConnector
@ -769,9 +797,11 @@ impl UdpHolePunchConnector {
setup_sokcet2(&socket2_socket, &local_socket_addr)?;
let socket = Arc::new(UdpSocket::from_std(socket2_socket.into())?);
Ok(Self::try_connect_with_socket(socket, remote_mapped_addr)
.await
.with_context(|| "UdpTunnelConnector failed to connect remote")?)
Ok(
Self::try_connect_with_socket(socket, remote_mapped_addr.into())
.await
.with_context(|| "UdpTunnelConnector failed to connect remote")?,
)
}
#[tracing::instrument(err(level = Level::ERROR))]
@ -783,30 +813,28 @@ impl UdpHolePunchConnector {
return Err(anyhow::anyhow!("udp array not started"));
};
let Some(remote_mapped_addr) = data
let rpc_stub = data
.peer_mgr
.get_peer_rpc_mgr()
.do_client_rpc_scoped(
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
.rpc_client()
.scoped_client::<UdpHolePunchRpcClientFactory<BaseController>>(
data.peer_mgr.my_peer_id(),
dst_peer_id,
|c| async {
let client =
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn();
let remote_mapped_addr = client
.try_punch_hole(tarpc::context::current(), "0.0.0.0:0".parse().unwrap())
.await;
tracing::debug!(
?remote_mapped_addr,
?dst_peer_id,
"hole punching symmetric got remote mapped addr"
);
remote_mapped_addr
data.global_ctx.get_network_name(),
);
let local_mapped_addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
let remote_mapped_addr = rpc_stub
.try_punch_hole(
BaseController {},
TryPunchHoleRequest {
local_mapped_addr: Some(local_mapped_addr.into()),
},
)
.await?
else {
return Err(anyhow::anyhow!("failed to get remote mapped addr"));
};
.remote_mapped_addr
.ok_or(anyhow::anyhow!("failed to get remote mapped addr"))?
.into();
// try direct connect first
if data.try_direct_connect.load(Ordering::Relaxed) {
@ -852,38 +880,26 @@ impl UdpHolePunchConnector {
let mut last_port_idx = rand::thread_rng().gen_range(0..data.shuffled_port_vec.len());
for round in 0..5 {
let ret = data
.peer_mgr
.get_peer_rpc_mgr()
.do_client_rpc_scoped(
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
dst_peer_id,
|c| async {
let client =
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c)
.spawn();
let last_port_idx = client
.try_punch_symmetric(
tarpc::context::current(),
remote_mapped_addr,
port,
public_ips.clone(),
stun_info.min_port as u16,
stun_info.max_port as u16,
tid,
round,
last_port_idx,
)
.await;
tracing::info!(?last_port_idx, ?dst_peer_id, "punch symmetric return");
last_port_idx
let ret = rpc_stub
.try_punch_symmetric(
BaseController {},
TryPunchSymmetricRequest {
listener_addr: Some(remote_mapped_addr.into()),
port: port as u32,
public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(),
min_port: stun_info.min_port as u32,
max_port: stun_info.max_port as u32,
transaction_id: tid,
round,
last_port_index: last_port_idx as u32,
},
)
.await;
tracing::info!(?ret, "punch symmetric return");
let next_last_port_idx = match ret {
Ok(Some(idx)) => idx,
err => {
Ok(s) => s.last_port_index as usize,
Err(err) => {
tracing::error!(?err, "failed to get remote mapped addr");
rand::thread_rng().gen_range(0..data.shuffled_port_vec.len())
}
@ -1027,11 +1043,11 @@ pub mod tests {
use tokio::net::UdpSocket;
use crate::rpc::{NatType, StunInfo};
use crate::common::stun::MockStunInfoCollector;
use crate::proto::common::NatType;
use crate::tunnel::common::tests::wait_for_condition;
use crate::{
common::{error::Error, stun::StunInfoCollectorTrait},
connector::udp_hole_punch::UdpHolePunchConnector,
peers::{
peer_manager::PeerManager,
@ -1042,31 +1058,6 @@ pub mod tests {
},
};
struct MockStunInfoCollector {
udp_nat_type: NatType,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for MockStunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
StunInfo {
udp_nat_type: self.udp_nat_type as i32,
tcp_nat_type: NatType::Unknown as i32,
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
min_port: 100,
max_port: 200,
..Default::default()
}
}
async fn get_udp_port_mapping(&self, mut port: u16) -> Result<std::net::SocketAddr, Error> {
if port == 0 {
port = 40144;
}
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
}
}
pub fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, udp_nat_type: NatType) {
let collector = Box::new(MockStunInfoCollector { udp_nat_type });
peer_mgr

View File

@ -1,26 +1,29 @@
#![allow(dead_code)]
use std::{net::SocketAddr, time::Duration, vec};
use std::{net::SocketAddr, sync::Mutex, time::Duration, vec};
use anyhow::{Context, Ok};
use clap::{command, Args, Parser, Subcommand};
use common::stun::StunInfoCollectorTrait;
use rpc::vpn_portal_rpc_client::VpnPortalRpcClient;
use proto::{
common::NatType,
peer_rpc::{GetGlobalPeerMapRequest, PeerCenterRpc, PeerCenterRpcClientFactory},
rpc_impl::standalone::StandAloneClient,
rpc_types::controller::BaseController,
};
use tokio::time::timeout;
use tunnel::tcp::TcpTunnelConnector;
use utils::{list_peer_route_pair, PeerRoutePair};
mod arch;
mod common;
mod rpc;
mod proto;
mod tunnel;
mod utils;
use crate::{
common::stun::StunInfoCollector,
rpc::{
connector_manage_rpc_client::ConnectorManageRpcClient,
peer_center_rpc_client::PeerCenterRpcClient, peer_manage_rpc_client::PeerManageRpcClient,
*,
},
proto::cli::*,
utils::{cost_to_str, float_to_str},
};
use humansize::format_size;
@ -114,58 +117,76 @@ struct NodeArgs {
sub_command: Option<NodeSubCommand>,
}
#[derive(thiserror::Error, Debug)]
enum Error {
#[error("tonic transport error")]
TonicTransportError(#[from] tonic::transport::Error),
#[error("tonic rpc error")]
TonicRpcError(#[from] tonic::Status),
#[error("anyhow error")]
Anyhow(#[from] anyhow::Error),
}
type Error = anyhow::Error;
struct CommandHandler {
addr: String,
client: Mutex<RpcClient>,
verbose: bool,
}
type RpcClient = StandAloneClient<TcpTunnelConnector>;
impl CommandHandler {
async fn get_peer_manager_client(
&self,
) -> Result<PeerManageRpcClient<tonic::transport::Channel>, Error> {
Ok(PeerManageRpcClient::connect(self.addr.clone()).await?)
) -> Result<Box<dyn PeerManageRpc<Controller = BaseController>>, Error> {
Ok(self
.client
.lock()
.unwrap()
.scoped_client::<PeerManageRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get peer manager client")?)
}
async fn get_connector_manager_client(
&self,
) -> Result<ConnectorManageRpcClient<tonic::transport::Channel>, Error> {
Ok(ConnectorManageRpcClient::connect(self.addr.clone()).await?)
) -> Result<Box<dyn ConnectorManageRpc<Controller = BaseController>>, Error> {
Ok(self
.client
.lock()
.unwrap()
.scoped_client::<ConnectorManageRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get connector manager client")?)
}
async fn get_peer_center_client(
&self,
) -> Result<PeerCenterRpcClient<tonic::transport::Channel>, Error> {
Ok(PeerCenterRpcClient::connect(self.addr.clone()).await?)
) -> Result<Box<dyn PeerCenterRpc<Controller = BaseController>>, Error> {
Ok(self
.client
.lock()
.unwrap()
.scoped_client::<PeerCenterRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get peer center client")?)
}
async fn get_vpn_portal_client(
&self,
) -> Result<VpnPortalRpcClient<tonic::transport::Channel>, Error> {
Ok(VpnPortalRpcClient::connect(self.addr.clone()).await?)
) -> Result<Box<dyn VpnPortalRpc<Controller = BaseController>>, Error> {
Ok(self
.client
.lock()
.unwrap()
.scoped_client::<VpnPortalRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get vpn portal client")?)
}
async fn list_peers(&self) -> Result<ListPeerResponse, Error> {
let mut client = self.get_peer_manager_client().await?;
let request = tonic::Request::new(ListPeerRequest::default());
let response = client.list_peer(request).await?;
Ok(response.into_inner())
let client = self.get_peer_manager_client().await?;
let request = ListPeerRequest::default();
let response = client.list_peer(BaseController {}, request).await?;
Ok(response)
}
async fn list_routes(&self) -> Result<ListRouteResponse, Error> {
let mut client = self.get_peer_manager_client().await?;
let request = tonic::Request::new(ListRouteRequest::default());
let response = client.list_route(request).await?;
Ok(response.into_inner())
let client = self.get_peer_manager_client().await?;
let request = ListRouteRequest::default();
let response = client.list_route(BaseController {}, request).await?;
Ok(response)
}
async fn list_peer_route_pair(&self) -> Result<Vec<PeerRoutePair>, Error> {
@ -251,11 +272,10 @@ impl CommandHandler {
return Ok(());
}
let mut client = self.get_peer_manager_client().await?;
let client = self.get_peer_manager_client().await?;
let node_info = client
.show_node_info(ShowNodeInfoRequest::default())
.show_node_info(BaseController {}, ShowNodeInfoRequest::default())
.await?
.into_inner()
.node_info
.ok_or(anyhow::anyhow!("node info not found"))?;
items.push(node_info.into());
@ -273,18 +293,20 @@ impl CommandHandler {
}
async fn handle_route_dump(&self) -> Result<(), Error> {
let mut client = self.get_peer_manager_client().await?;
let request = tonic::Request::new(DumpRouteRequest::default());
let response = client.dump_route(request).await?;
println!("response: {}", response.into_inner().result);
let client = self.get_peer_manager_client().await?;
let request = DumpRouteRequest::default();
let response = client.dump_route(BaseController {}, request).await?;
println!("response: {}", response.result);
Ok(())
}
async fn handle_foreign_network_list(&self) -> Result<(), Error> {
let mut client = self.get_peer_manager_client().await?;
let request = tonic::Request::new(ListForeignNetworkRequest::default());
let response = client.list_foreign_network(request).await?;
let network_map = response.into_inner();
let client = self.get_peer_manager_client().await?;
let request = ListForeignNetworkRequest::default();
let response = client
.list_foreign_network(BaseController {}, request)
.await?;
let network_map = response;
if self.verbose {
println!("{:#?}", network_map);
return Ok(());
@ -303,7 +325,7 @@ impl CommandHandler {
"remote_addr: {}, rx_bytes: {}, tx_bytes: {}, latency_us: {}",
conn.tunnel
.as_ref()
.map(|t| t.remote_addr.clone())
.map(|t| t.remote_addr.clone().unwrap_or_default())
.unwrap_or_default(),
conn.stats.as_ref().map(|s| s.rx_bytes).unwrap_or_default(),
conn.stats.as_ref().map(|s| s.tx_bytes).unwrap_or_default(),
@ -334,11 +356,10 @@ impl CommandHandler {
}
let mut items: Vec<RouteTableItem> = vec![];
let mut client = self.get_peer_manager_client().await?;
let client = self.get_peer_manager_client().await?;
let node_info = client
.show_node_info(ShowNodeInfoRequest::default())
.show_node_info(BaseController {}, ShowNodeInfoRequest::default())
.await?
.into_inner()
.node_info
.ok_or(anyhow::anyhow!("node info not found"))?;
@ -403,10 +424,10 @@ impl CommandHandler {
}
async fn handle_connector_list(&self) -> Result<(), Error> {
let mut client = self.get_connector_manager_client().await?;
let request = tonic::Request::new(ListConnectorRequest::default());
let response = client.list_connector(request).await?;
println!("response: {:#?}", response.into_inner());
let client = self.get_connector_manager_client().await?;
let request = ListConnectorRequest::default();
let response = client.list_connector(BaseController {}, request).await?;
println!("response: {:#?}", response);
Ok(())
}
}
@ -415,8 +436,13 @@ impl CommandHandler {
#[tracing::instrument]
async fn main() -> Result<(), Error> {
let cli = Cli::parse();
let client = RpcClient::new(TcpTunnelConnector::new(
format!("tcp://{}:{}", cli.rpc_portal.ip(), cli.rpc_portal.port())
.parse()
.unwrap(),
));
let handler = CommandHandler {
addr: format!("http://{}:{}", cli.rpc_portal.ip(), cli.rpc_portal.port()),
client: Mutex::new(client),
verbose: cli.verbose,
};
@ -476,11 +502,10 @@ async fn main() -> Result<(), Error> {
.unwrap();
}
SubCommand::PeerCenter => {
let mut peer_center_client = handler.get_peer_center_client().await?;
let peer_center_client = handler.get_peer_center_client().await?;
let resp = peer_center_client
.get_global_peer_map(GetGlobalPeerMapRequest::default())
.await?
.into_inner();
.get_global_peer_map(BaseController {}, GetGlobalPeerMapRequest::default())
.await?;
#[derive(tabled::Tabled)]
struct PeerCenterTableItem {
@ -510,11 +535,10 @@ async fn main() -> Result<(), Error> {
);
}
SubCommand::VpnPortal => {
let mut vpn_portal_client = handler.get_vpn_portal_client().await?;
let vpn_portal_client = handler.get_vpn_portal_client().await?;
let resp = vpn_portal_client
.get_vpn_portal_info(GetVpnPortalInfoRequest::default())
.get_vpn_portal_info(BaseController {}, GetVpnPortalInfoRequest::default())
.await?
.into_inner()
.vpn_portal_info
.unwrap_or_default();
println!("portal_name: {}", resp.vpn_type);
@ -529,11 +553,10 @@ async fn main() -> Result<(), Error> {
println!("connected_clients:\n{:#?}", resp.connected_clients);
}
SubCommand::Node(sub_cmd) => {
let mut client = handler.get_peer_manager_client().await?;
let client = handler.get_peer_manager_client().await?;
let node_info = client
.show_node_info(ShowNodeInfoRequest::default())
.show_node_info(BaseController {}, ShowNodeInfoRequest::default())
.await?
.into_inner()
.node_info
.ok_or(anyhow::anyhow!("node info not found"))?;
match sub_cmd.sub_command {

View File

@ -21,7 +21,7 @@ mod gateway;
mod instance;
mod peer_center;
mod peers;
mod rpc;
mod proto;
mod tunnel;
mod utils;
mod vpn_portal;
@ -548,7 +548,7 @@ fn print_event(msg: String) {
);
}
fn peer_conn_info_to_string(p: crate::rpc::PeerConnInfo) -> String {
fn peer_conn_info_to_string(p: crate::proto::cli::PeerConnInfo) -> String {
format!(
"my_peer_id: {}, dst_peer_id: {}, tunnel_info: {:?}",
p.my_peer_id, p.peer_id, p.tunnel

View File

@ -8,8 +8,6 @@ use anyhow::Context;
use cidr::Ipv4Inet;
use tokio::{sync::Mutex, task::JoinSet};
use tonic::transport::server::TcpIncoming;
use tonic::transport::Server;
use crate::common::config::ConfigLoader;
use crate::common::error::Error;
@ -26,8 +24,13 @@ use crate::peers::peer_conn::PeerConnId;
use crate::peers::peer_manager::{PeerManager, RouteAlgoType};
use crate::peers::rpc_service::PeerManagerRpcService;
use crate::peers::PacketRecvChanReceiver;
use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc;
use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
use crate::proto::cli::VpnPortalRpc;
use crate::proto::cli::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
use crate::proto::peer_rpc::PeerCenterRpcServer;
use crate::proto::rpc_impl::standalone::StandAloneServer;
use crate::proto::rpc_types;
use crate::proto::rpc_types::controller::BaseController;
use crate::tunnel::tcp::TcpTunnelListener;
use crate::vpn_portal::{self, VpnPortal};
use super::listeners::ListenerManager;
@ -104,8 +107,6 @@ pub struct Instance {
nic_ctx: ArcNicCtx,
tasks: JoinSet<()>,
peer_packet_receiver: Arc<Mutex<PacketRecvChanReceiver>>,
peer_manager: Arc<PeerManager>,
listener_manager: Arc<Mutex<ListenerManager<PeerManager>>>,
@ -122,6 +123,8 @@ pub struct Instance {
#[cfg(feature = "socks5")]
socks5_server: Arc<Socks5Server>,
rpc_server: Option<StandAloneServer<TcpTunnelListener>>,
global_ctx: ArcGlobalCtx,
}
@ -170,6 +173,12 @@ impl Instance {
#[cfg(feature = "socks5")]
let socks5_server = Socks5Server::new(global_ctx.clone(), peer_manager.clone(), None);
let rpc_server = global_ctx.config.get_rpc_portal().and_then(|s| {
Some(StandAloneServer::new(TcpTunnelListener::new(
format!("tcp://{}", s).parse().unwrap(),
)))
});
Instance {
inst_name: global_ctx.inst_name.clone(),
id,
@ -177,7 +186,6 @@ impl Instance {
peer_packet_receiver: Arc::new(Mutex::new(peer_packet_receiver)),
nic_ctx: Arc::new(Mutex::new(None)),
tasks: JoinSet::new(),
peer_manager,
listener_manager,
conn_manager,
@ -193,6 +201,8 @@ impl Instance {
#[cfg(feature = "socks5")]
socks5_server,
rpc_server,
global_ctx,
}
}
@ -375,7 +385,7 @@ impl Instance {
self.check_dhcp_ip_conflict();
}
self.run_rpc_server()?;
self.run_rpc_server().await?;
// run after tun device created, so listener can bind to tun device, which may be required by win 10
self.ip_proxy = Some(IpProxy::new(
@ -441,11 +451,8 @@ impl Instance {
Ok(())
}
pub async fn wait(&mut self) {
while let Some(ret) = self.tasks.join_next().await {
tracing::info!("task finished: {:?}", ret);
ret.unwrap();
}
pub async fn wait(&self) {
self.peer_manager.wait().await;
}
pub fn id(&self) -> uuid::Uuid {
@ -456,24 +463,28 @@ impl Instance {
self.peer_manager.my_peer_id()
}
fn get_vpn_portal_rpc_service(&self) -> impl VpnPortalRpc {
fn get_vpn_portal_rpc_service(&self) -> impl VpnPortalRpc<Controller = BaseController> + Clone {
#[derive(Clone)]
struct VpnPortalRpcService {
peer_mgr: Weak<PeerManager>,
vpn_portal: Weak<Mutex<Box<dyn VpnPortal>>>,
}
#[tonic::async_trait]
#[async_trait::async_trait]
impl VpnPortalRpc for VpnPortalRpcService {
type Controller = BaseController;
async fn get_vpn_portal_info(
&self,
_request: tonic::Request<GetVpnPortalInfoRequest>,
) -> Result<tonic::Response<GetVpnPortalInfoResponse>, tonic::Status> {
_: BaseController,
_request: GetVpnPortalInfoRequest,
) -> Result<GetVpnPortalInfoResponse, rpc_types::error::Error> {
let Some(vpn_portal) = self.vpn_portal.upgrade() else {
return Err(tonic::Status::unavailable("vpn portal not available"));
return Err(anyhow::anyhow!("vpn portal not available").into());
};
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
return Err(tonic::Status::unavailable("peer manager not available"));
return Err(anyhow::anyhow!("peer manager not available").into());
};
let vpn_portal = vpn_portal.lock().await;
@ -485,7 +496,7 @@ impl Instance {
}),
};
Ok(tonic::Response::new(ret))
Ok(ret)
}
}
@ -495,46 +506,36 @@ impl Instance {
}
}
fn run_rpc_server(&mut self) -> Result<(), Error> {
let Some(addr) = self.global_ctx.config.get_rpc_portal() else {
async fn run_rpc_server(&mut self) -> Result<(), Error> {
let Some(_) = self.global_ctx.config.get_rpc_portal() else {
tracing::info!("rpc server not enabled, because rpc_portal is not set.");
return Ok(());
};
use crate::proto::cli::*;
let peer_mgr = self.peer_manager.clone();
let conn_manager = self.conn_manager.clone();
let net_ns = self.global_ctx.net_ns.clone();
let peer_center = self.peer_center.clone();
let vpn_portal_rpc = self.get_vpn_portal_rpc_service();
let incoming = TcpIncoming::new(addr, true, None)
.map_err(|e| anyhow::anyhow!("create rpc server failed. addr: {}, err: {}", addr, e))?;
self.tasks.spawn(async move {
let _g = net_ns.guard();
Server::builder()
.add_service(
crate::rpc::peer_manage_rpc_server::PeerManageRpcServer::new(
PeerManagerRpcService::new(peer_mgr),
),
)
.add_service(
crate::rpc::connector_manage_rpc_server::ConnectorManageRpcServer::new(
ConnectorManagerRpcService(conn_manager.clone()),
),
)
.add_service(
crate::rpc::peer_center_rpc_server::PeerCenterRpcServer::new(
peer_center.get_rpc_service(),
),
)
.add_service(crate::rpc::vpn_portal_rpc_server::VpnPortalRpcServer::new(
vpn_portal_rpc,
))
.serve_with_incoming(incoming)
.await
.with_context(|| format!("rpc server failed. addr: {}", addr))
.unwrap();
});
Ok(())
let s = self.rpc_server.as_mut().unwrap();
s.registry().register(
PeerManageRpcServer::new(PeerManagerRpcService::new(peer_mgr)),
"",
);
s.registry().register(
ConnectorManageRpcServer::new(ConnectorManagerRpcService(conn_manager)),
"",
);
s.registry()
.register(PeerCenterRpcServer::new(peer_center.get_rpc_service()), "");
s.registry()
.register(VpnPortalRpcServer::new(vpn_portal_rpc), "");
let _g = self.global_ctx.net_ns.guard();
Ok(s.serve().await.with_context(|| "rpc server start failed")?)
}
pub fn get_global_ctx(&self) -> ArcGlobalCtx {

View File

@ -159,8 +159,16 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
let tunnel_info = ret.info().unwrap();
global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted(
tunnel_info.local_addr.clone(),
tunnel_info.remote_addr.clone(),
tunnel_info
.local_addr
.clone()
.unwrap_or_default()
.to_string(),
tunnel_info
.remote_addr
.clone()
.unwrap_or_default()
.to_string(),
));
tracing::info!(ret = ?ret, "conn accepted");
let peer_manager = peer_manager.clone();
@ -169,8 +177,8 @@ impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManage
let server_ret = peer_manager.handle_tunnel(ret).await;
if let Err(e) = &server_ret {
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
tunnel_info.local_addr,
tunnel_info.remote_addr,
tunnel_info.local_addr.unwrap_or_default().to_string(),
tunnel_info.remote_addr.unwrap_or_default().to_string(),
e.to_string(),
));
tracing::error!(error = ?e, "handle conn error");

View File

@ -11,9 +11,10 @@ use crate::{
},
instance::instance::Instance,
peers::rpc_service::PeerManagerRpcService,
rpc::{
cli::{PeerInfo, Route, StunInfo},
peer::GetIpListResponse,
proto::{
cli::{PeerInfo, Route},
common::StunInfo,
peer_rpc::GetIpListResponse,
},
utils::{list_peer_route_pair, PeerRoutePair},
};

View File

@ -6,11 +6,11 @@ mod gateway;
mod instance;
mod peer_center;
mod peers;
mod proto;
mod vpn_portal;
pub mod common;
pub mod launcher;
pub mod rpc;
pub mod tunnel;
pub mod utils;

View File

@ -1,7 +1,7 @@
use std::{
collections::BTreeSet,
sync::Arc,
time::{Duration, Instant, SystemTime},
time::{Duration, Instant},
};
use crossbeam::atomic::AtomicCell;
@ -18,14 +18,17 @@ use crate::{
route_trait::{RouteCostCalculator, RouteCostCalculatorInterface},
rpc_service::PeerManagerRpcService,
},
rpc::{GetGlobalPeerMapRequest, GetGlobalPeerMapResponse},
proto::{
peer_rpc::{
GetGlobalPeerMapRequest, GetGlobalPeerMapResponse, GlobalPeerMap, PeerCenterRpc,
PeerCenterRpcClientFactory, PeerCenterRpcServer, PeerInfoForGlobalMap,
ReportPeersRequest, ReportPeersResponse,
},
rpc_types::{self, controller::BaseController},
},
};
use super::{
server::PeerCenterServer,
service::{GlobalPeerMap, PeerCenterService, PeerCenterServiceClient, PeerInfoForGlobalMap},
Digest, Error,
};
use super::{server::PeerCenterServer, Digest, Error};
struct PeerCenterBase {
peer_mgr: Arc<PeerManager>,
@ -44,11 +47,14 @@ struct PeridicJobCtx<T> {
impl PeerCenterBase {
pub async fn init(&self) -> Result<(), Error> {
self.peer_mgr.get_peer_rpc_mgr().run_service(
SERVICE_ID,
PeerCenterServer::new(self.peer_mgr.my_peer_id()).serve(),
);
self.peer_mgr
.get_peer_rpc_mgr()
.rpc_server()
.registry()
.register(
PeerCenterRpcServer::new(PeerCenterServer::new(self.peer_mgr.my_peer_id())),
&self.peer_mgr.get_global_ctx().get_network_name(),
);
Ok(())
}
@ -70,11 +76,17 @@ impl PeerCenterBase {
async fn init_periodic_job<
T: Send + Sync + 'static + Clone,
Fut: Future<Output = Result<u32, tarpc::client::RpcError>> + Send + 'static,
Fut: Future<Output = Result<u32, rpc_types::error::Error>> + Send + 'static,
>(
&self,
job_ctx: T,
job_fn: (impl Fn(PeerCenterServiceClient, Arc<PeridicJobCtx<T>>) -> Fut + Send + Sync + 'static),
job_fn: (impl Fn(
Box<dyn PeerCenterRpc<Controller = BaseController> + Send>,
Arc<PeridicJobCtx<T>>,
) -> Fut
+ Send
+ Sync
+ 'static),
) -> () {
let my_peer_id = self.peer_mgr.my_peer_id();
let peer_mgr = self.peer_mgr.clone();
@ -96,14 +108,14 @@ impl PeerCenterBase {
tracing::trace!(?center_peer, "run periodic job");
let rpc_mgr = peer_mgr.get_peer_rpc_mgr();
let _g = lock.lock().await;
let ret = rpc_mgr
.do_client_rpc_scoped(SERVICE_ID, center_peer, |c| async {
let client =
PeerCenterServiceClient::new(tarpc::client::Config::default(), c)
.spawn();
job_fn(client, ctx.clone()).await
})
.await;
let stub = rpc_mgr
.rpc_client()
.scoped_client::<PeerCenterRpcClientFactory<BaseController>>(
my_peer_id,
center_peer,
peer_mgr.get_global_ctx().get_network_name(),
);
let ret = job_fn(stub, ctx.clone()).await;
drop(_g);
let Ok(sleep_time_ms) = ret else {
@ -130,25 +142,34 @@ impl PeerCenterBase {
}
}
#[derive(Clone)]
pub struct PeerCenterInstanceService {
global_peer_map: Arc<RwLock<GlobalPeerMap>>,
global_peer_map_digest: Arc<AtomicCell<Digest>>,
}
#[tonic::async_trait]
impl crate::rpc::cli::peer_center_rpc_server::PeerCenterRpc for PeerCenterInstanceService {
#[async_trait::async_trait]
impl PeerCenterRpc for PeerCenterInstanceService {
type Controller = BaseController;
async fn get_global_peer_map(
&self,
_request: tonic::Request<GetGlobalPeerMapRequest>,
) -> Result<tonic::Response<GetGlobalPeerMapResponse>, tonic::Status> {
let global_peer_map = self.global_peer_map.read().unwrap().clone();
Ok(tonic::Response::new(GetGlobalPeerMapResponse {
global_peer_map: global_peer_map
.map
.into_iter()
.map(|(k, v)| (k, v))
.collect(),
}))
_: BaseController,
_: GetGlobalPeerMapRequest,
) -> Result<GetGlobalPeerMapResponse, rpc_types::error::Error> {
let global_peer_map = self.global_peer_map.read().unwrap();
Ok(GetGlobalPeerMapResponse {
global_peer_map: global_peer_map.map.clone(),
digest: Some(self.global_peer_map_digest.load()),
})
}
async fn report_peers(
&self,
_: BaseController,
_req: ReportPeersRequest,
) -> Result<ReportPeersResponse, rpc_types::error::Error> {
Err(anyhow::anyhow!("not implemented").into())
}
}
@ -166,7 +187,7 @@ impl PeerCenterInstance {
PeerCenterInstance {
peer_mgr: peer_mgr.clone(),
client: Arc::new(PeerCenterBase::new(peer_mgr.clone())),
global_peer_map: Arc::new(RwLock::new(GlobalPeerMap::new())),
global_peer_map: Arc::new(RwLock::new(GlobalPeerMap::default())),
global_peer_map_digest: Arc::new(AtomicCell::new(Digest::default())),
global_peer_map_update_time: Arc::new(AtomicCell::new(Instant::now())),
}
@ -193,9 +214,6 @@ impl PeerCenterInstance {
self.client
.init_periodic_job(ctx, |client, ctx| async move {
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3);
if ctx
.job_ctx
.global_peer_map_update_time
@ -208,8 +226,13 @@ impl PeerCenterInstance {
}
let ret = client
.get_global_peer_map(rpc_ctx, ctx.job_ctx.global_peer_map_digest.load())
.await?;
.get_global_peer_map(
BaseController {},
GetGlobalPeerMapRequest {
digest: ctx.job_ctx.global_peer_map_digest.load(),
},
)
.await;
let Ok(resp) = ret else {
tracing::error!(
@ -219,9 +242,10 @@ impl PeerCenterInstance {
return Ok(1000);
};
let Some(resp) = resp else {
if resp == GetGlobalPeerMapResponse::default() {
// digest match, no need to update
return Ok(5000);
};
}
tracing::info!(
"get global info from center server: {:?}, digest: {:?}",
@ -229,8 +253,12 @@ impl PeerCenterInstance {
resp.digest
);
*ctx.job_ctx.global_peer_map.write().unwrap() = resp.global_peer_map;
ctx.job_ctx.global_peer_map_digest.store(resp.digest);
*ctx.job_ctx.global_peer_map.write().unwrap() = GlobalPeerMap {
map: resp.global_peer_map,
};
ctx.job_ctx
.global_peer_map_digest
.store(resp.digest.unwrap_or_default());
ctx.job_ctx
.global_peer_map_update_time
.store(Instant::now());
@ -274,12 +302,15 @@ impl PeerCenterInstance {
return Ok(5000);
}
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3);
let ret = client
.report_peers(rpc_ctx, my_node_id.clone(), peers)
.await?;
.report_peers(
BaseController {},
ReportPeersRequest {
my_peer_id: my_node_id,
peer_infos: Some(peers),
},
)
.await;
if ret.is_ok() {
ctx.job_ctx.last_center_peer.store(ctx.center_peer.load());
@ -339,7 +370,7 @@ impl PeerCenterInstance {
Box::new(RouteCostCalculatorImpl {
global_peer_map: self.global_peer_map.clone(),
global_peer_map_clone: GlobalPeerMap::new(),
global_peer_map_clone: GlobalPeerMap::default(),
last_update_time: AtomicCell::new(
self.global_peer_map_update_time.load() - Duration::from_secs(1),
),

View File

@ -5,9 +5,13 @@
// peer center is not guaranteed to be stable and can be changed when peer enter or leave.
// it's used to reduce the cost to exchange infos between peers.
use std::collections::BTreeMap;
use crate::proto::cli::PeerInfo;
use crate::proto::peer_rpc::{DirectConnectedPeerInfo, PeerInfoForGlobalMap};
pub mod instance;
mod server;
mod service;
#[derive(thiserror::Error, Debug, serde::Deserialize, serde::Serialize)]
pub enum Error {
@ -18,3 +22,29 @@ pub enum Error {
}
pub type Digest = u64;
impl From<Vec<PeerInfo>> for PeerInfoForGlobalMap {
fn from(peers: Vec<PeerInfo>) -> Self {
let mut peer_map = BTreeMap::new();
for peer in peers {
let Some(min_lat) = peer
.conns
.iter()
.map(|conn| conn.stats.as_ref().unwrap().latency_us)
.min()
else {
continue;
};
let dp_info = DirectConnectedPeerInfo {
latency_ms: std::cmp::max(1, (min_lat as u32 / 1000) as i32),
};
// sort conn info so hash result is stable
peer_map.insert(peer.peer_id, dp_info);
}
PeerInfoForGlobalMap {
direct_peers: peer_map,
}
}
}

View File

@ -7,15 +7,22 @@ use std::{
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap;
use once_cell::sync::Lazy;
use tokio::{task::JoinSet};
use tokio::task::JoinSet;
use crate::{common::PeerId, rpc::DirectConnectedPeerInfo};
use super::{
service::{GetGlobalPeerMapResponse, GlobalPeerMap, PeerCenterService, PeerInfoForGlobalMap},
Digest, Error,
use crate::{
common::PeerId,
proto::{
peer_rpc::{
DirectConnectedPeerInfo, GetGlobalPeerMapRequest, GetGlobalPeerMapResponse,
GlobalPeerMap, PeerCenterRpc, PeerInfoForGlobalMap, ReportPeersRequest,
ReportPeersResponse,
},
rpc_types::{self, controller::BaseController},
},
};
use super::Digest;
#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
pub(crate) struct SrcDstPeerPair {
src: PeerId,
@ -95,15 +102,19 @@ impl PeerCenterServer {
}
}
#[tarpc::server]
impl PeerCenterService for PeerCenterServer {
#[async_trait::async_trait]
impl PeerCenterRpc for PeerCenterServer {
type Controller = BaseController;
#[tracing::instrument()]
async fn report_peers(
self,
_: tarpc::context::Context,
my_peer_id: PeerId,
peers: PeerInfoForGlobalMap,
) -> Result<(), Error> {
&self,
_: BaseController,
req: ReportPeersRequest,
) -> Result<ReportPeersResponse, rpc_types::error::Error> {
let my_peer_id = req.my_peer_id;
let peers = req.peer_infos.unwrap_or_default();
tracing::debug!("receive report_peers");
let data = get_global_data(self.my_node_id);
@ -125,20 +136,23 @@ impl PeerCenterService for PeerCenterServer {
data.digest
.store(PeerCenterServer::calc_global_digest(self.my_node_id));
Ok(())
Ok(ReportPeersResponse::default())
}
#[tracing::instrument()]
async fn get_global_peer_map(
self,
_: tarpc::context::Context,
digest: Digest,
) -> Result<Option<GetGlobalPeerMapResponse>, Error> {
&self,
_: BaseController,
req: GetGlobalPeerMapRequest,
) -> Result<GetGlobalPeerMapResponse, rpc_types::error::Error> {
let digest = req.digest;
let data = get_global_data(self.my_node_id);
if digest == data.digest.load() && digest != 0 {
return Ok(None);
return Ok(GetGlobalPeerMapResponse::default());
}
let mut global_peer_map = GlobalPeerMap::new();
let mut global_peer_map = GlobalPeerMap::default();
for item in data.global_peer_map.iter() {
let (pair, entry) = item.pair();
global_peer_map
@ -151,9 +165,9 @@ impl PeerCenterService for PeerCenterServer {
.insert(pair.dst, entry.info.clone());
}
Ok(Some(GetGlobalPeerMapResponse {
global_peer_map,
digest: data.digest.load(),
}))
Ok(GetGlobalPeerMapResponse {
global_peer_map: global_peer_map.map,
digest: Some(data.digest.load()),
})
}
}

View File

@ -1,64 +0,0 @@
use std::collections::BTreeMap;
use crate::{common::PeerId, rpc::DirectConnectedPeerInfo};
use super::{Digest, Error};
use crate::rpc::PeerInfo;
pub type PeerInfoForGlobalMap = crate::rpc::cli::PeerInfoForGlobalMap;
impl From<Vec<PeerInfo>> for PeerInfoForGlobalMap {
fn from(peers: Vec<PeerInfo>) -> Self {
let mut peer_map = BTreeMap::new();
for peer in peers {
let Some(min_lat) = peer
.conns
.iter()
.map(|conn| conn.stats.as_ref().unwrap().latency_us)
.min()
else {
continue;
};
let dp_info = DirectConnectedPeerInfo {
latency_ms: std::cmp::max(1, (min_lat as u32 / 1000) as i32),
};
// sort conn info so hash result is stable
peer_map.insert(peer.peer_id, dp_info);
}
PeerInfoForGlobalMap {
direct_peers: peer_map,
}
}
}
// a global peer topology map, peers can use it to find optimal path to other peers
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct GlobalPeerMap {
pub map: BTreeMap<PeerId, PeerInfoForGlobalMap>,
}
impl GlobalPeerMap {
pub fn new() -> Self {
GlobalPeerMap {
map: BTreeMap::new(),
}
}
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct GetGlobalPeerMapResponse {
pub global_peer_map: GlobalPeerMap,
pub digest: Digest,
}
#[tarpc::service]
pub trait PeerCenterService {
// report center server which peer is directly connected to me
// digest is a hash of current peer map, if digest not match, we need to transfer the whole map
async fn report_peers(my_peer_id: PeerId, peers: PeerInfoForGlobalMap) -> Result<(), Error>;
async fn get_global_peer_map(digest: Digest)
-> Result<Option<GetGlobalPeerMapResponse>, Error>;
}

View File

@ -1,27 +1,11 @@
use std::{
sync::Arc,
time::{Duration, SystemTime},
};
use dashmap::DashMap;
use tokio::{sync::Mutex, task::JoinSet};
use std::sync::{Arc, Mutex};
use crate::{
common::{
error::Error,
global_ctx::{ArcGlobalCtx, NetworkIdentity},
PeerId,
},
common::{error::Error, global_ctx::ArcGlobalCtx, scoped_task::ScopedTask, PeerId},
tunnel::packet_def::ZCPacket,
};
use super::{
foreign_network_manager::{ForeignNetworkServiceClient, FOREIGN_NETWORK_SERVICE_ID},
peer_conn::PeerConn,
peer_map::PeerMap,
peer_rpc::PeerRpcManager,
PacketRecvChan,
};
use super::{peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::PeerRpcManager, PacketRecvChan};
pub struct ForeignNetworkClient {
global_ctx: ArcGlobalCtx,
@ -29,9 +13,7 @@ pub struct ForeignNetworkClient {
my_peer_id: PeerId,
peer_map: Arc<PeerMap>,
next_hop: Arc<DashMap<PeerId, PeerId>>,
tasks: Mutex<JoinSet<()>>,
task: Mutex<Option<ScopedTask<()>>>,
}
impl ForeignNetworkClient {
@ -46,17 +28,13 @@ impl ForeignNetworkClient {
global_ctx.clone(),
my_peer_id,
));
let next_hop = Arc::new(DashMap::new());
Self {
global_ctx,
peer_rpc,
my_peer_id,
peer_map,
next_hop,
tasks: Mutex::new(JoinSet::new()),
task: Mutex::new(None),
}
}
@ -65,91 +43,19 @@ impl ForeignNetworkClient {
self.peer_map.add_new_peer_conn(peer_conn).await
}
async fn collect_next_hop_in_foreign_network_task(
network_identity: NetworkIdentity,
peer_map: Arc<PeerMap>,
peer_rpc: Arc<PeerRpcManager>,
next_hop: Arc<DashMap<PeerId, PeerId>>,
) {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
peer_map.clean_peer_without_conn().await;
let new_next_hop = Self::collect_next_hop_in_foreign_network(
network_identity.clone(),
peer_map.clone(),
peer_rpc.clone(),
)
.await;
next_hop.clear();
for (k, v) in new_next_hop.into_iter() {
next_hop.insert(k, v);
}
}
}
async fn collect_next_hop_in_foreign_network(
network_identity: NetworkIdentity,
peer_map: Arc<PeerMap>,
peer_rpc: Arc<PeerRpcManager>,
) -> DashMap<PeerId, PeerId> {
let peers = peer_map.list_peers().await;
let mut tasks = JoinSet::new();
if !peers.is_empty() {
tracing::warn!(?peers, my_peer_id = ?peer_rpc.my_peer_id(), "collect next hop in foreign network");
}
for peer in peers {
let peer_rpc = peer_rpc.clone();
let network_identity = network_identity.clone();
tasks.spawn(async move {
let Ok(Some(peers_in_foreign)) = peer_rpc
.do_client_rpc_scoped(FOREIGN_NETWORK_SERVICE_ID, peer, |c| async {
let c =
ForeignNetworkServiceClient::new(tarpc::client::Config::default(), c)
.spawn();
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(2);
let ret = c.list_network_peers(rpc_ctx, network_identity).await;
ret
})
.await
else {
return (peer, vec![]);
};
(peer, peers_in_foreign)
});
}
let new_next_hop = DashMap::new();
while let Some(join_ret) = tasks.join_next().await {
let Ok((gateway, peer_ids)) = join_ret else {
tracing::error!(?join_ret, "collect next hop in foreign network failed");
continue;
};
for ret in peer_ids {
new_next_hop.insert(ret, gateway);
}
}
new_next_hop
}
pub fn has_next_hop(&self, peer_id: PeerId) -> bool {
self.get_next_hop(peer_id).is_some()
}
pub fn is_peer_public_node(&self, peer_id: &PeerId) -> bool {
self.peer_map.has_peer(*peer_id)
pub async fn list_public_peers(&self) -> Vec<PeerId> {
self.peer_map.list_peers().await
}
pub fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId> {
if self.peer_map.has_peer(peer_id) {
return Some(peer_id.clone());
}
self.next_hop.get(&peer_id).map(|v| v.clone())
None
}
pub async fn send_msg(&self, msg: ZCPacket, peer_id: PeerId) -> Result<(), Error> {
@ -162,40 +68,32 @@ impl ForeignNetworkClient {
?next_hop,
"foreign network client send msg failed"
);
} else {
tracing::info!(
?peer_id,
?next_hop,
"foreign network client send msg success"
);
}
return ret;
}
Err(Error::RouteError(Some("no next hop".to_string())))
}
pub fn list_foreign_peers(&self) -> Vec<PeerId> {
let mut peers = vec![];
for item in self.next_hop.iter() {
if item.key() != &self.my_peer_id {
peers.push(item.key().clone());
}
}
peers
}
pub async fn run(&self) {
self.tasks
.lock()
.await
.spawn(Self::collect_next_hop_in_foreign_network_task(
self.global_ctx.get_network_identity(),
self.peer_map.clone(),
self.peer_rpc.clone(),
self.next_hop.clone(),
));
}
pub fn get_next_hop_table(&self) -> DashMap<PeerId, PeerId> {
let next_hop = DashMap::new();
for item in self.next_hop.iter() {
next_hop.insert(item.key().clone(), item.value().clone());
}
next_hop
let peer_map = Arc::downgrade(&self.peer_map);
*self.task.lock().unwrap() = Some(
tokio::spawn(async move {
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
let Some(peer_map) = peer_map.upgrade() else {
break;
};
peer_map.clean_peer_without_conn().await;
}
})
.into(),
);
}
pub fn get_peer_map(&self) -> Arc<PeerMap> {

View File

@ -5,12 +5,12 @@ only forward packets of peers that directly connected to this node.
in future, with the help wo peer center we can forward packets of peers that
connected to any node in the local network.
*/
use std::sync::Arc;
use std::sync::{Arc, Weak};
use dashmap::DashMap;
use tokio::{
sync::{
mpsc::{self, unbounded_channel, UnboundedReceiver, UnboundedSender},
mpsc::{self, UnboundedReceiver, UnboundedSender},
Mutex,
},
task::JoinSet,
@ -18,26 +18,35 @@ use tokio::{
use crate::{
common::{
config::{ConfigLoader, TomlConfigLoader},
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent, NetworkIdentity},
global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent, NetworkIdentity},
stun::MockStunInfoCollector,
PeerId,
},
rpc::{ForeignNetworkEntryPb, ListForeignNetworkResponse, PeerInfo},
peers::route_trait::{Route, RouteInterface},
proto::{
cli::{ForeignNetworkEntryPb, ListForeignNetworkResponse, PeerInfo},
common::NatType,
},
tunnel::packet_def::{PacketType, ZCPacket},
};
use super::{
peer_conn::PeerConn,
peer_map::PeerMap,
peer_ospf_route::PeerRoute,
peer_rpc::{PeerRpcManager, PeerRpcManagerTransport},
route_trait::NextHopPolicy,
route_trait::{ArcRoute, NextHopPolicy},
PacketRecvChan, PacketRecvChanReceiver,
};
struct ForeignNetworkEntry {
global_ctx: ArcGlobalCtx,
network: NetworkIdentity,
peer_map: Arc<PeerMap>,
relay_data: bool,
route: ArcRoute,
}
impl ForeignNetworkEntry {
@ -47,19 +56,70 @@ impl ForeignNetworkEntry {
global_ctx: ArcGlobalCtx,
my_peer_id: PeerId,
relay_data: bool,
peer_rpc: Arc<PeerRpcManager>,
) -> Self {
let peer_map = Arc::new(PeerMap::new(packet_sender, global_ctx, my_peer_id));
let config = TomlConfigLoader::default();
config.set_network_identity(network.clone());
config.set_hostname(Some(format!("PublicServer_{}", global_ctx.get_hostname())));
let foreign_global_ctx = Arc::new(GlobalCtx::new(config));
foreign_global_ctx.replace_stun_info_collector(Box::new(MockStunInfoCollector {
udp_nat_type: NatType::Unknown,
}));
let peer_map = Arc::new(PeerMap::new(
packet_sender,
foreign_global_ctx.clone(),
my_peer_id,
));
let route = PeerRoute::new(my_peer_id, foreign_global_ctx.clone(), peer_rpc);
Self {
global_ctx: foreign_global_ctx,
network,
peer_map,
relay_data,
route: Arc::new(Box::new(route)),
}
}
async fn prepare(&self, my_peer_id: PeerId) {
struct Interface {
my_peer_id: PeerId,
peer_map: Weak<PeerMap>,
}
#[async_trait::async_trait]
impl RouteInterface for Interface {
async fn list_peers(&self) -> Vec<PeerId> {
let Some(peer_map) = self.peer_map.upgrade() else {
return vec![];
};
peer_map.list_peers_with_conn().await
}
fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
}
self.route
.open(Box::new(Interface {
my_peer_id,
peer_map: Arc::downgrade(&self.peer_map),
}))
.await
.unwrap();
self.peer_map.add_route(self.route.clone()).await;
}
}
struct ForeignNetworkManagerData {
network_peer_maps: DashMap<String, Arc<ForeignNetworkEntry>>,
peer_network_map: DashMap<PeerId, String>,
lock: std::sync::Mutex<()>,
}
impl ForeignNetworkManagerData {
@ -88,18 +148,27 @@ impl ForeignNetworkManagerData {
self.network_peer_maps.get(network_name).map(|v| v.clone())
}
fn remove_peer(&self, peer_id: PeerId) {
fn remove_peer(&self, peer_id: PeerId, network_name: &String) {
let _l = self.lock.lock().unwrap();
self.peer_network_map.remove(&peer_id);
self.network_peer_maps.retain(|_, v| !v.peer_map.is_empty());
self.network_peer_maps
.remove_if(network_name, |_, v| v.peer_map.is_empty());
}
fn clear_no_conn_peer(&self) {
for item in self.network_peer_maps.iter() {
let peer_map = item.value().peer_map.clone();
tokio::spawn(async move {
peer_map.clean_peer_without_conn().await;
});
}
async fn clear_no_conn_peer(&self, network_name: &String) {
let peer_map = self
.network_peer_maps
.get(network_name)
.unwrap()
.peer_map
.clone();
peer_map.clean_peer_without_conn().await;
}
fn remove_network(&self, network_name: &String) {
let _l = self.lock.lock().unwrap();
self.peer_network_map.retain(|_, v| v != network_name);
self.network_peer_maps.remove(network_name);
}
}
@ -117,11 +186,16 @@ impl PeerRpcManagerTransport for RpcTransport {
}
async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
tracing::debug!(
"foreign network manager send rpc to peer: {:?}",
dst_peer_id
);
self.data.send_msg(msg, dst_peer_id).await
}
async fn recv(&self) -> Result<ZCPacket, Error> {
if let Some(o) = self.packet_recv.lock().await.recv().await {
tracing::info!("recv rpc packet in foreign network manager rpc transport");
Ok(o)
} else {
Err(Error::Unknown)
@ -131,23 +205,6 @@ impl PeerRpcManagerTransport for RpcTransport {
pub const FOREIGN_NETWORK_SERVICE_ID: u32 = 1;
#[tarpc::service]
pub trait ForeignNetworkService {
async fn list_network_peers(network_identy: NetworkIdentity) -> Option<Vec<PeerId>>;
}
#[tarpc::server]
impl ForeignNetworkService for Arc<ForeignNetworkManagerData> {
async fn list_network_peers(
self,
_: tarpc::context::Context,
network_identy: NetworkIdentity,
) -> Option<Vec<PeerId>> {
let entry = self.network_peer_maps.get(&network_identy.network_name)?;
Some(entry.peer_map.list_peers().await)
}
}
pub struct ForeignNetworkManager {
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
@ -175,6 +232,7 @@ impl ForeignNetworkManager {
let data = Arc::new(ForeignNetworkManagerData {
network_peer_maps: DashMap::new(),
peer_network_map: DashMap::new(),
lock: std::sync::Mutex::new(()),
});
// handle rpc from foreign networks
@ -225,25 +283,39 @@ impl ForeignNetworkManager {
return ret;
}
let entry = self
.data
.network_peer_maps
.entry(peer_conn.get_network_identity().network_name.clone())
.or_insert_with(|| {
Arc::new(ForeignNetworkEntry::new(
peer_conn.get_network_identity(),
self.packet_sender.clone(),
self.global_ctx.clone(),
self.my_peer_id,
!ret.is_err(),
))
})
.clone();
let mut new_added = false;
self.data.peer_network_map.insert(
peer_conn.get_peer_id(),
peer_conn.get_network_identity().network_name.clone(),
);
let entry = {
let _l = self.data.lock.lock().unwrap();
let entry = self
.data
.network_peer_maps
.entry(peer_conn.get_network_identity().network_name.clone())
.or_insert_with(|| {
new_added = true;
Arc::new(ForeignNetworkEntry::new(
peer_conn.get_network_identity(),
self.packet_sender.clone(),
self.global_ctx.clone(),
self.my_peer_id,
!ret.is_err(),
self.rpc_mgr.clone(),
))
})
.clone();
self.data.peer_network_map.insert(
peer_conn.get_peer_id(),
peer_conn.get_network_identity().network_name.clone(),
);
entry
};
if new_added {
entry.prepare(self.my_peer_id).await;
self.start_event_handler(&entry).await;
}
if entry.network != peer_conn.get_network_identity() {
return Err(anyhow::anyhow!(
@ -257,28 +329,26 @@ impl ForeignNetworkManager {
Ok(entry.peer_map.add_new_peer_conn(peer_conn).await)
}
async fn start_global_event_handler(&self) {
async fn start_event_handler(&self, entry: &ForeignNetworkEntry) {
let data = self.data.clone();
let mut s = self.global_ctx.subscribe();
let (ev_tx, mut ev_rx) = unbounded_channel();
let network_name = entry.network.network_name.clone();
let mut s = entry.global_ctx.subscribe();
self.tasks.lock().await.spawn(async move {
while let Ok(e) = s.recv().await {
ev_tx.send(e).unwrap();
}
panic!("global event handler at foreign network manager exit");
});
self.tasks.lock().await.spawn(async move {
while let Some(e) = ev_rx.recv().await {
if let GlobalCtxEvent::PeerRemoved(peer_id) = &e {
tracing::info!(?e, "remove peer from foreign network manager");
data.remove_peer(*peer_id);
data.remove_peer(*peer_id, &network_name);
} else if let GlobalCtxEvent::PeerConnRemoved(..) = &e {
tracing::info!(?e, "clear no conn peer from foreign network manager");
data.clear_no_conn_peer();
data.clear_no_conn_peer(&network_name).await;
}
}
// if lagged or recv done just remove the network
tracing::error!("global event handler at foreign network manager exit");
data.remove_network(&network_name);
});
self.tasks.lock().await.spawn(async move {});
}
async fn start_packet_recv(&self) {
@ -294,10 +364,14 @@ impl ForeignNetworkManager {
tracing::warn!("invalid packet, skip");
continue;
};
tracing::info!(?hdr, "recv packet in foreign network manager");
let from_peer_id = hdr.from_peer_id.get();
let to_peer_id = hdr.to_peer_id.get();
if to_peer_id == my_node_id {
if hdr.packet_type == PacketType::TaRpc as u8 {
if hdr.packet_type == PacketType::TaRpc as u8
|| hdr.packet_type == PacketType::RpcReq as u8
|| hdr.packet_type == PacketType::RpcResp as u8
{
rpc_sender.send(packet_bytes).unwrap();
continue;
}
@ -335,16 +409,9 @@ impl ForeignNetworkManager {
});
}
async fn register_peer_rpc_service(&self) {
self.rpc_mgr.run();
self.rpc_mgr
.run_service(FOREIGN_NETWORK_SERVICE_ID, self.data.clone().serve())
}
pub async fn run(&self) {
self.start_global_event_handler().await;
self.start_packet_recv().await;
self.register_peer_rpc_service().await;
self.rpc_mgr.run();
}
pub async fn list_foreign_networks(&self) -> ListForeignNetworkResponse {
@ -380,8 +447,17 @@ impl ForeignNetworkManager {
}
}
impl Drop for ForeignNetworkManager {
fn drop(&mut self) {
self.data.peer_network_map.clear();
self.data.network_peer_maps.clear();
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use crate::{
common::global_ctx::tests::get_mock_global_ctx_with_network,
connector::udp_hole_punch::tests::{
@ -391,7 +467,8 @@ mod tests {
peer_manager::{PeerManager, RouteAlgoType},
tests::{connect_peer_manager, wait_route_appear},
},
rpc::NatType,
proto::common::NatType,
tunnel::common::tests::wait_for_condition,
};
use super::*;
@ -413,7 +490,7 @@ mod tests {
#[tokio::test]
async fn foreign_network_basic() {
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
tracing::debug!("pm_center: {:?}", pm_center.my_peer_id());
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
@ -428,8 +505,10 @@ mod tests {
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
.await
.unwrap();
assert_eq!(1, pma_net1.list_routes().await.len());
assert_eq!(1, pmb_net1.list_routes().await.len());
assert_eq!(2, pma_net1.list_routes().await.len());
assert_eq!(2, pmb_net1.list_routes().await.len());
println!("{:?}", pmb_net1.list_routes().await);
let rpc_resp = pm_center
.get_foreign_network_manager()
@ -440,7 +519,7 @@ mod tests {
}
async fn foreign_network_whitelist_helper(name: String) {
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
tracing::debug!("pm_center: {:?}", pm_center.my_peer_id());
let mut flag = pm_center.get_global_ctx().get_flags();
flag.foreign_network_whitelist = vec!["net1".to_string(), "net2*".to_string()].join(" ");
@ -466,7 +545,7 @@ mod tests {
#[tokio::test]
async fn only_relay_peer_rpc() {
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let mut flag = pm_center.get_global_ctx().get_flags();
flag.foreign_network_whitelist = "".to_string();
flag.relay_all_peer_rpc = true;
@ -485,8 +564,8 @@ mod tests {
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
.await
.unwrap();
assert_eq!(1, pma_net1.list_routes().await.len());
assert_eq!(1, pmb_net1.list_routes().await.len());
assert_eq!(2, pma_net1.list_routes().await.len());
assert_eq!(2, pmb_net1.list_routes().await.len());
}
#[tokio::test]
@ -497,9 +576,8 @@ mod tests {
#[tokio::test]
async fn test_foreign_network_manager() {
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
let pm_center2 =
create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let pm_center2 = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
connect_peer_manager(pm_center.clone(), pm_center2.clone()).await;
tracing::debug!(
@ -519,17 +597,9 @@ mod tests {
pmb_net1.my_peer_id()
);
let now = std::time::Instant::now();
let mut succ = false;
while now.elapsed().as_secs() < 10 {
let table = pma_net1.get_foreign_network_client().get_next_hop_table();
if table.len() >= 1 {
succ = true;
break;
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
assert!(succ);
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
.await
.unwrap();
assert_eq!(
vec![pm_center.my_peer_id()],
@ -547,11 +617,9 @@ mod tests {
.list_peers()
.await
);
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
.await
.unwrap();
assert_eq!(1, pma_net1.list_routes().await.len());
assert_eq!(1, pmb_net1.list_routes().await.len());
assert_eq!(2, pma_net1.list_routes().await.len());
assert_eq!(2, pmb_net1.list_routes().await.len());
let pmc_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
connect_peer_manager(pmc_net1.clone(), pm_center.clone()).await;
@ -561,7 +629,7 @@ mod tests {
wait_route_appear(pmb_net1.clone(), pmc_net1.clone())
.await
.unwrap();
assert_eq!(2, pmc_net1.list_routes().await.len());
assert_eq!(3, pmc_net1.list_routes().await.len());
tracing::debug!("pmc_net1: {:?}", pmc_net1.my_peer_id());
@ -577,8 +645,8 @@ mod tests {
wait_route_appear(pma_net2.clone(), pmb_net2.clone())
.await
.unwrap();
assert_eq!(1, pma_net2.list_routes().await.len());
assert_eq!(1, pmb_net2.list_routes().await.len());
assert_eq!(2, pma_net2.list_routes().await.len());
assert_eq!(2, pmb_net2.list_routes().await.len());
assert_eq!(
5,
@ -635,4 +703,27 @@ mod tests {
.len()
);
}
#[tokio::test]
async fn test_disconnect_foreign_network() {
let pm_center = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
tracing::debug!("pm_center: {:?}", pm_center.my_peer_id());
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
tracing::debug!("pma_net1: {:?}", pma_net1.my_peer_id(),);
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
wait_for_condition(
|| async { pma_net1.list_routes().await.len() == 1 },
Duration::from_secs(5),
)
.await;
drop(pm_center);
wait_for_condition(
|| async { pma_net1.list_routes().await.len() == 0 },
Duration::from_secs(5),
)
.await;
}
}

View File

@ -5,7 +5,6 @@ pub mod peer_conn_ping;
pub mod peer_manager;
pub mod peer_map;
pub mod peer_ospf_route;
pub mod peer_rip_route;
pub mod peer_rpc;
pub mod route_trait;
pub mod rpc_service;

View File

@ -11,7 +11,7 @@ use super::{
peer_conn::{PeerConn, PeerConnId},
PacketRecvChan,
};
use crate::rpc::PeerConnInfo;
use crate::proto::cli::PeerConnInfo;
use crate::{
common::{
error::Error,

View File

@ -29,8 +29,18 @@ use crate::{
global_ctx::ArcGlobalCtx,
PeerId,
},
rpc::{HandshakeRequest, PeerConnInfo, PeerConnStats, TunnelInfo},
tunnel::{filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter}, mpsc::{MpscTunnel, MpscTunnelSender}, packet_def::{PacketType, ZCPacket}, stats::{Throughput, WindowLatency}, Tunnel, TunnelError, ZCPacketStream},
proto::{
cli::{PeerConnInfo, PeerConnStats},
common::TunnelInfo,
peer_rpc::HandshakeRequest,
},
tunnel::{
filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter},
mpsc::{MpscTunnel, MpscTunnelSender},
packet_def::{PacketType, ZCPacket},
stats::{Throughput, WindowLatency},
Tunnel, TunnelError, ZCPacketStream,
},
};
use super::{peer_conn_ping::PeerConnPinger, PacketRecvChan};

View File

@ -17,7 +17,6 @@ use tokio::{
task::JoinSet,
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::bytes::Bytes;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
@ -27,6 +26,7 @@ use crate::{
route_trait::{NextHopPolicy, RouteInterface},
PeerPacketFilter,
},
proto::cli,
tunnel::{
self,
packet_def::{PacketType, ZCPacket},
@ -41,7 +41,6 @@ use super::{
peer_conn::PeerConnId,
peer_map::PeerMap,
peer_ospf_route::PeerRoute,
peer_rip_route::BasicRoute,
peer_rpc::PeerRpcManager,
route_trait::{ArcRoute, Route},
BoxNicPacketFilter, BoxPeerPacketFilter, PacketRecvChanReceiver,
@ -75,7 +74,15 @@ impl PeerRpcManagerTransport for RpcTransport {
.ok_or(Error::Unknown)?;
let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
if let Some(gateway_id) = peers
if foreign_peers.has_next_hop(dst_peer_id) {
// do not encrypt for data sending to public server
tracing::debug!(
?dst_peer_id,
?self.my_peer_id,
"failed to send msg to peer, try foreign network",
);
foreign_peers.send_msg(msg, dst_peer_id).await
} else if let Some(gateway_id) = peers
.get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop)
.await
{
@ -88,20 +95,11 @@ impl PeerRpcManagerTransport for RpcTransport {
self.encryptor
.encrypt(&mut msg)
.with_context(|| "encrypt failed")?;
peers.send_msg_directly(msg, gateway_id).await
} else if foreign_peers.has_next_hop(dst_peer_id) {
if !foreign_peers.is_peer_public_node(&dst_peer_id) {
// do not encrypt for msg sending to public node
self.encryptor
.encrypt(&mut msg)
.with_context(|| "encrypt failed")?;
if peers.has_peer(gateway_id) {
peers.send_msg_directly(msg, gateway_id).await
} else {
foreign_peers.send_msg(msg, gateway_id).await
}
tracing::debug!(
?dst_peer_id,
?self.my_peer_id,
"failed to send msg to peer, try foreign network",
);
foreign_peers.send_msg(msg, dst_peer_id).await
} else {
Err(Error::RouteError(Some(format!(
"peermgr RpcTransport no route for dst_peer_id: {}",
@ -120,13 +118,11 @@ impl PeerRpcManagerTransport for RpcTransport {
}
pub enum RouteAlgoType {
Rip,
Ospf,
None,
}
enum RouteAlgoInst {
Rip(Arc<BasicRoute>),
Ospf(Arc<PeerRoute>),
None,
}
@ -217,9 +213,6 @@ impl PeerManager {
let peer_rpc_mgr = Arc::new(PeerRpcManager::new(rpc_tspt.clone()));
let route_algo_inst = match route_algo {
RouteAlgoType::Rip => {
RouteAlgoInst::Rip(Arc::new(BasicRoute::new(my_peer_id, global_ctx.clone())))
}
RouteAlgoType::Ospf => RouteAlgoInst::Ospf(PeerRoute::new(
my_peer_id,
global_ctx.clone(),
@ -438,7 +431,10 @@ impl PeerManager {
impl PeerPacketFilter for PeerRpcPacketProcessor {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
let hdr = packet.peer_manager_header().unwrap();
if hdr.packet_type == PacketType::TaRpc as u8 {
if hdr.packet_type == PacketType::TaRpc as u8
|| hdr.packet_type == PacketType::RpcReq as u8
|| hdr.packet_type == PacketType::RpcResp as u8
{
self.peer_rpc_tspt_sender.send(packet).unwrap();
None
} else {
@ -477,33 +473,11 @@ impl PeerManager {
return vec![];
};
let mut peers = foreign_client.list_foreign_peers();
let mut peers = foreign_client.list_public_peers().await;
peers.extend(peer_map.list_peers_with_conn().await);
peers
}
async fn send_route_packet(
&self,
msg: Bytes,
_route_id: u8,
dst_peer_id: PeerId,
) -> Result<(), Error> {
let foreign_client = self
.foreign_network_client
.upgrade()
.ok_or(Error::Unknown)?;
let peer_map = self.peers.upgrade().ok_or(Error::Unknown)?;
let mut zc_packet = ZCPacket::new_with_payload(&msg);
zc_packet.fill_peer_manager_hdr(
self.my_peer_id,
dst_peer_id,
PacketType::Route as u8,
);
if foreign_client.has_next_hop(dst_peer_id) {
foreign_client.send_msg(zc_packet, dst_peer_id).await
} else {
peer_map.send_msg_directly(zc_packet, dst_peer_id).await
}
}
fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
@ -525,13 +499,12 @@ impl PeerManager {
pub fn get_route(&self) -> Box<dyn Route + Send + Sync + 'static> {
match &self.route_algo_inst {
RouteAlgoInst::Rip(route) => Box::new(route.clone()),
RouteAlgoInst::Ospf(route) => Box::new(route.clone()),
RouteAlgoInst::None => panic!("no route"),
}
}
pub async fn list_routes(&self) -> Vec<crate::rpc::Route> {
pub async fn list_routes(&self) -> Vec<cli::Route> {
self.get_route().list_routes().await
}
@ -649,13 +622,23 @@ impl PeerManager {
.get_gateway_peer_id(*peer_id, next_hop_policy.clone())
.await
{
if let Err(e) = self.peers.send_msg_directly(msg, gateway).await {
errs.push(e);
}
} else if self.foreign_network_client.has_next_hop(*peer_id) {
if let Err(e) = self.foreign_network_client.send_msg(msg, *peer_id).await {
errs.push(e);
if self.peers.has_peer(gateway) {
if let Err(e) = self.peers.send_msg_directly(msg, gateway).await {
errs.push(e);
}
} else if self.foreign_network_client.has_next_hop(gateway) {
if let Err(e) = self.foreign_network_client.send_msg(msg, gateway).await {
errs.push(e);
}
} else {
tracing::warn!(
?gateway,
?peer_id,
"cannot send msg to peer through gateway"
);
}
} else {
tracing::debug!(?peer_id, "no gateway for peer");
}
}
@ -693,7 +676,6 @@ impl PeerManager {
pub async fn run(&self) -> Result<(), Error> {
match &self.route_algo_inst {
RouteAlgoInst::Ospf(route) => self.add_route(route.clone()).await,
RouteAlgoInst::Rip(route) => self.add_route(route.clone()).await,
RouteAlgoInst::None => {}
};
@ -732,13 +714,6 @@ impl PeerManager {
self.nic_channel.clone()
}
pub fn get_basic_route(&self) -> Arc<BasicRoute> {
match &self.route_algo_inst {
RouteAlgoInst::Rip(route) => route.clone(),
_ => panic!("not rip route"),
}
}
pub fn get_foreign_network_manager(&self) -> Arc<ForeignNetworkManager> {
self.foreign_network_manager.clone()
}
@ -747,8 +722,8 @@ impl PeerManager {
self.foreign_network_client.clone()
}
pub fn get_my_info(&self) -> crate::rpc::NodeInfo {
crate::rpc::NodeInfo {
pub fn get_my_info(&self) -> cli::NodeInfo {
cli::NodeInfo {
peer_id: self.my_peer_id,
ipv4_addr: self
.global_ctx
@ -774,6 +749,12 @@ impl PeerManager {
version: env!("CARGO_PKG_VERSION").to_string(),
}
}
pub async fn wait(&self) {
while !self.tasks.lock().await.is_empty() {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
}
#[cfg(test)]
@ -789,12 +770,11 @@ mod tests {
instance::listeners::get_listener_by_url,
peers::{
peer_manager::RouteAlgoType,
peer_rpc::tests::{MockService, TestRpcService, TestRpcServiceClient},
peer_rpc::tests::register_service,
tests::{connect_peer_manager, wait_route_appear},
},
rpc::NatType,
tunnel::common::tests::wait_for_condition,
tunnel::{TunnelConnector, TunnelListener},
proto::common::NatType,
tunnel::{common::tests::wait_for_condition, TunnelConnector, TunnelListener},
};
use super::PeerManager;
@ -857,25 +837,18 @@ mod tests {
#[values("tcp", "udp", "wg", "quic")] proto1: &str,
#[values("tcp", "udp", "wg", "quic")] proto2: &str,
) {
use crate::proto::{
rpc_impl::RpcController,
tests::{GreetingClientFactory, SayHelloRequest},
};
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_a.get_peer_rpc_mgr().run_service(
100,
MockService {
prefix: "hello a".to_owned(),
}
.serve(),
);
register_service(&peer_mgr_a.peer_rpc_mgr, "", 0, "hello a");
let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
peer_mgr_c.get_peer_rpc_mgr().run_service(
100,
MockService {
prefix: "hello c".to_owned(),
}
.serve(),
);
register_service(&peer_mgr_c.peer_rpc_mgr, "", 0, "hello c");
let mut listener1 = get_listener_by_url(
&format!("{}://0.0.0.0:31013", proto1).parse().unwrap(),
@ -913,16 +886,26 @@ mod tests {
.await
.unwrap();
let ret = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(100, peer_mgr_c.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
let stub = peer_mgr_a
.peer_rpc_mgr
.rpc_client()
.scoped_client::<GreetingClientFactory<RpcController>>(
peer_mgr_a.my_peer_id,
peer_mgr_c.my_peer_id,
"".to_string(),
);
let ret = stub
.say_hello(
RpcController {},
SayHelloRequest {
name: "abc".to_string(),
},
)
.await
.unwrap();
assert_eq!(ret, "hello c abc");
assert_eq!(ret.greeting, "hello c abc!");
}
#[tokio::test]

View File

@ -10,7 +10,7 @@ use crate::{
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
PeerId,
},
rpc::PeerConnInfo,
proto::cli::PeerConnInfo,
tunnel::packet_def::ZCPacket,
tunnel::TunnelError,
};
@ -66,7 +66,7 @@ impl PeerMap {
}
pub fn has_peer(&self, peer_id: PeerId) -> bool {
self.peer_map.contains_key(&peer_id)
peer_id == self.my_peer_id || self.peer_map.contains_key(&peer_id)
}
pub async fn send_msg_directly(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
@ -113,10 +113,8 @@ impl PeerMap {
.get_next_hop_with_policy(dst_peer_id, policy.clone())
.await
{
// for foreign network, gateway_peer_id may not connect to me
if self.has_peer(gateway_peer_id) {
return Some(gateway_peer_id);
}
// NOTIC: for foreign network, gateway_peer_id may not connect to me
return Some(gateway_peer_id);
}
}

View File

@ -25,7 +25,18 @@ use tokio::{
use crate::{
common::{global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
peers::route_trait::{Route, RouteInterfaceBox},
rpc::{NatType, StunInfo},
proto::common::{NatType, StunInfo},
proto::{
peer_rpc::{
OspfRouteRpc, OspfRouteRpcClientFactory, OspfRouteRpcServer, PeerIdVersion,
RoutePeerInfo, RoutePeerInfos, SyncRouteInfoError, SyncRouteInfoRequest,
SyncRouteInfoResponse,
},
rpc_types::{
self,
controller::{BaseController, Controller},
},
},
};
use super::{
@ -76,31 +87,17 @@ impl From<Version> for AtomicVersion {
}
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
struct RoutePeerInfo {
// means next hop in route table.
peer_id: PeerId,
inst_id: uuid::Uuid,
cost: u8,
ipv4_addr: Option<Ipv4Addr>,
proxy_cidrs: Vec<String>,
hostname: Option<String>,
udp_stun_info: i8,
last_update: SystemTime,
version: Version,
}
impl RoutePeerInfo {
pub fn new() -> Self {
Self {
peer_id: 0,
inst_id: uuid::Uuid::nil(),
inst_id: Some(uuid::Uuid::nil().into()),
cost: 0,
ipv4_addr: None,
proxy_cidrs: Vec::new(),
hostname: None,
udp_stun_info: 0,
last_update: SystemTime::now(),
last_update: Some(SystemTime::now().into()),
version: 0,
}
}
@ -108,9 +105,9 @@ impl RoutePeerInfo {
pub fn update_self(&self, my_peer_id: PeerId, global_ctx: &ArcGlobalCtx) -> Self {
let mut new = Self {
peer_id: my_peer_id,
inst_id: global_ctx.get_id(),
inst_id: Some(global_ctx.get_id().into()),
cost: 0,
ipv4_addr: global_ctx.get_ipv4(),
ipv4_addr: global_ctx.get_ipv4().map(|x| x.into()),
proxy_cidrs: global_ctx
.get_proxy_cidrs()
.iter()
@ -121,20 +118,22 @@ impl RoutePeerInfo {
udp_stun_info: global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type as i8,
.udp_nat_type,
// following fields do not participate in comparison.
last_update: self.last_update,
version: self.version,
};
let need_update_periodically = if let Ok(d) = new.last_update.elapsed() {
let need_update_periodically = if let Ok(Ok(d)) =
SystemTime::try_from(new.last_update.unwrap()).map(|x| x.elapsed())
{
d > UPDATE_PEER_INFO_PERIOD
} else {
true
};
if new != *self || need_update_periodically {
new.last_update = SystemTime::now();
new.last_update = Some(SystemTime::now().into());
new.version += 1;
}
@ -142,9 +141,9 @@ impl RoutePeerInfo {
}
}
impl Into<crate::rpc::Route> for RoutePeerInfo {
fn into(self) -> crate::rpc::Route {
crate::rpc::Route {
impl Into<crate::proto::cli::Route> for RoutePeerInfo {
fn into(self) -> crate::proto::cli::Route {
crate::proto::cli::Route {
peer_id: self.peer_id,
ipv4_addr: if let Some(ipv4_addr) = self.ipv4_addr {
ipv4_addr.to_string()
@ -162,7 +161,7 @@ impl Into<crate::rpc::Route> for RoutePeerInfo {
}
Some(stun_info)
},
inst_id: self.inst_id.to_string(),
inst_id: self.inst_id.map(|x| x.to_string()).unwrap_or_default(),
version: env!("CARGO_PKG_VERSION").to_string(),
}
}
@ -174,6 +173,35 @@ struct RouteConnBitmap {
bitmap: Vec<u8>,
}
impl Into<crate::proto::peer_rpc::RouteConnBitmap> for RouteConnBitmap {
fn into(self) -> crate::proto::peer_rpc::RouteConnBitmap {
crate::proto::peer_rpc::RouteConnBitmap {
peer_ids: self
.peer_ids
.into_iter()
.map(|x| PeerIdVersion {
peer_id: x.0,
version: x.1,
})
.collect(),
bitmap: self.bitmap,
}
}
}
impl From<crate::proto::peer_rpc::RouteConnBitmap> for RouteConnBitmap {
fn from(v: crate::proto::peer_rpc::RouteConnBitmap) -> Self {
RouteConnBitmap {
peer_ids: v
.peer_ids
.into_iter()
.map(|x| (x.peer_id, x.version))
.collect(),
bitmap: v.bitmap,
}
}
}
impl RouteConnBitmap {
fn new() -> Self {
RouteConnBitmap {
@ -200,28 +228,7 @@ impl RouteConnBitmap {
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
enum Error {
DuplicatePeerId,
Stopped,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct SyncRouteInfoResponse {
is_initiator: bool,
session_id: SessionId,
}
#[tarpc::service]
trait RouteService {
async fn sync_route_info(
my_peer_id: PeerId,
my_session_id: SessionId,
is_initiator: bool,
peer_infos: Option<Vec<RoutePeerInfo>>,
conn_bitmap: Option<RouteConnBitmap>,
) -> Result<SyncRouteInfoResponse, Error>;
}
type Error = SyncRouteInfoError;
// constructed with all infos synced from all peers.
#[derive(Debug)]
@ -299,7 +306,7 @@ impl SyncedRouteInfo {
for mut route_info in peer_infos.iter().map(Clone::clone) {
// time between peers may not be synchronized, so update last_update to local now.
// note only last_update with larger version will be updated to local saved peer info.
route_info.last_update = SystemTime::now();
route_info.last_update = Some(SystemTime::now().into());
self.peer_infos
.entry(route_info.peer_id)
@ -581,7 +588,7 @@ impl RouteTable {
let info = item.value();
if let Some(ipv4_addr) = info.ipv4_addr {
self.ipv4_peer_id_map.insert(ipv4_addr, *peer_id);
self.ipv4_peer_id_map.insert(ipv4_addr.into(), *peer_id);
}
for cidr in info.proxy_cidrs.iter() {
@ -996,7 +1003,8 @@ impl PeerRouteServiceImpl {
let now = SystemTime::now();
let mut to_remove = Vec::new();
for item in self.synced_route_info.peer_infos.iter() {
if let Ok(d) = now.duration_since(item.value().last_update) {
if let Ok(d) = now.duration_since(item.value().last_update.unwrap().try_into().unwrap())
{
if d > REMOVE_DEAD_PEER_INFO_AFTER {
to_remove.push(*item.key());
}
@ -1021,7 +1029,7 @@ impl PeerRouteServiceImpl {
let my_peer_id = self.my_peer_id;
let (peer_infos, conn_bitmap) = self.build_sync_request(&session);
tracing::info!("my_id {:?}, pper_id: {:?}, peer_infos: {:?}, conn_bitmap: {:?}, synced_route_info: {:?} session: {:?}",
tracing::info!("building sync_route request. my_id {:?}, pper_id: {:?}, peer_infos: {:?}, conn_bitmap: {:?}, synced_route_info: {:?} session: {:?}",
my_peer_id, dst_peer_id, peer_infos, conn_bitmap, self.synced_route_info, session);
if peer_infos.is_none()
@ -1035,33 +1043,60 @@ impl PeerRouteServiceImpl {
.need_sync_initiator_info
.store(false, Ordering::Relaxed);
let ret = peer_rpc
.do_client_rpc_scoped(SERVICE_ID, dst_peer_id, |c| async {
let client = RouteServiceClient::new(tarpc::client::Config::default(), c).spawn();
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3);
client
.sync_route_info(
rpc_ctx,
my_peer_id,
session.my_session_id.load(Ordering::Relaxed),
session.we_are_initiator.load(Ordering::Relaxed),
peer_infos.clone(),
conn_bitmap.clone(),
)
.await
})
let rpc_stub = peer_rpc
.rpc_client()
.scoped_client::<OspfRouteRpcClientFactory<BaseController>>(
self.my_peer_id,
dst_peer_id,
self.global_ctx.get_network_name(),
);
let mut ctrl = BaseController {};
ctrl.set_timeout_ms(3000);
let ret = rpc_stub
.sync_route_info(
ctrl,
SyncRouteInfoRequest {
my_peer_id,
my_session_id: session.my_session_id.load(Ordering::Relaxed),
is_initiator: session.we_are_initiator.load(Ordering::Relaxed),
peer_infos: peer_infos.clone().map(|x| RoutePeerInfos { items: x }),
conn_bitmap: conn_bitmap.clone().map(Into::into),
},
)
.await;
match ret {
Ok(Ok(ret)) => {
if let Err(e) = &ret {
tracing::error!(
?ret,
?my_peer_id,
?dst_peer_id,
?e,
"sync_route_info failed"
);
session
.need_sync_initiator_info
.store(true, Ordering::Relaxed);
} else {
let resp = ret.as_ref().unwrap();
if resp.error.is_some() {
let err = resp.error.unwrap();
if err == Error::DuplicatePeerId as i32 {
panic!("duplicate peer id");
} else {
tracing::error!(?ret, ?my_peer_id, ?dst_peer_id, "sync_route_info failed");
session
.need_sync_initiator_info
.store(true, Ordering::Relaxed);
}
} else {
session.rpc_tx_count.fetch_add(1, Ordering::Relaxed);
session
.dst_is_initiator
.store(ret.is_initiator, Ordering::Relaxed);
.store(resp.is_initiator, Ordering::Relaxed);
session.update_dst_session_id(ret.session_id);
session.update_dst_session_id(resp.session_id);
if let Some(peer_infos) = &peer_infos {
session.update_dst_saved_peer_info_version(&peer_infos);
@ -1071,17 +1106,6 @@ impl PeerRouteServiceImpl {
session.update_dst_saved_conn_bitmap_version(&conn_bitmap);
}
}
Ok(Err(Error::DuplicatePeerId)) => {
panic!("duplicate peer id");
}
_ => {
tracing::error!(?ret, ?my_peer_id, ?dst_peer_id, "sync_route_info failed");
session
.need_sync_initiator_info
.store(true, Ordering::Relaxed);
}
}
return false;
}
@ -1103,59 +1127,37 @@ impl Debug for RouteSessionManager {
}
}
#[tarpc::server]
impl RouteService for RouteSessionManager {
#[async_trait::async_trait]
impl OspfRouteRpc for RouteSessionManager {
type Controller = BaseController;
async fn sync_route_info(
self,
_: tarpc::context::Context,
from_peer_id: PeerId,
from_session_id: SessionId,
is_initiator: bool,
peer_infos: Option<Vec<RoutePeerInfo>>,
conn_bitmap: Option<RouteConnBitmap>,
) -> Result<SyncRouteInfoResponse, Error> {
let Some(service_impl) = self.service_impl.upgrade() else {
return Err(Error::Stopped);
};
&self,
_ctrl: BaseController,
request: SyncRouteInfoRequest,
) -> Result<SyncRouteInfoResponse, rpc_types::error::Error> {
let from_peer_id = request.my_peer_id;
let from_session_id = request.my_session_id;
let is_initiator = request.is_initiator;
let peer_infos = request.peer_infos.map(|x| x.items);
let conn_bitmap = request.conn_bitmap.map(Into::into);
let my_peer_id = service_impl.my_peer_id;
let session = self.get_or_start_session(from_peer_id)?;
session.rpc_rx_count.fetch_add(1, Ordering::Relaxed);
session.update_dst_session_id(from_session_id);
if let Some(peer_infos) = &peer_infos {
service_impl.synced_route_info.update_peer_infos(
my_peer_id,
let ret = self
.do_sync_route_info(
from_peer_id,
from_session_id,
is_initiator,
peer_infos,
)?;
session.update_dst_saved_peer_info_version(peer_infos);
}
conn_bitmap,
)
.await;
if let Some(conn_bitmap) = &conn_bitmap {
service_impl.synced_route_info.update_conn_map(&conn_bitmap);
session.update_dst_saved_conn_bitmap_version(conn_bitmap);
}
service_impl.update_route_table_and_cached_local_conn_bitmap();
tracing::info!(
"sync_route_info: from_peer_id: {:?}, is_initiator: {:?}, peer_infos: {:?}, conn_bitmap: {:?}, synced_route_info: {:?} session: {:?}, new_route_table: {:?}",
from_peer_id, is_initiator, peer_infos, conn_bitmap, service_impl.synced_route_info, session, service_impl.route_table);
session
.dst_is_initiator
.store(is_initiator, Ordering::Relaxed);
let is_initiator = session.we_are_initiator.load(Ordering::Relaxed);
let session_id = session.my_session_id.load(Ordering::Relaxed);
self.sync_now("sync_route_info");
Ok(SyncRouteInfoResponse {
is_initiator,
session_id,
Ok(match ret {
Ok(v) => v,
Err(e) => {
let mut resp = SyncRouteInfoResponse::default();
resp.error = Some(e as i32);
resp
}
})
}
}
@ -1366,6 +1368,60 @@ impl RouteSessionManager {
let ret = self.sync_now_broadcast.send(());
tracing::debug!(?ret, ?reason, "sync_now_broadcast.send");
}
async fn do_sync_route_info(
&self,
from_peer_id: PeerId,
from_session_id: SessionId,
is_initiator: bool,
peer_infos: Option<Vec<RoutePeerInfo>>,
conn_bitmap: Option<RouteConnBitmap>,
) -> Result<SyncRouteInfoResponse, Error> {
let Some(service_impl) = self.service_impl.upgrade() else {
return Err(Error::Stopped);
};
let my_peer_id = service_impl.my_peer_id;
let session = self.get_or_start_session(from_peer_id)?;
session.rpc_rx_count.fetch_add(1, Ordering::Relaxed);
session.update_dst_session_id(from_session_id);
if let Some(peer_infos) = &peer_infos {
service_impl.synced_route_info.update_peer_infos(
my_peer_id,
from_peer_id,
peer_infos,
)?;
session.update_dst_saved_peer_info_version(peer_infos);
}
if let Some(conn_bitmap) = &conn_bitmap {
service_impl.synced_route_info.update_conn_map(&conn_bitmap);
session.update_dst_saved_conn_bitmap_version(conn_bitmap);
}
service_impl.update_route_table_and_cached_local_conn_bitmap();
tracing::info!(
"handling sync_route_info rpc: from_peer_id: {:?}, is_initiator: {:?}, peer_infos: {:?}, conn_bitmap: {:?}, synced_route_info: {:?} session: {:?}, new_route_table: {:?}",
from_peer_id, is_initiator, peer_infos, conn_bitmap, service_impl.synced_route_info, session, service_impl.route_table);
session
.dst_is_initiator
.store(is_initiator, Ordering::Relaxed);
let is_initiator = session.we_are_initiator.load(Ordering::Relaxed);
let session_id = session.my_session_id.load(Ordering::Relaxed);
self.sync_now("sync_route_info");
Ok(SyncRouteInfoResponse {
is_initiator,
session_id,
error: None,
})
}
}
pub struct PeerRoute {
@ -1415,7 +1471,7 @@ impl PeerRoute {
tokio::time::sleep(Duration::from_secs(60)).await;
service_impl.clear_expired_peer();
// TODO: use debug log level for this.
tracing::info!(?service_impl, "clear_expired_peer");
tracing::debug!(?service_impl, "clear_expired_peer");
}
}
@ -1453,8 +1509,10 @@ impl PeerRoute {
}
async fn start(&self) {
self.peer_rpc
.run_service(SERVICE_ID, RouteService::serve(self.session_mgr.clone()));
self.peer_rpc.rpc_server().registry().register(
OspfRouteRpcServer::new(self.session_mgr.clone()),
&self.global_ctx.get_network_name(),
);
self.tasks
.lock()
@ -1479,6 +1537,15 @@ impl PeerRoute {
}
}
impl Drop for PeerRoute {
fn drop(&mut self) {
self.peer_rpc.rpc_server().registry().unregister(
OspfRouteRpcServer::new(self.session_mgr.clone()),
&self.global_ctx.get_network_name(),
);
}
}
#[async_trait::async_trait]
impl Route for PeerRoute {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> {
@ -1507,7 +1574,7 @@ impl Route for PeerRoute {
route_table.get_next_hop(dst_peer_id).map(|x| x.0)
}
async fn list_routes(&self) -> Vec<crate::rpc::Route> {
async fn list_routes(&self) -> Vec<crate::proto::cli::Route> {
let route_table = &self.service_impl.route_table;
let mut routes = Vec::new();
for item in route_table.peer_infos.iter() {
@ -1517,7 +1584,7 @@ impl Route for PeerRoute {
let Some(next_hop_peer) = route_table.get_next_hop(*item.key()) else {
continue;
};
let mut route: crate::rpc::Route = item.value().clone().into();
let mut route: crate::proto::cli::Route = item.value().clone().into();
route.next_hop_peer_id = next_hop_peer.0;
route.cost = next_hop_peer.1;
routes.push(route);
@ -1567,7 +1634,7 @@ mod tests {
route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface},
tests::connect_peer_manager,
},
rpc::NatType,
proto::common::NatType,
tunnel::common::tests::wait_for_condition,
};

View File

@ -1,753 +0,0 @@
use std::{
net::Ipv4Addr,
sync::{atomic::AtomicU32, Arc},
time::{Duration, Instant},
};
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::{
sync::{Mutex, RwLock},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
peers::route_trait::{Route, RouteInterfaceBox},
rpc::{NatType, StunInfo},
tunnel::packet_def::{PacketType, ZCPacket},
};
use super::PeerPacketFilter;
const SEND_ROUTE_PERIOD_SEC: u64 = 60;
const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5;
const ROUTE_EXPIRED_SEC: u64 = 70;
type Version = u32;
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, PartialEq)]
// Derives can be passed through to the generated type:
pub struct SyncPeerInfo {
// means next hop in route table.
pub peer_id: PeerId,
pub cost: u32,
pub ipv4_addr: Option<Ipv4Addr>,
pub proxy_cidrs: Vec<String>,
pub hostname: Option<String>,
pub udp_stun_info: i8,
}
impl SyncPeerInfo {
pub fn new_self(from_peer: PeerId, global_ctx: &ArcGlobalCtx) -> Self {
SyncPeerInfo {
peer_id: from_peer,
cost: 0,
ipv4_addr: global_ctx.get_ipv4(),
proxy_cidrs: global_ctx
.get_proxy_cidrs()
.iter()
.map(|x| x.to_string())
.chain(global_ctx.get_vpn_portal_cidr().map(|x| x.to_string()))
.collect(),
hostname: Some(global_ctx.get_hostname()),
udp_stun_info: global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type as i8,
}
}
pub fn clone_for_route_table(&self, next_hop: PeerId, cost: u32, from: &Self) -> Self {
SyncPeerInfo {
peer_id: next_hop,
cost,
ipv4_addr: from.ipv4_addr.clone(),
proxy_cidrs: from.proxy_cidrs.clone(),
hostname: from.hostname.clone(),
udp_stun_info: from.udp_stun_info,
}
}
}
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
pub struct SyncPeer {
pub myself: SyncPeerInfo,
pub neighbors: Vec<SyncPeerInfo>,
// the route table version of myself
pub version: Version,
// the route table version of peer that we have received last time
pub peer_version: Option<Version>,
// if we do not have latest peer version, need_reply is true
pub need_reply: bool,
}
impl SyncPeer {
pub fn new(
from_peer: PeerId,
_to_peer: PeerId,
neighbors: Vec<SyncPeerInfo>,
global_ctx: ArcGlobalCtx,
version: Version,
peer_version: Option<Version>,
need_reply: bool,
) -> Self {
SyncPeer {
myself: SyncPeerInfo::new_self(from_peer, &global_ctx),
neighbors,
version,
peer_version,
need_reply,
}
}
}
#[derive(Debug)]
struct SyncPeerFromRemote {
packet: SyncPeer,
last_update: std::time::Instant,
}
type SyncPeerFromRemoteMap = Arc<DashMap<PeerId, SyncPeerFromRemote>>;
#[derive(Debug)]
struct RouteTable {
route_info: DashMap<PeerId, SyncPeerInfo>,
ipv4_peer_id_map: DashMap<Ipv4Addr, PeerId>,
cidr_peer_id_map: DashMap<cidr::IpCidr, PeerId>,
}
impl RouteTable {
fn new() -> Self {
RouteTable {
route_info: DashMap::new(),
ipv4_peer_id_map: DashMap::new(),
cidr_peer_id_map: DashMap::new(),
}
}
fn copy_from(&self, other: &Self) {
self.route_info.clear();
for item in other.route_info.iter() {
let (k, v) = item.pair();
self.route_info.insert(*k, v.clone());
}
self.ipv4_peer_id_map.clear();
for item in other.ipv4_peer_id_map.iter() {
let (k, v) = item.pair();
self.ipv4_peer_id_map.insert(*k, *v);
}
self.cidr_peer_id_map.clear();
for item in other.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
self.cidr_peer_id_map.insert(*k, *v);
}
}
}
#[derive(Debug, Clone)]
struct RouteVersion(Arc<AtomicU32>);
impl RouteVersion {
fn new() -> Self {
// RouteVersion(Arc::new(AtomicU32::new(rand::random())))
RouteVersion(Arc::new(AtomicU32::new(0)))
}
fn get(&self) -> Version {
self.0.load(std::sync::atomic::Ordering::Relaxed)
}
fn inc(&self) {
self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
pub struct BasicRoute {
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
route_table: Arc<RouteTable>,
sync_peer_from_remote: SyncPeerFromRemoteMap,
tasks: Mutex<JoinSet<()>>,
need_sync_notifier: Arc<tokio::sync::Notify>,
version: RouteVersion,
myself: Arc<RwLock<SyncPeerInfo>>,
last_send_time_map: Arc<DashMap<PeerId, (Version, Option<Version>, Instant)>>,
}
impl BasicRoute {
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx) -> Self {
BasicRoute {
my_peer_id,
global_ctx: global_ctx.clone(),
interface: Arc::new(Mutex::new(None)),
route_table: Arc::new(RouteTable::new()),
sync_peer_from_remote: Arc::new(DashMap::new()),
tasks: Mutex::new(JoinSet::new()),
need_sync_notifier: Arc::new(tokio::sync::Notify::new()),
version: RouteVersion::new(),
myself: Arc::new(RwLock::new(SyncPeerInfo::new_self(
my_peer_id.into(),
&global_ctx,
))),
last_send_time_map: Arc::new(DashMap::new()),
}
}
fn update_route_table(
my_id: PeerId,
sync_peer_reqs: SyncPeerFromRemoteMap,
route_table: Arc<RouteTable>,
) {
tracing::trace!(my_id = ?my_id, route_table = ?route_table, "update route table");
let new_route_table = Arc::new(RouteTable::new());
for item in sync_peer_reqs.iter() {
Self::update_route_table_with_req(my_id, &item.value().packet, new_route_table.clone());
}
route_table.copy_from(&new_route_table);
}
async fn update_myself(
my_peer_id: PeerId,
myself: &Arc<RwLock<SyncPeerInfo>>,
global_ctx: &ArcGlobalCtx,
) -> bool {
let new_myself = SyncPeerInfo::new_self(my_peer_id, &global_ctx);
if *myself.read().await != new_myself {
*myself.write().await = new_myself;
true
} else {
false
}
}
fn update_route_table_with_req(my_id: PeerId, packet: &SyncPeer, route_table: Arc<RouteTable>) {
let peer_id = packet.myself.peer_id.clone();
let update = |cost: u32, peer_info: &SyncPeerInfo| {
let node_id: PeerId = peer_info.peer_id.into();
let ret = route_table
.route_info
.entry(node_id.clone().into())
.and_modify(|info| {
if info.cost > cost {
*info = info.clone_for_route_table(peer_id, cost, &peer_info);
}
})
.or_insert(
peer_info
.clone()
.clone_for_route_table(peer_id, cost, &peer_info),
)
.value()
.clone();
if ret.cost > 6 {
tracing::error!(
"cost too large: {}, may lost connection, remove it",
ret.cost
);
route_table.route_info.remove(&node_id);
}
tracing::trace!(
"update route info, to: {:?}, gateway: {:?}, cost: {}, peer: {:?}",
node_id,
peer_id,
cost,
&peer_info
);
if let Some(ipv4) = peer_info.ipv4_addr {
route_table
.ipv4_peer_id_map
.insert(ipv4.clone(), node_id.clone().into());
}
for cidr in peer_info.proxy_cidrs.iter() {
let cidr: cidr::IpCidr = cidr.parse().unwrap();
route_table
.cidr_peer_id_map
.insert(cidr, node_id.clone().into());
}
};
for neighbor in packet.neighbors.iter() {
if neighbor.peer_id == my_id {
continue;
}
update(neighbor.cost + 1, &neighbor);
tracing::trace!("route info: {:?}", neighbor);
}
// add the sender peer to route info
update(1, &packet.myself);
tracing::trace!("my_id: {:?}, current route table: {:?}", my_id, route_table);
}
async fn send_sync_peer_request(
interface: &RouteInterfaceBox,
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
peer_id: PeerId,
route_table: Arc<RouteTable>,
my_version: Version,
peer_version: Option<Version>,
need_reply: bool,
) -> Result<(), Error> {
let mut route_info_copy: Vec<SyncPeerInfo> = Vec::new();
// copy the route info
for item in route_table.route_info.iter() {
let (k, v) = item.pair();
route_info_copy.push(v.clone().clone_for_route_table(*k, v.cost, &v));
}
let msg = SyncPeer::new(
my_peer_id,
peer_id,
route_info_copy,
global_ctx,
my_version,
peer_version,
need_reply,
);
// TODO: this may exceed the MTU of the tunnel
interface
.send_route_packet(postcard::to_allocvec(&msg).unwrap().into(), 1, peer_id)
.await
}
async fn sync_peer_periodically(&self) {
let route_table = self.route_table.clone();
let global_ctx = self.global_ctx.clone();
let my_peer_id = self.my_peer_id.clone();
let interface = self.interface.clone();
let notifier = self.need_sync_notifier.clone();
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
let myself = self.myself.clone();
let version = self.version.clone();
let last_send_time_map = self.last_send_time_map.clone();
self.tasks.lock().await.spawn(
async move {
loop {
if Self::update_myself(my_peer_id,&myself, &global_ctx).await {
version.inc();
tracing::info!(
my_id = ?my_peer_id,
version = version.get(),
"update route table version when myself changed"
);
}
let lockd_interface = interface.lock().await;
let interface = lockd_interface.as_ref().unwrap();
let last_send_time_map_new = DashMap::new();
let peers = interface.list_peers().await;
for peer in peers.iter() {
let last_send_time = last_send_time_map.get(peer).map(|v| *v).unwrap_or((0, None, Instant::now() - Duration::from_secs(3600)));
let my_version_peer_saved = sync_peer_from_remote.get(peer).and_then(|v| v.packet.peer_version);
let peer_have_latest_version = my_version_peer_saved == Some(version.get());
if peer_have_latest_version && last_send_time.2.elapsed().as_secs() < SEND_ROUTE_PERIOD_SEC {
last_send_time_map_new.insert(*peer, last_send_time);
continue;
}
tracing::trace!(
my_id = ?my_peer_id,
dst_peer_id = ?peer,
version = version.get(),
?my_version_peer_saved,
last_send_version = ?last_send_time.0,
last_send_peer_version = ?last_send_time.1,
last_send_elapse = ?last_send_time.2.elapsed().as_secs(),
"need send route info"
);
let peer_version_we_saved = sync_peer_from_remote.get(&peer).and_then(|v| Some(v.packet.version));
last_send_time_map_new.insert(*peer, (version.get(), peer_version_we_saved, Instant::now()));
let ret = Self::send_sync_peer_request(
interface,
my_peer_id.clone(),
global_ctx.clone(),
*peer,
route_table.clone(),
version.get(),
peer_version_we_saved,
!peer_have_latest_version,
)
.await;
match &ret {
Ok(_) => {
tracing::trace!("send sync peer request to peer: {}", peer);
}
Err(Error::PeerNoConnectionError(_)) => {
tracing::trace!("peer {} no connection", peer);
}
Err(e) => {
tracing::error!(
"send sync peer request to peer: {} error: {:?}",
peer,
e
);
}
};
}
last_send_time_map.clear();
for item in last_send_time_map_new.iter() {
let (k, v) = item.pair();
last_send_time_map.insert(*k, *v);
}
tokio::select! {
_ = notifier.notified() => {
tracing::trace!("sync peer request triggered by notifier");
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
tracing::trace!("sync peer request triggered by timeout");
}
}
}
}
.instrument(
tracing::info_span!("sync_peer_periodically", my_id = ?self.my_peer_id, global_ctx = ?self.global_ctx),
),
);
}
async fn check_expired_sync_peer_from_remote(&self) {
let route_table = self.route_table.clone();
let my_peer_id = self.my_peer_id.clone();
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
let notifier = self.need_sync_notifier.clone();
let interface = self.interface.clone();
let version = self.version.clone();
self.tasks.lock().await.spawn(async move {
loop {
let mut need_update_route = false;
let now = std::time::Instant::now();
let mut need_remove = Vec::new();
let connected_peers = interface.lock().await.as_ref().unwrap().list_peers().await;
for item in sync_peer_from_remote.iter() {
let (k, v) = item.pair();
if now.duration_since(v.last_update).as_secs() > ROUTE_EXPIRED_SEC
|| !connected_peers.contains(k)
{
need_update_route = true;
need_remove.insert(0, k.clone());
}
}
for k in need_remove.iter() {
tracing::warn!("remove expired sync peer: {:?}", k);
sync_peer_from_remote.remove(k);
}
if need_update_route {
Self::update_route_table(
my_peer_id,
sync_peer_from_remote.clone(),
route_table.clone(),
);
version.inc();
tracing::info!(
my_id = ?my_peer_id,
version = version.get(),
"update route table when check expired peer"
);
notifier.notify_one();
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
}
fn get_peer_id_for_proxy(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
let ipv4 = std::net::IpAddr::V4(*ipv4);
for item in self.route_table.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
if k.contains(&ipv4) {
return Some(*v);
}
}
None
}
#[tracing::instrument(skip(self, packet), fields(my_id = ?self.my_peer_id, ctx = ?self.global_ctx))]
async fn handle_route_packet(&self, src_peer_id: PeerId, packet: Bytes) {
let packet = postcard::from_bytes::<SyncPeer>(&packet).unwrap();
let p = &packet;
let mut updated = true;
assert_eq!(packet.myself.peer_id, src_peer_id);
self.sync_peer_from_remote
.entry(packet.myself.peer_id.into())
.and_modify(|v| {
if v.packet.myself == p.myself && v.packet.neighbors == p.neighbors {
updated = false;
} else {
v.packet = p.clone();
}
v.packet.version = p.version;
v.packet.peer_version = p.peer_version;
v.last_update = std::time::Instant::now();
})
.or_insert(SyncPeerFromRemote {
packet: p.clone(),
last_update: std::time::Instant::now(),
});
if updated {
Self::update_route_table(
self.my_peer_id.clone(),
self.sync_peer_from_remote.clone(),
self.route_table.clone(),
);
self.version.inc();
tracing::info!(
my_id = ?self.my_peer_id,
?p,
version = self.version.get(),
"update route table when receive route packet"
);
}
if packet.need_reply {
self.last_send_time_map
.entry(packet.myself.peer_id.into())
.and_modify(|v| {
const FAST_REPLY_DURATION: u64 =
SEND_ROUTE_PERIOD_SEC - SEND_ROUTE_FAST_REPLY_SEC;
if v.0 != self.version.get() || v.1 != Some(p.version) {
v.2 = Instant::now() - Duration::from_secs(3600);
} else if v.2.elapsed().as_secs() < FAST_REPLY_DURATION {
// do not send same version route info too frequently
v.2 = Instant::now() - Duration::from_secs(FAST_REPLY_DURATION);
}
});
}
if updated || packet.need_reply {
self.need_sync_notifier.notify_one();
}
}
}
#[async_trait]
impl Route for BasicRoute {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> {
*self.interface.lock().await = Some(interface);
self.sync_peer_periodically().await;
self.check_expired_sync_peer_from_remote().await;
Ok(1)
}
async fn close(&self) {}
async fn get_next_hop(&self, dst_peer_id: PeerId) -> Option<PeerId> {
match self.route_table.route_info.get(&dst_peer_id) {
Some(info) => {
return Some(info.peer_id.clone().into());
}
None => {
tracing::error!("no route info for dst_peer_id: {}", dst_peer_id);
return None;
}
}
}
async fn list_routes(&self) -> Vec<crate::rpc::Route> {
let mut routes = Vec::new();
let parse_route_info = |real_peer_id: PeerId, route_info: &SyncPeerInfo| {
let mut route = crate::rpc::Route::default();
route.ipv4_addr = if let Some(ipv4_addr) = route_info.ipv4_addr {
ipv4_addr.to_string()
} else {
"".to_string()
};
route.peer_id = real_peer_id;
route.next_hop_peer_id = route_info.peer_id;
route.cost = route_info.cost as i32;
route.proxy_cidrs = route_info.proxy_cidrs.clone();
route.hostname = route_info.hostname.clone().unwrap_or_default();
let mut stun_info = StunInfo::default();
if let Ok(udp_nat_type) = NatType::try_from(route_info.udp_stun_info as i32) {
stun_info.set_udp_nat_type(udp_nat_type);
}
route.stun_info = Some(stun_info);
route
};
self.route_table.route_info.iter().for_each(|item| {
routes.push(parse_route_info(*item.key(), item.value()));
});
routes
}
async fn get_peer_id_by_ipv4(&self, ipv4_addr: &Ipv4Addr) -> Option<PeerId> {
if let Some(peer_id) = self.route_table.ipv4_peer_id_map.get(ipv4_addr) {
return Some(*peer_id);
}
if let Some(peer_id) = self.get_peer_id_for_proxy(ipv4_addr) {
return Some(peer_id);
}
tracing::info!("no peer id for ipv4: {}", ipv4_addr);
return None;
}
}
#[async_trait::async_trait]
impl PeerPacketFilter for BasicRoute {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
let hdr = packet.peer_manager_header().unwrap();
if hdr.packet_type == PacketType::Route as u8 {
let b = packet.payload().to_vec();
self.handle_route_packet(hdr.from_peer_id.get(), b.into())
.await;
None
} else {
Some(packet)
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
common::{global_ctx::tests::get_mock_global_ctx, PeerId},
connector::udp_hole_punch::tests::replace_stun_info_collector,
peers::{
peer_manager::{PeerManager, RouteAlgoType},
peer_rip_route::Version,
tests::{connect_peer_manager, wait_route_appear},
},
rpc::NatType,
};
async fn create_mock_pmgr() -> Arc<PeerManager> {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Rip,
get_mock_global_ctx(),
s,
));
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
peer_mgr.run().await.unwrap();
peer_mgr
}
#[tokio::test]
async fn test_rip_route() {
let peer_mgr_a = create_mock_pmgr().await;
let peer_mgr_b = create_mock_pmgr().await;
let peer_mgr_c = create_mock_pmgr().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
let mgrs = vec![peer_mgr_a.clone(), peer_mgr_b.clone(), peer_mgr_c.clone()];
tokio::time::sleep(tokio::time::Duration::from_secs(4)).await;
let check_version = |version: Version, peer_id: PeerId, mgrs: &Vec<Arc<PeerManager>>| {
for mgr in mgrs.iter() {
tracing::warn!(
"check version: {:?}, {:?}, {:?}, {:?}",
version,
peer_id,
mgr,
mgr.get_basic_route().sync_peer_from_remote
);
assert_eq!(
version,
mgr.get_basic_route()
.sync_peer_from_remote
.get(&peer_id)
.unwrap()
.packet
.version,
);
assert_eq!(
mgr.get_basic_route()
.sync_peer_from_remote
.get(&peer_id)
.unwrap()
.packet
.peer_version
.unwrap(),
mgr.get_basic_route().version.get()
);
}
};
let check_sanity = || {
// check peer version in other peer mgr are correct.
check_version(
peer_mgr_b.get_basic_route().version.get(),
peer_mgr_b.my_peer_id(),
&vec![peer_mgr_a.clone(), peer_mgr_c.clone()],
);
check_version(
peer_mgr_a.get_basic_route().version.get(),
peer_mgr_a.my_peer_id(),
&vec![peer_mgr_b.clone()],
);
check_version(
peer_mgr_c.get_basic_route().version.get(),
peer_mgr_c.my_peer_id(),
&vec![peer_mgr_b.clone()],
);
};
check_sanity();
let versions = mgrs
.iter()
.map(|x| x.get_basic_route().version.get())
.collect::<Vec<_>>();
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
let versions2 = mgrs
.iter()
.map(|x| x.get_basic_route().version.get())
.collect::<Vec<_>>();
assert_eq!(versions, versions2);
check_sanity();
assert!(peer_mgr_a.get_basic_route().version.get() <= 3);
assert!(peer_mgr_b.get_basic_route().version.get() <= 6);
assert!(peer_mgr_c.get_basic_route().version.get() <= 3);
}
}

View File

@ -1,27 +1,10 @@
use std::{
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
time::Instant,
};
use std::sync::Arc;
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use prost::Message;
use tarpc::{server::Channel, transport::channel::UnboundedChannel};
use tokio::{
sync::mpsc::{self, UnboundedSender},
task::JoinSet,
};
use tracing::Instrument;
use futures::StreamExt;
use crate::{
common::{error::Error, PeerId},
rpc::TaRpcPacket,
proto::rpc_impl,
tunnel::packet_def::{PacketType, ZCPacket},
};
@ -38,33 +21,11 @@ pub trait PeerRpcManagerTransport: Send + Sync + 'static {
async fn recv(&self) -> Result<ZCPacket, Error>;
}
type PacketSender = UnboundedSender<ZCPacket>;
struct PeerRpcEndPoint {
peer_id: PeerId,
packet_sender: PacketSender,
create_time: AtomicCell<Instant>,
finished: Arc<AtomicBool>,
tasks: JoinSet<()>,
}
type PeerRpcEndPointCreator =
Box<dyn Fn(PeerId, PeerRpcTransactId) -> PeerRpcEndPoint + Send + Sync + 'static>;
#[derive(Hash, Eq, PartialEq, Clone)]
struct PeerRpcClientCtxKey(PeerId, PeerRpcServiceId, PeerRpcTransactId);
// handle rpc request from one peer
pub struct PeerRpcManager {
service_map: Arc<DashMap<PeerRpcServiceId, PacketSender>>,
tasks: JoinSet<()>,
tspt: Arc<Box<dyn PeerRpcManagerTransport>>,
service_registry: Arc<DashMap<PeerRpcServiceId, PeerRpcEndPointCreator>>,
peer_rpc_endpoints: Arc<DashMap<PeerRpcClientCtxKey, PeerRpcEndPoint>>,
client_resp_receivers: Arc<DashMap<PeerRpcClientCtxKey, PacketSender>>,
transact_id: AtomicU32,
rpc_client: rpc_impl::client::Client,
rpc_server: rpc_impl::server::Server,
}
impl std::fmt::Debug for PeerRpcManager {
@ -75,293 +36,55 @@ impl std::fmt::Debug for PeerRpcManager {
}
}
struct PacketMerger {
first_piece: Option<TaRpcPacket>,
pieces: Vec<TaRpcPacket>,
}
impl PacketMerger {
fn new() -> Self {
Self {
first_piece: None,
pieces: Vec::new(),
}
}
fn try_merge_pieces(&self) -> Option<TaRpcPacket> {
if self.first_piece.is_none() || self.pieces.is_empty() {
return None;
}
for p in &self.pieces {
// some piece is missing
if p.total_pieces == 0 {
return None;
}
}
// all pieces are received
let mut content = Vec::new();
for p in &self.pieces {
content.extend_from_slice(&p.content);
}
let mut tmpl_packet = self.first_piece.as_ref().unwrap().clone();
tmpl_packet.total_pieces = 1;
tmpl_packet.piece_idx = 0;
tmpl_packet.content = content;
Some(tmpl_packet)
}
fn feed(
&mut self,
packet: ZCPacket,
expected_tid: Option<PeerRpcTransactId>,
) -> Result<Option<TaRpcPacket>, Error> {
let payload = packet.payload();
let rpc_packet =
TaRpcPacket::decode(payload).map_err(|e| Error::MessageDecodeError(e.to_string()))?;
if expected_tid.is_some() && rpc_packet.transact_id != expected_tid.unwrap() {
return Ok(None);
}
let total_pieces = rpc_packet.total_pieces;
let piece_idx = rpc_packet.piece_idx;
// for compatibility with old version
if total_pieces == 0 && piece_idx == 0 {
return Ok(Some(rpc_packet));
}
if total_pieces > 100 || total_pieces == 0 {
return Err(Error::MessageDecodeError(format!(
"total_pieces is invalid: {}",
total_pieces
)));
}
if piece_idx >= total_pieces {
return Err(Error::MessageDecodeError(
"piece_idx >= total_pieces".to_owned(),
));
}
if self.first_piece.is_none()
|| self.first_piece.as_ref().unwrap().transact_id != rpc_packet.transact_id
|| self.first_piece.as_ref().unwrap().from_peer != rpc_packet.from_peer
{
self.first_piece = Some(rpc_packet.clone());
self.pieces.clear();
}
self.pieces
.resize(total_pieces as usize, Default::default());
self.pieces[piece_idx as usize] = rpc_packet;
Ok(self.try_merge_pieces())
}
}
impl PeerRpcManager {
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
Self {
service_map: Arc::new(DashMap::new()),
tasks: JoinSet::new(),
tspt: Arc::new(Box::new(tspt)),
service_registry: Arc::new(DashMap::new()),
peer_rpc_endpoints: Arc::new(DashMap::new()),
client_resp_receivers: Arc::new(DashMap::new()),
transact_id: AtomicU32::new(0),
rpc_client: rpc_impl::client::Client::new(),
rpc_server: rpc_impl::server::Server::new(),
}
}
pub fn run_service<S, Req>(self: &Self, service_id: PeerRpcServiceId, s: S) -> ()
where
S: tarpc::server::Serve<Req> + Clone + Send + Sync + 'static,
Req: Send + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Resp:
Send + std::fmt::Debug + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Fut: Send + 'static,
{
let tspt = self.tspt.clone();
let creator = Box::new(move |peer_id: PeerId, transact_id: PeerRpcTransactId| {
let mut tasks = JoinSet::new();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel();
let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
let finished = Arc::new(AtomicBool::new(false));
let my_peer_id_clone = tspt.my_peer_id();
let peer_id_clone = peer_id.clone();
let o = server.execute(s.clone());
tasks.spawn(o);
let tspt = tspt.clone();
let finished_clone = finished.clone();
tasks.spawn(async move {
let mut packet_merger = PacketMerger::new();
loop {
tokio::select! {
Some(resp) = client_transport.next() => {
tracing::debug!(resp = ?resp, ?transact_id, ?peer_id, "server recv packet from service provider");
if resp.is_err() {
tracing::warn!(err = ?resp.err(),
"[PEER RPC MGR] client_transport in server side got channel error, ignore it.");
continue;
}
let resp = resp.unwrap();
let serialized_resp = postcard::to_allocvec(&resp);
if serialized_resp.is_err() {
tracing::error!(error = ?serialized_resp.err(), "serialize resp failed");
continue;
}
let msgs = Self::build_rpc_packet(
tspt.my_peer_id(),
peer_id,
service_id,
transact_id,
false,
serialized_resp.as_ref().unwrap(),
);
for msg in msgs {
if let Err(e) = tspt.send(msg, peer_id).await {
tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed");
break;
}
}
finished_clone.store(true, Ordering::Relaxed);
}
Some(packet) = packet_receiver.recv() => {
tracing::trace!("recv packet from peer, packet: {:?}", packet);
let info = match packet_merger.feed(packet, None) {
Err(e) => {
tracing::error!(error = ?e, "feed packet to merger failed");
continue;
},
Ok(None) => {
continue;
},
Ok(Some(info)) => {
info
}
};
assert_eq!(info.service_id, service_id);
assert_eq!(info.from_peer, peer_id);
assert_eq!(info.transact_id, transact_id);
let decoded_ret = postcard::from_bytes(&info.content.as_slice());
if let Err(e) = decoded_ret {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
}
let decoded: tarpc::ClientMessage<Req> = decoded_ret.unwrap();
if let Err(e) = client_transport.send(decoded).await {
tracing::error!(error = ?e, "send to req to client transport failed");
}
}
else => {
tracing::warn!("[PEER RPC MGR] service runner destroy, peer_id: {}, service_id: {}", peer_id, service_id);
}
}
}
}.instrument(tracing::info_span!("service_runner", my_id = ?my_peer_id_clone, peer_id = ?peer_id_clone, service_id = ?service_id)));
tracing::info!(
"[PEER RPC MGR] create new service endpoint for peer {}, service {}",
peer_id,
service_id
);
return PeerRpcEndPoint {
peer_id,
packet_sender,
create_time: AtomicCell::new(Instant::now()),
finished,
tasks,
};
// let resp = client_transport.next().await;
});
if let Some(_) = self.service_registry.insert(service_id, creator) {
panic!(
"[PEER RPC MGR] service {} is already registered",
service_id
);
}
tracing::info!(
"[PEER RPC MGR] register service {} succeed, my_node_id {}",
service_id,
self.tspt.my_peer_id()
)
}
fn parse_rpc_packet(packet: &ZCPacket) -> Result<TaRpcPacket, Error> {
let payload = packet.payload();
TaRpcPacket::decode(payload).map_err(|e| Error::MessageDecodeError(e.to_string()))
}
fn build_rpc_packet(
from_peer: PeerId,
to_peer: PeerId,
service_id: PeerRpcServiceId,
transact_id: PeerRpcTransactId,
is_req: bool,
content: &Vec<u8>,
) -> Vec<ZCPacket> {
let mut ret = Vec::new();
let content_mtu = RPC_PACKET_CONTENT_MTU;
let total_pieces = (content.len() + content_mtu - 1) / content_mtu;
let mut cur_offset = 0;
while cur_offset < content.len() {
let mut cur_len = content_mtu;
if cur_offset + cur_len > content.len() {
cur_len = content.len() - cur_offset;
}
let mut cur_content = Vec::new();
cur_content.extend_from_slice(&content[cur_offset..cur_offset + cur_len]);
let cur_packet = TaRpcPacket {
from_peer,
to_peer,
service_id,
transact_id,
is_req,
total_pieces: total_pieces as u32,
piece_idx: (cur_offset / content_mtu) as u32,
content: cur_content,
};
cur_offset += cur_len;
let mut buf = Vec::new();
cur_packet.encode(&mut buf).unwrap();
let mut zc_packet = ZCPacket::new_with_payload(&buf);
zc_packet.fill_peer_manager_hdr(from_peer, to_peer, PacketType::TaRpc as u8);
ret.push(zc_packet);
}
ret
}
pub fn run(&self) {
self.rpc_client.run();
self.rpc_server.run();
let (server_tx, mut server_rx) = (
self.rpc_server.get_transport_sink(),
self.rpc_server.get_transport_stream(),
);
let (client_tx, mut client_rx) = (
self.rpc_client.get_transport_sink(),
self.rpc_client.get_transport_stream(),
);
let tspt = self.tspt.clone();
tokio::spawn(async move {
loop {
let packet = tokio::select! {
Some(Ok(packet)) = server_rx.next() => {
tracing::trace!(?packet, "recv rpc packet from server");
packet
}
Some(Ok(packet)) = client_rx.next() => {
tracing::trace!(?packet, "recv rpc packet from client");
packet
}
else => {
tracing::warn!("rpc transport read aborted, exiting");
break;
}
};
let dst_peer_id = packet.peer_manager_header().unwrap().to_peer_id.into();
if let Err(e) = tspt.send(packet, dst_peer_id).await {
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
}
}
});
let tspt = self.tspt.clone();
let service_registry = self.service_registry.clone();
let peer_rpc_endpoints = self.peer_rpc_endpoints.clone();
let client_resp_receivers = self.client_resp_receivers.clone();
tokio::spawn(async move {
loop {
let Ok(o) = tspt.recv().await else {
@ -369,176 +92,24 @@ impl PeerRpcManager {
break;
};
let info = Self::parse_rpc_packet(&o).unwrap();
tracing::debug!(?info, "recv rpc packet from peer");
if info.is_req {
if !service_registry.contains_key(&info.service_id) {
tracing::warn!(
"service {} not found, my_node_id: {}",
info.service_id,
tspt.my_peer_id()
);
continue;
}
let endpoint = peer_rpc_endpoints
.entry(PeerRpcClientCtxKey(
info.from_peer,
info.service_id,
info.transact_id,
))
.or_insert_with(|| {
service_registry.get(&info.service_id).unwrap()(
info.from_peer,
info.transact_id,
)
});
endpoint.packet_sender.send(o).unwrap();
} else {
if let Some(a) = client_resp_receivers.get(&PeerRpcClientCtxKey(
info.from_peer,
info.service_id,
info.transact_id,
)) {
tracing::trace!("recv resp: {:?}", info);
if let Err(e) = a.send(o) {
tracing::error!(error = ?e, "send resp to client failed");
}
} else {
tracing::warn!("client resp receiver not found, info: {:?}", info);
}
if o.peer_manager_header().unwrap().packet_type == PacketType::RpcReq as u8 {
server_tx.send(o).await.unwrap();
continue;
} else if o.peer_manager_header().unwrap().packet_type == PacketType::RpcResp as u8
{
client_tx.send(o).await.unwrap();
continue;
}
}
});
let peer_rpc_endpoints = self.peer_rpc_endpoints.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
peer_rpc_endpoints.retain(|_, v| {
v.create_time.load().elapsed().as_secs() < 30
&& !v.finished.load(Ordering::Relaxed)
});
}
});
}
#[tracing::instrument(skip(f))]
pub async fn do_client_rpc_scoped<Resp, Req, RpcRet, Fut>(
&self,
service_id: PeerRpcServiceId,
dst_peer_id: PeerId,
f: impl FnOnce(UnboundedChannel<Resp, Req>) -> Fut,
) -> RpcRet
where
Resp: serde::Serialize
+ for<'a> serde::Deserialize<'a>
+ Send
+ Sync
+ std::fmt::Debug
+ 'static,
Req: serde::Serialize
+ for<'a> serde::Deserialize<'a>
+ Send
+ Sync
+ std::fmt::Debug
+ 'static,
Fut: std::future::Future<Output = RpcRet>,
{
let mut tasks = JoinSet::new();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel();
pub fn rpc_client(&self) -> &rpc_impl::client::Client {
&self.rpc_client
}
let (client_transport, server_transport) =
tarpc::transport::channel::unbounded::<Resp, Req>();
let (mut server_s, mut server_r) = server_transport.split();
let transact_id = self
.transact_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let tspt = self.tspt.clone();
tasks.spawn(async move {
while let Some(a) = server_r.next().await {
if a.is_err() {
tracing::error!(error = ?a.err(), "channel error");
continue;
}
let req = postcard::to_allocvec(&a.unwrap());
if req.is_err() {
tracing::error!(error = ?req.err(), "bincode serialize failed");
continue;
}
let packets = Self::build_rpc_packet(
tspt.my_peer_id(),
dst_peer_id,
service_id,
transact_id,
true,
req.as_ref().unwrap(),
);
tracing::debug!(?packets, ?req, ?transact_id, "client send rpc packet to peer");
for packet in packets {
if let Err(e) = tspt.send(packet, dst_peer_id).await {
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
break;
}
}
}
tracing::warn!("[PEER RPC MGR] server trasport read aborted");
});
tasks.spawn(async move {
let mut packet_merger = PacketMerger::new();
while let Some(packet) = packet_receiver.recv().await {
tracing::trace!("tunnel recv: {:?}", packet);
let info = match packet_merger.feed(packet, Some(transact_id)) {
Err(e) => {
tracing::error!(error = ?e, "feed packet to merger failed");
continue;
}
Ok(None) => {
continue;
}
Ok(Some(info)) => info,
};
let decoded = postcard::from_bytes(&info.content.as_slice());
tracing::debug!(?info, ?decoded, "client recv rpc packet from peer");
assert_eq!(info.transact_id, transact_id);
if let Err(e) = decoded {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
}
if let Err(e) = server_s.send(decoded.unwrap()).await {
tracing::error!(error = ?e, "send to rpc server channel failed");
}
}
tracing::warn!("[PEER RPC MGR] server packet read aborted");
});
let key = PeerRpcClientCtxKey(dst_peer_id, service_id, transact_id);
let _insert_ret = self
.client_resp_receivers
.insert(key.clone(), packet_sender);
let ret = f(client_transport).await;
self.client_resp_receivers.remove(&key);
ret
pub fn rpc_server(&self) -> &rpc_impl::server::Server {
&self.rpc_server
}
pub fn my_peer_id(&self) -> PeerId {
@ -548,7 +119,7 @@ impl PeerRpcManager {
#[cfg(test)]
pub mod tests {
use std::{pin::Pin, sync::Arc, time::Duration};
use std::{pin::Pin, sync::Arc};
use futures::{SinkExt, StreamExt};
use tokio::sync::Mutex;
@ -559,31 +130,18 @@ pub mod tests {
peer_rpc::PeerRpcManager,
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
},
proto::{
rpc_impl::RpcController,
tests::{GreetingClientFactory, GreetingServer, GreetingService, SayHelloRequest},
},
tunnel::{
common::tests::wait_for_condition, packet_def::ZCPacket, ring::create_ring_tunnel_pair,
Tunnel, ZCPacketSink, ZCPacketStream,
packet_def::ZCPacket, ring::create_ring_tunnel_pair, Tunnel,
ZCPacketSink, ZCPacketStream,
},
};
use super::PeerRpcManagerTransport;
#[tarpc::service]
pub trait TestRpcService {
async fn hello(s: String) -> String;
}
#[derive(Clone)]
pub struct MockService {
pub prefix: String,
}
#[tarpc::server]
impl TestRpcService for MockService {
async fn hello(self, _: tarpc::context::Context, s: String) -> String {
format!("{} {}", self.prefix, s)
}
}
fn random_string(len: usize) -> String {
use rand::distributions::Alphanumeric;
use rand::Rng;
@ -595,6 +153,16 @@ pub mod tests {
String::from_utf8(s).unwrap()
}
pub fn register_service(rpc_mgr: &PeerRpcManager, domain: &str, delay_ms: u64, prefix: &str) {
rpc_mgr.rpc_server().registry().register(
GreetingServer::new(GreetingService {
delay_ms,
prefix: prefix.to_string(),
}),
domain,
);
}
#[tokio::test]
async fn peer_rpc_basic_test() {
struct MockTransport {
@ -630,10 +198,7 @@ pub mod tests {
my_peer_id: new_peer_id(),
});
server_rpc_mgr.run();
let s = MockService {
prefix: "hello".to_owned(),
};
server_rpc_mgr.run_service(1, s.serve());
register_service(&server_rpc_mgr, "test", 0, "Hello");
let client_rpc_mgr = PeerRpcManager::new(MockTransport {
sink: Arc::new(Mutex::new(stsr)),
@ -642,35 +207,27 @@ pub mod tests {
});
client_rpc_mgr.run();
let stub = client_rpc_mgr
.rpc_client()
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "test".to_string());
let msg = random_string(8192);
let ret = client_rpc_mgr
.do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), msg.clone()).await;
ret
})
.await;
let ret = stub
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.await
.unwrap();
println!("ret: {:?}", ret);
assert_eq!(ret.unwrap(), format!("hello {}", msg));
assert_eq!(ret.greeting, format!("Hello {}!", msg));
let msg = random_string(10);
let ret = client_rpc_mgr
.do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), msg.clone()).await;
ret
})
.await;
let ret = stub
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.await
.unwrap();
println!("ret: {:?}", ret);
assert_eq!(ret.unwrap(), format!("hello {}", msg));
wait_for_condition(
|| async { server_rpc_mgr.peer_rpc_endpoints.is_empty() },
Duration::from_secs(10),
)
.await;
assert_eq!(ret.greeting, format!("Hello {}!", msg));
}
#[tokio::test]
@ -680,6 +237,7 @@ pub mod tests {
let peer_mgr_c = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
@ -699,51 +257,42 @@ pub mod tests {
peer_mgr_b.my_peer_id()
);
let s = MockService {
prefix: "hello".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
register_service(&peer_mgr_b.get_peer_rpc_mgr(), "test", 0, "Hello");
let msg = random_string(16 * 1024);
let ip_list = peer_mgr_a
let stub = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), msg.clone()).await;
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.unwrap(), format!("hello {}", msg));
.rpc_client()
.scoped_client::<GreetingClientFactory<RpcController>>(
peer_mgr_a.my_peer_id(),
peer_mgr_b.my_peer_id(),
"test".to_string(),
);
let ret = stub
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.await
.unwrap();
assert_eq!(ret.greeting, format!("Hello {}!", msg));
// call again
let msg = random_string(16 * 1024);
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), msg.clone()).await;
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.unwrap(), format!("hello {}", msg));
let ret = stub
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.await
.unwrap();
assert_eq!(ret.greeting, format!("Hello {}!", msg));
let msg = random_string(16 * 1024);
let ip_list = peer_mgr_c
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), msg.clone()).await;
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.unwrap(), format!("hello {}", msg));
let ret = stub
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.await
.unwrap();
assert_eq!(ret.greeting, format!("Hello {}!", msg));
}
#[tokio::test]
async fn test_multi_service_with_peer_manager() {
async fn test_multi_domain_with_peer_manager() {
let peer_mgr_a = create_mock_peer_manager().await;
let peer_mgr_b = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
@ -757,42 +306,37 @@ pub mod tests {
peer_mgr_b.my_peer_id()
);
let s = MockService {
prefix: "hello_a".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
let b = MockService {
prefix: "hello_b".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(2, b.serve());
register_service(&peer_mgr_b.get_peer_rpc_mgr(), "test1", 0, "Hello");
register_service(&peer_mgr_b.get_peer_rpc_mgr(), "test2", 20000, "Hello2");
let stub1 = peer_mgr_a
.get_peer_rpc_mgr()
.rpc_client()
.scoped_client::<GreetingClientFactory<RpcController>>(
peer_mgr_a.my_peer_id(),
peer_mgr_b.my_peer_id(),
"test1".to_string(),
);
let stub2 = peer_mgr_a
.get_peer_rpc_mgr()
.rpc_client()
.scoped_client::<GreetingClientFactory<RpcController>>(
peer_mgr_a.my_peer_id(),
peer_mgr_b.my_peer_id(),
"test2".to_string(),
);
let msg = random_string(16 * 1024);
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), msg.clone()).await;
ret
})
let ret = stub1
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.await
.unwrap();
assert_eq!(ret.greeting, format!("Hello {}!", msg));
let ret = stub2
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.await;
assert_eq!(ip_list.unwrap(), format!("hello_a {}", msg));
let msg = random_string(16 * 1024);
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(2, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), msg.clone()).await;
ret
})
.await;
assert_eq!(ip_list.unwrap(), format!("hello_b {}", msg));
wait_for_condition(
|| async { peer_mgr_b.get_peer_rpc_mgr().peer_rpc_endpoints.is_empty() },
Duration::from_secs(10),
)
.await;
assert!(ret.is_err() && ret.unwrap_err().to_string().contains("Timeout"));
}
}

View File

@ -1,9 +1,6 @@
use std::{net::Ipv4Addr, sync::Arc};
use async_trait::async_trait;
use tokio_util::bytes::Bytes;
use crate::common::{error::Error, PeerId};
use crate::common::PeerId;
#[derive(Clone, Debug)]
pub enum NextHopPolicy {
@ -17,15 +14,9 @@ impl Default for NextHopPolicy {
}
}
#[async_trait]
#[async_trait::async_trait]
pub trait RouteInterface {
async fn list_peers(&self) -> Vec<PeerId>;
async fn send_route_packet(
&self,
msg: Bytes,
route_id: u8,
dst_peer_id: PeerId,
) -> Result<(), Error>;
fn my_peer_id(&self) -> PeerId;
}
@ -56,7 +47,7 @@ impl RouteCostCalculatorInterface for DefaultRouteCostCalculator {}
pub type RouteCostCalculator = Box<dyn RouteCostCalculatorInterface>;
#[async_trait]
#[async_trait::async_trait]
#[auto_impl::auto_impl(Box, Arc)]
pub trait Route {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()>;
@ -71,7 +62,7 @@ pub trait Route {
self.get_next_hop(peer_id).await
}
async fn list_routes(&self) -> Vec<crate::rpc::Route>;
async fn list_routes(&self) -> Vec<crate::proto::cli::Route>;
async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option<PeerId> {
None

View File

@ -1,14 +1,17 @@
use std::sync::Arc;
use crate::rpc::{
cli::PeerInfo, peer_manage_rpc_server::PeerManageRpc, DumpRouteRequest, DumpRouteResponse,
ListForeignNetworkRequest, ListForeignNetworkResponse, ListPeerRequest, ListPeerResponse,
ListRouteRequest, ListRouteResponse, ShowNodeInfoRequest, ShowNodeInfoResponse,
use crate::proto::{
cli::{
DumpRouteRequest, DumpRouteResponse, ListForeignNetworkRequest, ListForeignNetworkResponse,
ListPeerRequest, ListPeerResponse, ListRouteRequest, ListRouteResponse, PeerInfo,
PeerManageRpc, ShowNodeInfoRequest, ShowNodeInfoResponse,
},
rpc_types::{self, controller::BaseController},
};
use tonic::{Request, Response, Status};
use super::peer_manager::PeerManager;
#[derive(Clone)]
pub struct PeerManagerRpcService {
peer_manager: Arc<PeerManager>,
}
@ -36,12 +39,14 @@ impl PeerManagerRpcService {
}
}
#[tonic::async_trait]
#[async_trait::async_trait]
impl PeerManageRpc for PeerManagerRpcService {
type Controller = BaseController;
async fn list_peer(
&self,
_request: Request<ListPeerRequest>, // Accept request of type HelloRequest
) -> Result<Response<ListPeerResponse>, Status> {
_: BaseController,
_request: ListPeerRequest, // Accept request of type HelloRequest
) -> Result<ListPeerResponse, rpc_types::error::Error> {
let mut reply = ListPeerResponse::default();
let peers = self.list_peers().await;
@ -49,45 +54,49 @@ impl PeerManageRpc for PeerManagerRpcService {
reply.peer_infos.push(peer);
}
Ok(Response::new(reply))
Ok(reply)
}
async fn list_route(
&self,
_request: Request<ListRouteRequest>, // Accept request of type HelloRequest
) -> Result<Response<ListRouteResponse>, Status> {
_: BaseController,
_request: ListRouteRequest, // Accept request of type HelloRequest
) -> Result<ListRouteResponse, rpc_types::error::Error> {
let mut reply = ListRouteResponse::default();
reply.routes = self.peer_manager.list_routes().await;
Ok(Response::new(reply))
Ok(reply)
}
async fn dump_route(
&self,
_request: Request<DumpRouteRequest>, // Accept request of type HelloRequest
) -> Result<Response<DumpRouteResponse>, Status> {
_: BaseController,
_request: DumpRouteRequest, // Accept request of type HelloRequest
) -> Result<DumpRouteResponse, rpc_types::error::Error> {
let mut reply = DumpRouteResponse::default();
reply.result = self.peer_manager.dump_route().await;
Ok(Response::new(reply))
Ok(reply)
}
async fn list_foreign_network(
&self,
_request: Request<ListForeignNetworkRequest>, // Accept request of type HelloRequest
) -> Result<Response<ListForeignNetworkResponse>, Status> {
_: BaseController,
_request: ListForeignNetworkRequest, // Accept request of type HelloRequest
) -> Result<ListForeignNetworkResponse, rpc_types::error::Error> {
let reply = self
.peer_manager
.get_foreign_network_manager()
.list_foreign_networks()
.await;
Ok(Response::new(reply))
Ok(reply)
}
async fn show_node_info(
&self,
_request: Request<ShowNodeInfoRequest>, // Accept request of type HelloRequest
) -> Result<Response<ShowNodeInfoResponse>, Status> {
Ok(Response::new(ShowNodeInfoResponse {
_: BaseController,
_request: ShowNodeInfoRequest, // Accept request of type HelloRequest
) -> Result<ShowNodeInfoResponse, rpc_types::error::Error> {
Ok(ShowNodeInfoResponse {
node_info: Some(self.peer_manager.get_my_info()),
}))
})
}
}

View File

@ -1,4 +1,7 @@
syntax = "proto3";
import "common.proto";
package cli;
message Status {
@ -16,18 +19,12 @@ message PeerConnStats {
uint64 latency_us = 5;
}
message TunnelInfo {
string tunnel_type = 1;
string local_addr = 2;
string remote_addr = 3;
}
message PeerConnInfo {
string conn_id = 1;
uint32 my_peer_id = 2;
uint32 peer_id = 3;
repeated string features = 4;
TunnelInfo tunnel = 5;
common.TunnelInfo tunnel = 5;
PeerConnStats stats = 6;
float loss_rate = 7;
bool is_client = 8;
@ -46,27 +43,6 @@ message ListPeerResponse {
NodeInfo my_info = 2;
}
enum NatType {
// has NAT; but own a single public IP, port is not changed
Unknown = 0;
OpenInternet = 1;
NoPAT = 2;
FullCone = 3;
Restricted = 4;
PortRestricted = 5;
Symmetric = 6;
SymUdpFirewall = 7;
}
message StunInfo {
NatType udp_nat_type = 1;
NatType tcp_nat_type = 2;
int64 last_update_time = 3;
repeated string public_ip = 4;
uint32 min_port = 5;
uint32 max_port = 6;
}
message Route {
uint32 peer_id = 1;
string ipv4_addr = 2;
@ -74,7 +50,7 @@ message Route {
int32 cost = 4;
repeated string proxy_cidrs = 5;
string hostname = 6;
StunInfo stun_info = 7;
common.StunInfo stun_info = 7;
string inst_id = 8;
string version = 9;
}
@ -84,7 +60,7 @@ message NodeInfo {
string ipv4_addr = 2;
repeated string proxy_cidrs = 3;
string hostname = 4;
StunInfo stun_info = 5;
common.StunInfo stun_info = 5;
string inst_id = 6;
repeated string listeners = 7;
string config = 8;
@ -127,7 +103,7 @@ enum ConnectorStatus {
}
message Connector {
string url = 1;
common.Url url = 1;
ConnectorStatus status = 2;
}
@ -142,7 +118,7 @@ enum ConnectorManageAction {
message ManageConnectorRequest {
ConnectorManageAction action = 1;
string url = 2;
common.Url url = 2;
}
message ManageConnectorResponse {}
@ -152,23 +128,6 @@ service ConnectorManageRpc {
rpc ManageConnector(ManageConnectorRequest) returns (ManageConnectorResponse);
}
message DirectConnectedPeerInfo { int32 latency_ms = 1; }
message PeerInfoForGlobalMap {
map<uint32, DirectConnectedPeerInfo> direct_peers = 1;
}
message GetGlobalPeerMapRequest {}
message GetGlobalPeerMapResponse {
map<uint32, PeerInfoForGlobalMap> global_peer_map = 1;
}
service PeerCenterRpc {
rpc GetGlobalPeerMap(GetGlobalPeerMapRequest)
returns (GetGlobalPeerMapResponse);
}
message VpnPortalInfo {
string vpn_type = 1;
string client_config = 2;
@ -182,24 +141,3 @@ service VpnPortalRpc {
rpc GetVpnPortalInfo(GetVpnPortalInfoRequest)
returns (GetVpnPortalInfoResponse);
}
message HandshakeRequest {
uint32 magic = 1;
uint32 my_peer_id = 2;
uint32 version = 3;
repeated string features = 4;
string network_name = 5;
bytes network_secret_digrest = 6;
}
message TaRpcPacket {
uint32 from_peer = 1;
uint32 to_peer = 2;
uint32 service_id = 3;
uint32 transact_id = 4;
bool is_req = 5;
bytes content = 6;
uint32 total_pieces = 7;
uint32 piece_idx = 8;
}

View File

@ -0,0 +1 @@
include!(concat!(env!("OUT_DIR"), "/cli.rs"));

View File

@ -0,0 +1,92 @@
syntax = "proto3";
import "error.proto";
package common;
message RpcDescriptor {
// allow same service registered multiple times in different domain
string domain_name = 1;
string proto_name = 2;
string service_name = 3;
uint32 method_index = 4;
}
message RpcRequest {
RpcDescriptor descriptor = 1;
bytes request = 2;
int32 timeout_ms = 3;
}
message RpcResponse {
bytes response = 1;
error.Error error = 2;
uint64 runtime_us = 3;
}
message RpcPacket {
uint32 from_peer = 1;
uint32 to_peer = 2;
int64 transaction_id = 3;
RpcDescriptor descriptor = 4;
bytes body = 5;
bool is_request = 6;
uint32 total_pieces = 7;
uint32 piece_idx = 8;
int32 trace_id = 9;
}
message UUID {
uint64 high = 1;
uint64 low = 2;
}
enum NatType {
// has NAT; but own a single public IP, port is not changed
Unknown = 0;
OpenInternet = 1;
NoPAT = 2;
FullCone = 3;
Restricted = 4;
PortRestricted = 5;
Symmetric = 6;
SymUdpFirewall = 7;
}
message Ipv4Addr { uint32 addr = 1; }
message Ipv6Addr {
uint64 high = 1;
uint64 low = 2;
}
message Url { string url = 1; }
message SocketAddr {
oneof ip {
Ipv4Addr ipv4 = 1;
Ipv6Addr ipv6 = 2;
};
uint32 port = 3;
}
message TunnelInfo {
string tunnel_type = 1;
common.Url local_addr = 2;
common.Url remote_addr = 3;
}
message StunInfo {
NatType udp_nat_type = 1;
NatType tcp_nat_type = 2;
int64 last_update_time = 3;
repeated string public_ip = 4;
uint32 min_port = 5;
uint32 max_port = 6;
}

View File

@ -0,0 +1,131 @@
use std::{fmt::Display, str::FromStr};
include!(concat!(env!("OUT_DIR"), "/common.rs"));
impl From<uuid::Uuid> for Uuid {
fn from(uuid: uuid::Uuid) -> Self {
let (high, low) = uuid.as_u64_pair();
Uuid { low, high }
}
}
impl From<Uuid> for uuid::Uuid {
fn from(uuid: Uuid) -> Self {
uuid::Uuid::from_u64_pair(uuid.high, uuid.low)
}
}
impl Display for Uuid {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", uuid::Uuid::from(self.clone()))
}
}
impl From<std::net::Ipv4Addr> for Ipv4Addr {
fn from(value: std::net::Ipv4Addr) -> Self {
Self {
addr: u32::from_be_bytes(value.octets()),
}
}
}
impl From<Ipv4Addr> for std::net::Ipv4Addr {
fn from(value: Ipv4Addr) -> Self {
std::net::Ipv4Addr::from(value.addr)
}
}
impl ToString for Ipv4Addr {
fn to_string(&self) -> String {
std::net::Ipv4Addr::from(self.addr).to_string()
}
}
impl From<std::net::Ipv6Addr> for Ipv6Addr {
fn from(value: std::net::Ipv6Addr) -> Self {
let b = value.octets();
Self {
low: u64::from_be_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]),
high: u64::from_be_bytes([b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]]),
}
}
}
impl From<Ipv6Addr> for std::net::Ipv6Addr {
fn from(value: Ipv6Addr) -> Self {
let low = value.low.to_be_bytes();
let high = value.high.to_be_bytes();
std::net::Ipv6Addr::from([
low[0], low[1], low[2], low[3], low[4], low[5], low[6], low[7], high[0], high[1],
high[2], high[3], high[4], high[5], high[6], high[7],
])
}
}
impl ToString for Ipv6Addr {
fn to_string(&self) -> String {
std::net::Ipv6Addr::from(self.clone()).to_string()
}
}
impl From<url::Url> for Url {
fn from(value: url::Url) -> Self {
Url {
url: value.to_string(),
}
}
}
impl From<Url> for url::Url {
fn from(value: Url) -> Self {
url::Url::parse(&value.url).unwrap()
}
}
impl FromStr for Url {
type Err = url::ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Url {
url: s.parse::<url::Url>()?.to_string(),
})
}
}
impl Display for Url {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.url)
}
}
impl From<std::net::SocketAddr> for SocketAddr {
fn from(value: std::net::SocketAddr) -> Self {
match value {
std::net::SocketAddr::V4(v4) => SocketAddr {
ip: Some(socket_addr::Ip::Ipv4(v4.ip().clone().into())),
port: v4.port() as u32,
},
std::net::SocketAddr::V6(v6) => SocketAddr {
ip: Some(socket_addr::Ip::Ipv6(v6.ip().clone().into())),
port: v6.port() as u32,
},
}
}
}
impl From<SocketAddr> for std::net::SocketAddr {
fn from(value: SocketAddr) -> Self {
match value.ip.unwrap() {
socket_addr::Ip::Ipv4(ip) => std::net::SocketAddr::V4(std::net::SocketAddrV4::new(
std::net::Ipv4Addr::from(ip),
value.port as u16,
)),
socket_addr::Ip::Ipv6(ip) => std::net::SocketAddr::V6(std::net::SocketAddrV6::new(
std::net::Ipv6Addr::from(ip),
value.port as u16,
0,
0,
)),
}
}
}

View File

@ -0,0 +1,34 @@
syntax = "proto3";
package error;
message OtherError { string error_message = 1; }
message InvalidMethodIndex {
string service_name = 1;
uint32 method_index = 2;
}
message InvalidService { string service_name = 1; }
message ProstDecodeError {}
message ProstEncodeError {}
message ExecuteError { string error_message = 1; }
message MalformatRpcPacket { string error_message = 1; }
message Timeout { string error_message = 1; }
message Error {
oneof error {
OtherError other_error = 1;
InvalidMethodIndex invalid_method_index = 2;
InvalidService invalid_service = 3;
ProstDecodeError prost_decode_error = 4;
ProstEncodeError prost_encode_error = 5;
ExecuteError execute_error = 6;
MalformatRpcPacket malformat_rpc_packet = 7;
Timeout timeout = 8;
}
}

View File

@ -0,0 +1,84 @@
use prost::DecodeError;
use super::rpc_types;
include!(concat!(env!("OUT_DIR"), "/error.rs"));
impl From<&rpc_types::error::Error> for Error {
fn from(e: &rpc_types::error::Error) -> Self {
use super::error::error::Error as ProtoError;
match e {
rpc_types::error::Error::ExecutionError(e) => Self {
error: Some(ProtoError::ExecuteError(ExecuteError {
error_message: e.to_string(),
})),
},
rpc_types::error::Error::DecodeError(_) => Self {
error: Some(ProtoError::ProstDecodeError(ProstDecodeError {})),
},
rpc_types::error::Error::EncodeError(_) => Self {
error: Some(ProtoError::ProstEncodeError(ProstEncodeError {})),
},
rpc_types::error::Error::InvalidMethodIndex(m, s) => Self {
error: Some(ProtoError::InvalidMethodIndex(InvalidMethodIndex {
method_index: *m as u32,
service_name: s.to_string(),
})),
},
rpc_types::error::Error::InvalidServiceKey(s, _) => Self {
error: Some(ProtoError::InvalidService(InvalidService {
service_name: s.to_string(),
})),
},
rpc_types::error::Error::MalformatRpcPacket(e) => Self {
error: Some(ProtoError::MalformatRpcPacket(MalformatRpcPacket {
error_message: e.to_string(),
})),
},
rpc_types::error::Error::Timeout(e) => Self {
error: Some(ProtoError::Timeout(Timeout {
error_message: e.to_string(),
})),
},
#[allow(unreachable_patterns)]
e => Self {
error: Some(ProtoError::OtherError(OtherError {
error_message: e.to_string(),
})),
},
}
}
}
impl From<&Error> for rpc_types::error::Error {
fn from(e: &Error) -> Self {
use super::error::error::Error as ProtoError;
match &e.error {
Some(ProtoError::ExecuteError(e)) => {
Self::ExecutionError(anyhow::anyhow!(e.error_message.clone()))
}
Some(ProtoError::ProstDecodeError(_)) => {
Self::DecodeError(DecodeError::new("decode error"))
}
Some(ProtoError::ProstEncodeError(_)) => {
Self::DecodeError(DecodeError::new("encode error"))
}
Some(ProtoError::InvalidMethodIndex(e)) => {
Self::InvalidMethodIndex(e.method_index as u8, e.service_name.clone())
}
Some(ProtoError::InvalidService(e)) => {
Self::InvalidServiceKey(e.service_name.clone(), "".to_string())
}
Some(ProtoError::MalformatRpcPacket(e)) => {
Self::MalformatRpcPacket(e.error_message.clone())
}
Some(ProtoError::Timeout(e)) => {
Self::ExecutionError(anyhow::anyhow!(e.error_message.clone()))
}
Some(ProtoError::OtherError(e)) => {
Self::ExecutionError(anyhow::anyhow!(e.error_message.clone()))
}
None => Self::ExecutionError(anyhow::anyhow!("unknown error {:?}", e)),
}
}
}

View File

@ -0,0 +1,9 @@
pub mod rpc_impl;
pub mod rpc_types;
pub mod cli;
pub mod common;
pub mod error;
pub mod peer_rpc;
pub mod tests;

View File

@ -0,0 +1,129 @@
syntax = "proto3";
import "google/protobuf/timestamp.proto";
import "common.proto";
package peer_rpc;
message RoutePeerInfo {
// means next hop in route table.
uint32 peer_id = 1;
common.UUID inst_id = 2;
uint32 cost = 3;
optional common.Ipv4Addr ipv4_addr = 4;
repeated string proxy_cidrs = 5;
optional string hostname = 6;
common.NatType udp_stun_info = 7;
google.protobuf.Timestamp last_update = 8;
uint32 version = 9;
}
message PeerIdVersion {
uint32 peer_id = 1;
uint32 version = 2;
}
message RouteConnBitmap {
repeated PeerIdVersion peer_ids = 1;
bytes bitmap = 2;
}
message RoutePeerInfos { repeated RoutePeerInfo items = 1; }
message SyncRouteInfoRequest {
uint32 my_peer_id = 1;
uint64 my_session_id = 2;
bool is_initiator = 3;
RoutePeerInfos peer_infos = 4;
RouteConnBitmap conn_bitmap = 5;
}
enum SyncRouteInfoError {
DuplicatePeerId = 0;
Stopped = 1;
}
message SyncRouteInfoResponse {
bool is_initiator = 1;
uint64 session_id = 2;
optional SyncRouteInfoError error = 3;
}
service OspfRouteRpc {
// Generates a "hello" greeting based on the supplied info.
rpc SyncRouteInfo(SyncRouteInfoRequest) returns (SyncRouteInfoResponse);
}
message GetIpListRequest {}
message GetIpListResponse {
common.Ipv4Addr public_ipv4 = 1;
repeated common.Ipv4Addr interface_ipv4s = 2;
common.Ipv6Addr public_ipv6 = 3;
repeated common.Ipv6Addr interface_ipv6s = 4;
repeated common.Url listeners = 5;
}
service DirectConnectorRpc {
rpc GetIpList(GetIpListRequest) returns (GetIpListResponse);
}
message TryPunchHoleRequest { common.SocketAddr local_mapped_addr = 1; }
message TryPunchHoleResponse { common.SocketAddr remote_mapped_addr = 1; }
message TryPunchSymmetricRequest {
common.SocketAddr listener_addr = 1;
uint32 port = 2;
repeated common.Ipv4Addr public_ips = 3;
uint32 min_port = 4;
uint32 max_port = 5;
uint32 transaction_id = 6;
uint32 round = 7;
uint32 last_port_index = 8;
}
message TryPunchSymmetricResponse { uint32 last_port_index = 1; }
service UdpHolePunchRpc {
rpc TryPunchHole(TryPunchHoleRequest) returns (TryPunchHoleResponse);
rpc TryPunchSymmetric(TryPunchSymmetricRequest)
returns (TryPunchSymmetricResponse);
}
message DirectConnectedPeerInfo { int32 latency_ms = 1; }
message PeerInfoForGlobalMap {
map<uint32, DirectConnectedPeerInfo> direct_peers = 1;
}
message ReportPeersRequest {
uint32 my_peer_id = 1;
PeerInfoForGlobalMap peer_infos = 2;
}
message ReportPeersResponse {}
message GlobalPeerMap { map<uint32, PeerInfoForGlobalMap> map = 1; }
message GetGlobalPeerMapRequest { uint64 digest = 1; }
message GetGlobalPeerMapResponse {
map<uint32, PeerInfoForGlobalMap> global_peer_map = 1;
optional uint64 digest = 2;
}
service PeerCenterRpc {
rpc ReportPeers(ReportPeersRequest) returns (ReportPeersResponse);
rpc GetGlobalPeerMap(GetGlobalPeerMapRequest)
returns (GetGlobalPeerMapResponse);
}
message HandshakeRequest {
uint32 magic = 1;
uint32 my_peer_id = 2;
uint32 version = 3;
repeated string features = 4;
string network_name = 5;
bytes network_secret_digrest = 6;
}

View File

@ -0,0 +1 @@
include!(concat!(env!("OUT_DIR"), "/peer_rpc.rs"));

View File

@ -0,0 +1,8 @@
[package]
name = "rpc_build"
version = "0.1.0"
edition = "2021"
[dependencies]
heck = "0.5"
prost-build = "0.13"

View File

@ -0,0 +1,383 @@
extern crate heck;
extern crate prost_build;
use std::fmt;
const NAMESPACE: &str = "crate::proto::rpc_types";
/// The service generator to be used with `prost-build` to generate RPC implementations for
/// `prost-simple-rpc`.
///
/// See the crate-level documentation for more info.
#[allow(missing_copy_implementations)]
#[derive(Clone, Debug)]
pub struct ServiceGenerator {
_private: (),
}
impl ServiceGenerator {
/// Create a new `ServiceGenerator` instance with the default options set.
pub fn new() -> ServiceGenerator {
ServiceGenerator { _private: () }
}
}
impl prost_build::ServiceGenerator for ServiceGenerator {
fn generate(&mut self, service: prost_build::Service, mut buf: &mut String) {
use std::fmt::Write;
let descriptor_name = format!("{}Descriptor", service.name);
let server_name = format!("{}Server", service.name);
let client_name = format!("{}Client", service.name);
let method_descriptor_name = format!("{}MethodDescriptor", service.name);
let mut trait_methods = String::new();
let mut enum_methods = String::new();
let mut list_enum_methods = String::new();
let mut client_methods = String::new();
let mut client_own_methods = String::new();
let mut match_name_methods = String::new();
let mut match_proto_name_methods = String::new();
let mut match_input_type_methods = String::new();
let mut match_input_proto_type_methods = String::new();
let mut match_output_type_methods = String::new();
let mut match_output_proto_type_methods = String::new();
let mut match_handle_methods = String::new();
let mut match_method_try_from = String::new();
for (idx, method) in service.methods.iter().enumerate() {
assert!(
!method.client_streaming,
"Client streaming not yet supported for method {}",
method.proto_name
);
assert!(
!method.server_streaming,
"Server streaming not yet supported for method {}",
method.proto_name
);
ServiceGenerator::write_comments(&mut trait_methods, 4, &method.comments).unwrap();
writeln!(
trait_methods,
r#" async fn {name}(&self, ctrl: Self::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}>;"#,
name = method.name,
input_type = method.input_type,
output_type = method.output_type,
namespace = NAMESPACE,
)
.unwrap();
ServiceGenerator::write_comments(&mut enum_methods, 4, &method.comments).unwrap();
writeln!(
enum_methods,
" {name} = {index},",
name = method.proto_name,
index = format!("{}", idx + 1)
)
.unwrap();
writeln!(
match_method_try_from,
" {index} => Ok({service_name}MethodDescriptor::{name}),",
service_name = service.name,
name = method.proto_name,
index = format!("{}", idx + 1),
)
.unwrap();
writeln!(
list_enum_methods,
" {service_name}MethodDescriptor::{name},",
service_name = service.name,
name = method.proto_name
)
.unwrap();
writeln!(
client_methods,
r#" async fn {name}(&self, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{
{client_name}::{name}_inner(self.0.clone(), ctrl, input).await
}}"#,
name = method.name,
input_type = method.input_type,
output_type = method.output_type,
client_name = format!("{}Client", service.name),
namespace = NAMESPACE,
)
.unwrap();
writeln!(
client_own_methods,
r#" async fn {name}_inner(handler: H, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{
{namespace}::__rt::call_method(handler, ctrl, {method_descriptor_name}::{proto_name}, input).await
}}"#,
name = method.name,
method_descriptor_name = method_descriptor_name,
proto_name = method.proto_name,
input_type = method.input_type,
output_type = method.output_type,
namespace = NAMESPACE,
).unwrap();
let case = format!(
" {service_name}MethodDescriptor::{proto_name} => ",
service_name = service.name,
proto_name = method.proto_name
);
writeln!(match_name_methods, "{}{:?},", case, method.name).unwrap();
writeln!(match_proto_name_methods, "{}{:?},", case, method.proto_name).unwrap();
writeln!(
match_input_type_methods,
"{}::std::any::TypeId::of::<{}>(),",
case, method.input_type
)
.unwrap();
writeln!(
match_input_proto_type_methods,
"{}{:?},",
case, method.input_proto_type
)
.unwrap();
writeln!(
match_output_type_methods,
"{}::std::any::TypeId::of::<{}>(),",
case, method.output_type
)
.unwrap();
writeln!(
match_output_proto_type_methods,
"{}{:?},",
case, method.output_proto_type
)
.unwrap();
write!(
match_handle_methods,
r#"{} {{
let decoded: {input_type} = {namespace}::__rt::decode(input)?;
let ret = service.{name}(ctrl, decoded).await?;
{namespace}::__rt::encode(ret)
}}
"#,
case,
input_type = method.input_type,
name = method.name,
namespace = NAMESPACE,
)
.unwrap();
}
ServiceGenerator::write_comments(&mut buf, 0, &service.comments).unwrap();
write!(
buf,
r#"
#[async_trait::async_trait]
#[auto_impl::auto_impl(&, Arc, Box)]
pub trait {name} {{
type Controller: {namespace}::controller::Controller;
{trait_methods}
}}
/// A service descriptor for a `{name}`.
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd, Default)]
pub struct {descriptor_name};
/// Methods available on a `{name}`.
///
/// This can be used as a key when routing requests for servers/clients of a `{name}`.
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
#[repr(u8)]
pub enum {method_descriptor_name} {{
{enum_methods}
}}
impl std::convert::TryFrom<u8> for {method_descriptor_name} {{
type Error = {namespace}::error::Error;
fn try_from(value: u8) -> {namespace}::error::Result<Self> {{
match value {{
{match_method_try_from}
_ => Err({namespace}::error::Error::InvalidMethodIndex(value, "{name}".to_string())),
}}
}}
}}
/// A client for a `{name}`.
///
/// This implements the `{name}` trait by dispatching all method calls to the supplied `Handler`.
#[derive(Clone, Debug)]
pub struct {client_name}<H>(H) where H: {namespace}::handler::Handler;
impl<H> {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
/// Creates a new client instance that delegates all method calls to the supplied handler.
pub fn new(handler: H) -> {client_name}<H> {{
{client_name}(handler)
}}
}}
impl<H> {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
{client_own_methods}
}}
#[async_trait::async_trait]
impl<H> {name} for {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
type Controller = H::Controller;
{client_methods}
}}
pub struct {client_name}Factory<C: {namespace}::controller::Controller>(std::marker::PhantomData<C>);
impl<C: {namespace}::controller::Controller> Clone for {client_name}Factory<C> {{
fn clone(&self) -> Self {{
Self(std::marker::PhantomData)
}}
}}
impl<C> {namespace}::__rt::RpcClientFactory for {client_name}Factory<C> where C: {namespace}::controller::Controller {{
type Descriptor = {descriptor_name};
type ClientImpl = Box<dyn {name}<Controller = C> + Send + 'static>;
type Controller = C;
fn new(handler: impl {namespace}::handler::Handler<Descriptor = Self::Descriptor, Controller = Self::Controller>) -> Self::ClientImpl {{
Box::new({client_name}::new(handler))
}}
}}
/// A server for a `{name}`.
///
/// This implements the `Server` trait by handling requests and dispatch them to methods on the
/// supplied `{name}`.
#[derive(Clone, Debug)]
pub struct {server_name}<A>(A) where A: {name} + Clone + Send + 'static;
impl<A> {server_name}<A> where A: {name} + Clone + Send + 'static {{
/// Creates a new server instance that dispatches all calls to the supplied service.
pub fn new(service: A) -> {server_name}<A> {{
{server_name}(service)
}}
async fn call_inner(
service: A,
method: {method_descriptor_name},
ctrl: A::Controller,
input: ::bytes::Bytes)
-> {namespace}::error::Result<::bytes::Bytes> {{
match method {{
{match_handle_methods}
}}
}}
}}
impl {namespace}::descriptor::ServiceDescriptor for {descriptor_name} {{
type Method = {method_descriptor_name};
fn name(&self) -> &'static str {{ {name:?} }}
fn proto_name(&self) -> &'static str {{ {proto_name:?} }}
fn package(&self) -> &'static str {{ {package:?} }}
fn methods(&self) -> &'static [Self::Method] {{
&[ {list_enum_methods} ]
}}
}}
#[async_trait::async_trait]
impl<A> {namespace}::handler::Handler for {server_name}<A>
where
A: {name} + Clone + Send + Sync + 'static {{
type Descriptor = {descriptor_name};
type Controller = A::Controller;
async fn call(
&self,
ctrl: A::Controller,
method: {method_descriptor_name},
input: ::bytes::Bytes)
-> {namespace}::error::Result<::bytes::Bytes> {{
{server_name}::call_inner(self.0.clone(), method, ctrl, input).await
}}
}}
impl {namespace}::descriptor::MethodDescriptor for {method_descriptor_name} {{
fn name(&self) -> &'static str {{
match *self {{
{match_name_methods}
}}
}}
fn proto_name(&self) -> &'static str {{
match *self {{
{match_proto_name_methods}
}}
}}
fn input_type(&self) -> ::std::any::TypeId {{
match *self {{
{match_input_type_methods}
}}
}}
fn input_proto_type(&self) -> &'static str {{
match *self {{
{match_input_proto_type_methods}
}}
}}
fn output_type(&self) -> ::std::any::TypeId {{
match *self {{
{match_output_type_methods}
}}
}}
fn output_proto_type(&self) -> &'static str {{
match *self {{
{match_output_proto_type_methods}
}}
}}
fn index(&self) -> u8 {{
*self as u8
}}
}}
"#,
name = service.name,
descriptor_name = descriptor_name,
server_name = server_name,
client_name = client_name,
method_descriptor_name = method_descriptor_name,
proto_name = service.proto_name,
package = service.package,
trait_methods = trait_methods,
enum_methods = enum_methods,
list_enum_methods = list_enum_methods,
client_own_methods = client_own_methods,
client_methods = client_methods,
match_name_methods = match_name_methods,
match_proto_name_methods = match_proto_name_methods,
match_input_type_methods = match_input_type_methods,
match_input_proto_type_methods = match_input_proto_type_methods,
match_output_type_methods = match_output_type_methods,
match_output_proto_type_methods = match_output_proto_type_methods,
match_handle_methods = match_handle_methods,
namespace = NAMESPACE,
).unwrap();
}
}
impl ServiceGenerator {
fn write_comments<W>(
mut write: W,
indent: usize,
comments: &prost_build::Comments,
) -> fmt::Result
where
W: fmt::Write,
{
for comment in &comments.leading {
for line in comment.lines().filter(|s| !s.is_empty()) {
writeln!(write, "{}///{}", " ".repeat(indent), line)?;
}
}
Ok(())
}
}

View File

@ -0,0 +1,240 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use bytes::Bytes;
use dashmap::DashMap;
use prost::Message;
use tokio::sync::mpsc;
use tokio::task::JoinSet;
use tokio::time::timeout;
use tokio_stream::StreamExt;
use crate::common::PeerId;
use crate::defer;
use crate::proto::common::{RpcDescriptor, RpcPacket, RpcRequest, RpcResponse};
use crate::proto::rpc_impl::packet::build_rpc_packet;
use crate::proto::rpc_types::controller::Controller;
use crate::proto::rpc_types::descriptor::MethodDescriptor;
use crate::proto::rpc_types::{
__rt::RpcClientFactory, descriptor::ServiceDescriptor, handler::Handler,
};
use crate::proto::rpc_types::error::Result;
use crate::tunnel::mpsc::{MpscTunnel, MpscTunnelSender};
use crate::tunnel::packet_def::ZCPacket;
use crate::tunnel::ring::create_ring_tunnel_pair;
use crate::tunnel::{Tunnel, TunnelError, ZCPacketStream};
use super::packet::PacketMerger;
use super::{RpcTransactId, Transport};
static CUR_TID: once_cell::sync::Lazy<atomic_shim::AtomicI64> =
once_cell::sync::Lazy::new(|| atomic_shim::AtomicI64::new(rand::random()));
type RpcPacketSender = mpsc::UnboundedSender<RpcPacket>;
type RpcPacketReceiver = mpsc::UnboundedReceiver<RpcPacket>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct InflightRequestKey {
from_peer_id: PeerId,
to_peer_id: PeerId,
transaction_id: RpcTransactId,
}
struct InflightRequest {
sender: RpcPacketSender,
merger: PacketMerger,
start_time: std::time::Instant,
}
type InflightRequestTable = Arc<DashMap<InflightRequestKey, InflightRequest>>;
pub struct Client {
mpsc: Mutex<MpscTunnel<Box<dyn Tunnel>>>,
transport: Mutex<Transport>,
inflight_requests: InflightRequestTable,
tasks: Arc<Mutex<JoinSet<()>>>,
}
impl Client {
pub fn new() -> Self {
let (ring_a, ring_b) = create_ring_tunnel_pair();
Self {
mpsc: Mutex::new(MpscTunnel::new(ring_a)),
transport: Mutex::new(MpscTunnel::new(ring_b)),
inflight_requests: Arc::new(DashMap::new()),
tasks: Arc::new(Mutex::new(JoinSet::new())),
}
}
pub fn get_transport_sink(&self) -> MpscTunnelSender {
self.transport.lock().unwrap().get_sink()
}
pub fn get_transport_stream(&self) -> Pin<Box<dyn ZCPacketStream>> {
self.transport.lock().unwrap().get_stream()
}
pub fn run(&self) {
let mut tasks = self.tasks.lock().unwrap();
let mut rx = self.mpsc.lock().unwrap().get_stream();
let inflight_requests = self.inflight_requests.clone();
tasks.spawn(async move {
while let Some(packet) = rx.next().await {
if let Err(err) = packet {
tracing::error!(?err, "Failed to receive packet");
continue;
}
let packet = match RpcPacket::decode(packet.unwrap().payload()) {
Err(err) => {
tracing::error!(?err, "Failed to decode packet");
continue;
}
Ok(packet) => packet,
};
if packet.is_request {
tracing::warn!(?packet, "Received non-response packet");
continue;
}
let key = InflightRequestKey {
from_peer_id: packet.to_peer,
to_peer_id: packet.from_peer,
transaction_id: packet.transaction_id,
};
let Some(mut inflight_request) = inflight_requests.get_mut(&key) else {
tracing::warn!(?key, "No inflight request found for key");
continue;
};
let ret = inflight_request.merger.feed(packet);
match ret {
Ok(Some(rpc_packet)) => {
inflight_request.sender.send(rpc_packet).unwrap();
}
Ok(None) => {}
Err(err) => {
tracing::error!(?err, "Failed to feed packet to merger");
}
}
}
});
}
pub fn scoped_client<F: RpcClientFactory>(
&self,
from_peer_id: PeerId,
to_peer_id: PeerId,
domain_name: String,
) -> F::ClientImpl {
#[derive(Clone)]
struct HandlerImpl<F> {
domain_name: String,
from_peer_id: PeerId,
to_peer_id: PeerId,
zc_packet_sender: MpscTunnelSender,
inflight_requests: InflightRequestTable,
_phan: PhantomData<F>,
}
impl<F: RpcClientFactory> HandlerImpl<F> {
async fn do_rpc(
&self,
packets: Vec<ZCPacket>,
rx: &mut RpcPacketReceiver,
) -> Result<RpcPacket> {
for packet in packets {
self.zc_packet_sender.send(packet).await?;
}
Ok(rx.recv().await.ok_or(TunnelError::Shutdown)?)
}
}
#[async_trait::async_trait]
impl<F: RpcClientFactory> Handler for HandlerImpl<F> {
type Descriptor = F::Descriptor;
type Controller = F::Controller;
async fn call(
&self,
ctrl: Self::Controller,
method: <Self::Descriptor as ServiceDescriptor>::Method,
input: bytes::Bytes,
) -> Result<bytes::Bytes> {
let transaction_id = CUR_TID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let (tx, mut rx) = mpsc::unbounded_channel();
let key = InflightRequestKey {
from_peer_id: self.from_peer_id,
to_peer_id: self.to_peer_id,
transaction_id,
};
defer!(self.inflight_requests.remove(&key););
self.inflight_requests.insert(
key.clone(),
InflightRequest {
sender: tx,
merger: PacketMerger::new(),
start_time: std::time::Instant::now(),
},
);
let desc = self.service_descriptor();
let rpc_desc = RpcDescriptor {
domain_name: self.domain_name.clone(),
proto_name: desc.proto_name().to_string(),
service_name: desc.name().to_string(),
method_index: method.index() as u32,
};
let rpc_req = RpcRequest {
descriptor: Some(rpc_desc.clone()),
request: input.into(),
timeout_ms: ctrl.timeout_ms(),
};
let packets = build_rpc_packet(
self.from_peer_id,
self.to_peer_id,
rpc_desc,
transaction_id,
true,
&rpc_req.encode_to_vec(),
ctrl.trace_id(),
);
let timeout_dur = std::time::Duration::from_millis(ctrl.timeout_ms() as u64);
let rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??;
assert_eq!(rpc_packet.transaction_id, transaction_id);
let rpc_resp = RpcResponse::decode(Bytes::from(rpc_packet.body))?;
if let Some(err) = &rpc_resp.error {
return Err(err.into());
}
Ok(bytes::Bytes::from(rpc_resp.response))
}
}
F::new(HandlerImpl::<F> {
domain_name: domain_name.to_string(),
from_peer_id,
to_peer_id,
zc_packet_sender: self.mpsc.lock().unwrap().get_sink(),
inflight_requests: self.inflight_requests.clone(),
_phan: PhantomData,
})
}
pub fn inflight_count(&self) -> usize {
self.inflight_requests.len()
}
}

View File

@ -0,0 +1,12 @@
use crate::tunnel::{mpsc::MpscTunnel, Tunnel};
pub type RpcController = super::rpc_types::controller::BaseController;
pub mod client;
pub mod packet;
pub mod server;
pub mod service_registry;
pub mod standalone;
pub type Transport = MpscTunnel<Box<dyn Tunnel>>;
pub type RpcTransactId = i64;

View File

@ -0,0 +1,161 @@
use prost::Message as _;
use crate::{
common::PeerId,
proto::{
common::{RpcDescriptor, RpcPacket},
rpc_types::error::Error,
},
tunnel::packet_def::{PacketType, ZCPacket},
};
use super::RpcTransactId;
const RPC_PACKET_CONTENT_MTU: usize = 1300;
pub struct PacketMerger {
first_piece: Option<RpcPacket>,
pieces: Vec<RpcPacket>,
last_updated: std::time::Instant,
}
impl PacketMerger {
pub fn new() -> Self {
Self {
first_piece: None,
pieces: Vec::new(),
last_updated: std::time::Instant::now(),
}
}
fn try_merge_pieces(&self) -> Option<RpcPacket> {
if self.first_piece.is_none() || self.pieces.is_empty() {
return None;
}
for p in &self.pieces {
// some piece is missing
if p.total_pieces == 0 {
return None;
}
}
// all pieces are received
let mut body = Vec::new();
for p in &self.pieces {
body.extend_from_slice(&p.body);
}
let mut tmpl_packet = self.first_piece.as_ref().unwrap().clone();
tmpl_packet.total_pieces = 1;
tmpl_packet.piece_idx = 0;
tmpl_packet.body = body;
Some(tmpl_packet)
}
pub fn feed(&mut self, rpc_packet: RpcPacket) -> Result<Option<RpcPacket>, Error> {
let total_pieces = rpc_packet.total_pieces;
let piece_idx = rpc_packet.piece_idx;
if rpc_packet.descriptor.is_none() {
return Err(Error::MalformatRpcPacket(
"descriptor is missing".to_owned(),
));
}
// for compatibility with old version
if total_pieces == 0 && piece_idx == 0 {
return Ok(Some(rpc_packet));
}
// about 32MB max size
if total_pieces > 32 * 1024 || total_pieces == 0 {
return Err(Error::MalformatRpcPacket(format!(
"total_pieces is invalid: {}",
total_pieces
)));
}
if piece_idx >= total_pieces {
return Err(Error::MalformatRpcPacket(
"piece_idx >= total_pieces".to_owned(),
));
}
if self.first_piece.is_none()
|| self.first_piece.as_ref().unwrap().transaction_id != rpc_packet.transaction_id
|| self.first_piece.as_ref().unwrap().from_peer != rpc_packet.from_peer
{
self.first_piece = Some(rpc_packet.clone());
self.pieces.clear();
}
self.pieces
.resize(total_pieces as usize, Default::default());
self.pieces[piece_idx as usize] = rpc_packet;
self.last_updated = std::time::Instant::now();
Ok(self.try_merge_pieces())
}
pub fn last_updated(&self) -> std::time::Instant {
self.last_updated
}
}
pub fn build_rpc_packet(
from_peer: PeerId,
to_peer: PeerId,
rpc_desc: RpcDescriptor,
transaction_id: RpcTransactId,
is_req: bool,
content: &Vec<u8>,
trace_id: i32,
) -> Vec<ZCPacket> {
let mut ret = Vec::new();
let content_mtu = RPC_PACKET_CONTENT_MTU;
let total_pieces = (content.len() + content_mtu - 1) / content_mtu;
let mut cur_offset = 0;
while cur_offset < content.len() || content.len() == 0 {
let mut cur_len = content_mtu;
if cur_offset + cur_len > content.len() {
cur_len = content.len() - cur_offset;
}
let mut cur_content = Vec::new();
cur_content.extend_from_slice(&content[cur_offset..cur_offset + cur_len]);
let cur_packet = RpcPacket {
from_peer,
to_peer,
descriptor: Some(rpc_desc.clone()),
is_request: is_req,
total_pieces: total_pieces as u32,
piece_idx: (cur_offset / content_mtu) as u32,
transaction_id,
body: cur_content,
trace_id,
};
cur_offset += cur_len;
let packet_type = if is_req {
PacketType::RpcReq
} else {
PacketType::RpcResp
};
let mut buf = Vec::new();
cur_packet.encode(&mut buf).unwrap();
let mut zc_packet = ZCPacket::new_with_payload(&buf);
zc_packet.fill_peer_manager_hdr(from_peer, to_peer, packet_type as u8);
ret.push(zc_packet);
if content.len() == 0 {
break;
}
}
ret
}

View File

@ -0,0 +1,207 @@
use std::{
pin::Pin,
sync::{Arc, Mutex},
};
use bytes::Bytes;
use dashmap::DashMap;
use prost::Message;
use tokio::{task::JoinSet, time::timeout};
use tokio_stream::StreamExt;
use crate::{
common::{join_joinset_background, PeerId},
proto::{
common::{self, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse},
rpc_types::error::Result,
},
tunnel::{
mpsc::{MpscTunnel, MpscTunnelSender},
ring::create_ring_tunnel_pair,
Tunnel, ZCPacketStream,
},
};
use super::{
packet::{build_rpc_packet, PacketMerger},
service_registry::ServiceRegistry,
RpcController, Transport,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct PacketMergerKey {
from_peer_id: PeerId,
rpc_desc: RpcDescriptor,
transaction_id: i64,
}
pub struct Server {
registry: Arc<ServiceRegistry>,
mpsc: Mutex<Option<MpscTunnel<Box<dyn Tunnel>>>>,
transport: Mutex<Transport>,
tasks: Arc<Mutex<JoinSet<()>>>,
packet_mergers: Arc<DashMap<PacketMergerKey, PacketMerger>>,
}
impl Server {
pub fn new() -> Self {
Server::new_with_registry(Arc::new(ServiceRegistry::new()))
}
pub fn new_with_registry(registry: Arc<ServiceRegistry>) -> Self {
let (ring_a, ring_b) = create_ring_tunnel_pair();
Self {
registry,
mpsc: Mutex::new(Some(MpscTunnel::new(ring_a))),
transport: Mutex::new(MpscTunnel::new(ring_b)),
tasks: Arc::new(Mutex::new(JoinSet::new())),
packet_mergers: Arc::new(DashMap::new()),
}
}
pub fn registry(&self) -> &ServiceRegistry {
&self.registry
}
pub fn get_transport_sink(&self) -> MpscTunnelSender {
self.transport.lock().unwrap().get_sink()
}
pub fn get_transport_stream(&self) -> Pin<Box<dyn ZCPacketStream>> {
self.transport.lock().unwrap().get_stream()
}
pub fn run(&self) {
let tasks = self.tasks.clone();
join_joinset_background(tasks.clone(), "rpc server".to_string());
let mpsc = self.mpsc.lock().unwrap().take().unwrap();
let packet_merges = self.packet_mergers.clone();
let reg = self.registry.clone();
let t = tasks.clone();
tasks.lock().unwrap().spawn(async move {
let mut mpsc = mpsc;
let mut rx = mpsc.get_stream();
while let Some(packet) = rx.next().await {
if let Err(err) = packet {
tracing::error!(?err, "Failed to receive packet");
continue;
}
let packet = match common::RpcPacket::decode(packet.unwrap().payload()) {
Err(err) => {
tracing::error!(?err, "Failed to decode packet");
continue;
}
Ok(packet) => packet,
};
if !packet.is_request {
tracing::warn!(?packet, "Received non-request packet");
continue;
}
let key = PacketMergerKey {
from_peer_id: packet.from_peer,
rpc_desc: packet.descriptor.clone().unwrap_or_default(),
transaction_id: packet.transaction_id,
};
let ret = packet_merges
.entry(key.clone())
.or_insert_with(PacketMerger::new)
.feed(packet);
match ret {
Ok(Some(packet)) => {
packet_merges.remove(&key);
t.lock().unwrap().spawn(Self::handle_rpc(
mpsc.get_sink(),
packet,
reg.clone(),
));
}
Ok(None) => {}
Err(err) => {
tracing::error!("Failed to feed packet to merger, {}", err.to_string());
}
}
}
});
let packet_mergers = self.packet_mergers.clone();
tasks.lock().unwrap().spawn(async move {
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
packet_mergers.retain(|_, v| v.last_updated().elapsed().as_secs() < 10);
}
});
}
async fn handle_rpc_request(packet: RpcPacket, reg: Arc<ServiceRegistry>) -> Result<Bytes> {
let rpc_request = RpcRequest::decode(Bytes::from(packet.body))?;
let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64);
let ctrl = RpcController {};
Ok(timeout(
timeout_duration,
reg.call_method(
packet.descriptor.unwrap(),
ctrl,
Bytes::from(rpc_request.request),
),
)
.await??)
}
async fn handle_rpc(sender: MpscTunnelSender, packet: RpcPacket, reg: Arc<ServiceRegistry>) {
let from_peer = packet.from_peer;
let to_peer = packet.to_peer;
let transaction_id = packet.transaction_id;
let trace_id = packet.trace_id;
let desc = packet.descriptor.clone().unwrap();
let mut resp_msg = RpcResponse::default();
let now = std::time::Instant::now();
let resp_bytes = Self::handle_rpc_request(packet, reg).await;
match &resp_bytes {
Ok(r) => {
resp_msg.response = r.clone().into();
}
Err(err) => {
resp_msg.error = Some(err.into());
}
};
resp_msg.runtime_us = now.elapsed().as_micros() as u64;
let packets = build_rpc_packet(
to_peer,
from_peer,
desc,
transaction_id,
false,
&resp_msg.encode_to_vec(),
trace_id,
);
for packet in packets {
if let Err(err) = sender.send(packet).await {
tracing::error!(?err, "Failed to send response packet");
}
}
}
pub fn inflight_count(&self) -> usize {
self.packet_mergers.len()
}
pub fn close(&self) {
self.transport.lock().unwrap().close();
}
}

View File

@ -0,0 +1,105 @@
use std::sync::Arc;
use dashmap::DashMap;
use crate::proto::common::RpcDescriptor;
use crate::proto::rpc_types;
use crate::proto::rpc_types::descriptor::ServiceDescriptor;
use crate::proto::rpc_types::handler::{Handler, HandlerExt};
use super::RpcController;
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub struct ServiceKey {
pub domain_name: String,
pub service_name: String,
pub proto_name: String,
}
impl From<&RpcDescriptor> for ServiceKey {
fn from(desc: &RpcDescriptor) -> Self {
Self {
domain_name: desc.domain_name.to_string(),
service_name: desc.service_name.to_string(),
proto_name: desc.proto_name.to_string(),
}
}
}
#[derive(Clone)]
struct ServiceEntry {
service: Arc<Box<dyn HandlerExt<Controller = RpcController>>>,
}
impl ServiceEntry {
fn new<H: Handler<Controller = RpcController>>(h: H) -> Self {
Self {
service: Arc::new(Box::new(h)),
}
}
async fn call_method(
&self,
ctrl: RpcController,
method_index: u8,
input: bytes::Bytes,
) -> rpc_types::error::Result<bytes::Bytes> {
self.service.call_method(ctrl, method_index, input).await
}
}
pub struct ServiceRegistry {
table: DashMap<ServiceKey, ServiceEntry>,
}
impl ServiceRegistry {
pub fn new() -> Self {
Self {
table: DashMap::new(),
}
}
pub fn register<H: Handler<Controller = RpcController>>(&self, h: H, domain_name: &str) {
let desc = h.service_descriptor();
let key = ServiceKey {
domain_name: domain_name.to_string(),
service_name: desc.name().to_string(),
proto_name: desc.proto_name().to_string(),
};
let entry = ServiceEntry::new(h);
self.table.insert(key, entry);
}
pub fn unregister<H: Handler<Controller = RpcController>>(
&self,
h: H,
domain_name: &str,
) -> Option<()> {
let desc = h.service_descriptor();
let key = ServiceKey {
domain_name: domain_name.to_string(),
service_name: desc.name().to_string(),
proto_name: desc.proto_name().to_string(),
};
self.table.remove(&key).map(|_| ())
}
pub async fn call_method(
&self,
rpc_desc: RpcDescriptor,
ctrl: RpcController,
input: bytes::Bytes,
) -> rpc_types::error::Result<bytes::Bytes> {
let service_key = ServiceKey::from(&rpc_desc);
let method_index = rpc_desc.method_index as u8;
let entry = self
.table
.get(&service_key)
.ok_or(rpc_types::error::Error::InvalidServiceKey(
service_key.service_name.clone(),
service_key.proto_name.clone(),
))?
.clone();
entry.call_method(ctrl, method_index, input).await
}
}

View File

@ -0,0 +1,245 @@
use std::{
sync::{atomic::AtomicU32, Arc, Mutex},
time::Duration,
};
use anyhow::Context as _;
use futures::{SinkExt as _, StreamExt};
use tokio::task::JoinSet;
use crate::{
common::join_joinset_background,
proto::rpc_types::{__rt::RpcClientFactory, error::Error},
tunnel::{Tunnel, TunnelConnector, TunnelListener},
};
use super::{client::Client, server::Server, service_registry::ServiceRegistry};
struct StandAloneServerOneTunnel {
tunnel: Box<dyn Tunnel>,
rpc_server: Server,
}
impl StandAloneServerOneTunnel {
pub fn new(tunnel: Box<dyn Tunnel>, registry: Arc<ServiceRegistry>) -> Self {
let rpc_server = Server::new_with_registry(registry);
StandAloneServerOneTunnel { tunnel, rpc_server }
}
pub async fn run(self) {
use tokio_stream::StreamExt as _;
let (tunnel_rx, tunnel_tx) = self.tunnel.split();
let (rpc_rx, rpc_tx) = (
self.rpc_server.get_transport_stream(),
self.rpc_server.get_transport_sink(),
);
let mut tasks = JoinSet::new();
tasks.spawn(async move {
let ret = tunnel_rx.timeout(Duration::from_secs(60));
tokio::pin!(ret);
while let Ok(Some(Ok(p))) = ret.try_next().await {
if let Err(e) = rpc_tx.send(p).await {
tracing::error!("tunnel_rx send to rpc_tx error: {:?}", e);
break;
}
}
tracing::info!("forward tunnel_rx to rpc_tx done");
});
tasks.spawn(async move {
let ret = rpc_rx.forward(tunnel_tx).await;
tracing::info!("rpc_rx forward tunnel_tx done: {:?}", ret);
});
self.rpc_server.run();
while let Some(ret) = tasks.join_next().await {
self.rpc_server.close();
tracing::info!("task done: {:?}", ret);
}
tracing::info!("all tasks done");
}
}
pub struct StandAloneServer<L> {
registry: Arc<ServiceRegistry>,
listener: Option<L>,
inflight_server: Arc<AtomicU32>,
tasks: Arc<Mutex<JoinSet<()>>>,
}
impl<L: TunnelListener + 'static> StandAloneServer<L> {
pub fn new(listener: L) -> Self {
StandAloneServer {
registry: Arc::new(ServiceRegistry::new()),
listener: Some(listener),
inflight_server: Arc::new(AtomicU32::new(0)),
tasks: Arc::new(Mutex::new(JoinSet::new())),
}
}
pub fn registry(&self) -> &ServiceRegistry {
&self.registry
}
pub async fn serve(&mut self) -> Result<(), Error> {
let tasks = self.tasks.clone();
let mut listener = self.listener.take().unwrap();
let registry = self.registry.clone();
join_joinset_background(tasks.clone(), "standalone server tasks".to_string());
listener
.listen()
.await
.with_context(|| "failed to listen")?;
let inflight_server = self.inflight_server.clone();
self.tasks.lock().unwrap().spawn(async move {
while let Ok(tunnel) = listener.accept().await {
let server = StandAloneServerOneTunnel::new(tunnel, registry.clone());
let inflight_server = inflight_server.clone();
inflight_server.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tasks.lock().unwrap().spawn(async move {
server.run().await;
inflight_server.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
});
}
panic!("standalone server listener exit");
});
Ok(())
}
pub fn inflight_server(&self) -> u32 {
self.inflight_server
.load(std::sync::atomic::Ordering::Relaxed)
}
}
struct StandAloneClientOneTunnel {
rpc_client: Client,
tasks: Arc<Mutex<JoinSet<()>>>,
error: Arc<Mutex<Option<Error>>>,
}
impl StandAloneClientOneTunnel {
pub fn new(tunnel: Box<dyn Tunnel>) -> Self {
let rpc_client = Client::new();
let (mut rpc_rx, rpc_tx) = (
rpc_client.get_transport_stream(),
rpc_client.get_transport_sink(),
);
let tasks = Arc::new(Mutex::new(JoinSet::new()));
let (mut tunnel_rx, mut tunnel_tx) = tunnel.split();
let error_store = Arc::new(Mutex::new(None));
let error = error_store.clone();
tasks.lock().unwrap().spawn(async move {
while let Some(p) = rpc_rx.next().await {
match p {
Ok(p) => {
if let Err(e) = tunnel_tx
.send(p)
.await
.with_context(|| "failed to send packet")
{
*error.lock().unwrap() = Some(e.into());
}
}
Err(e) => {
*error.lock().unwrap() = Some(anyhow::Error::from(e).into());
}
}
}
*error.lock().unwrap() = Some(anyhow::anyhow!("rpc_rx next exit").into());
});
let error = error_store.clone();
tasks.lock().unwrap().spawn(async move {
while let Some(p) = tunnel_rx.next().await {
match p {
Ok(p) => {
if let Err(e) = rpc_tx
.send(p)
.await
.with_context(|| "failed to send packet")
{
*error.lock().unwrap() = Some(e.into());
}
}
Err(e) => {
*error.lock().unwrap() = Some(anyhow::Error::from(e).into());
}
}
}
*error.lock().unwrap() = Some(anyhow::anyhow!("tunnel_rx next exit").into());
});
rpc_client.run();
StandAloneClientOneTunnel {
rpc_client,
tasks,
error: error_store,
}
}
pub fn take_error(&self) -> Option<Error> {
self.error.lock().unwrap().take()
}
}
pub struct StandAloneClient<C: TunnelConnector> {
connector: C,
client: Option<StandAloneClientOneTunnel>,
}
impl<C: TunnelConnector> StandAloneClient<C> {
pub fn new(connector: C) -> Self {
StandAloneClient {
connector,
client: None,
}
}
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, Error> {
Ok(self.connector.connect().await.with_context(|| {
format!(
"failed to connect to server: {:?}",
self.connector.remote_url()
)
})?)
}
pub async fn scoped_client<F: RpcClientFactory>(
&mut self,
domain_name: String,
) -> Result<F::ClientImpl, Error> {
let mut c = self.client.take();
let error = c.as_ref().and_then(|c| c.take_error());
if c.is_none() || error.is_some() {
tracing::info!("reconnect due to error: {:?}", error);
let tunnel = self.connect().await?;
c = Some(StandAloneClientOneTunnel::new(tunnel));
}
self.client = c;
Ok(self
.client
.as_ref()
.unwrap()
.rpc_client
.scoped_client::<F>(1, 1, domain_name))
}
}

View File

@ -0,0 +1,57 @@
//! Utility functions used by generated code; this is *not* part of the crate's public API!
use bytes;
use prost;
use super::controller;
use super::descriptor;
use super::descriptor::ServiceDescriptor;
use super::error;
use super::handler;
use super::handler::Handler;
/// Efficiently decode a particular message type from a byte buffer.
pub fn decode<M>(buf: bytes::Bytes) -> error::Result<M>
where
M: prost::Message + Default,
{
let message = prost::Message::decode(buf)?;
Ok(message)
}
/// Efficiently encode a particular message into a byte buffer.
pub fn encode<M>(message: M) -> error::Result<bytes::Bytes>
where
M: prost::Message,
{
let len = prost::Message::encoded_len(&message);
let mut buf = ::bytes::BytesMut::with_capacity(len);
prost::Message::encode(&message, &mut buf)?;
Ok(buf.freeze())
}
pub async fn call_method<H, I, O>(
handler: H,
ctrl: H::Controller,
method: <H::Descriptor as descriptor::ServiceDescriptor>::Method,
input: I,
) -> super::error::Result<O>
where
H: handler::Handler,
I: prost::Message,
O: prost::Message + Default,
{
type Error = super::error::Error;
let input_bytes = encode(input)?;
let ret_msg = handler.call(ctrl, method, input_bytes).await?;
decode(ret_msg)
}
pub trait RpcClientFactory: Clone + Send + Sync + 'static {
type Descriptor: ServiceDescriptor + Default;
type ClientImpl;
type Controller: controller::Controller;
fn new(
handler: impl Handler<Descriptor = Self::Descriptor, Controller = Self::Controller>,
) -> Self::ClientImpl;
}

View File

@ -0,0 +1,18 @@
pub trait Controller: Send + Sync + 'static {
fn timeout_ms(&self) -> i32 {
5000
}
fn set_timeout_ms(&mut self, _timeout_ms: i32) {}
fn set_trace_id(&mut self, _trace_id: i32) {}
fn trace_id(&self) -> i32 {
0
}
}
#[derive(Debug)]
pub struct BaseController {}
impl Controller for BaseController {}

View File

@ -0,0 +1,50 @@
//! Traits for defining generic service descriptor definitions.
//!
//! These traits are built on the assumption that some form of code generation is being used (e.g.
//! using only `&'static str`s) but it's of course possible to implement these traits manually.
use std::any;
use std::fmt;
/// A descriptor for an available RPC service.
pub trait ServiceDescriptor: Clone + fmt::Debug + Send + Sync {
/// The associated type of method descriptors.
type Method: MethodDescriptor + fmt::Debug + TryFrom<u8>;
/// The name of the service, used in Rust code and perhaps for human readability.
fn name(&self) -> &'static str;
/// The raw protobuf name of the service.
fn proto_name(&self) -> &'static str;
/// The package name of the service.
fn package(&self) -> &'static str {
""
}
/// All of the available methods on the service.
fn methods(&self) -> &'static [Self::Method];
}
/// A descriptor for a method available on an RPC service.
pub trait MethodDescriptor: Clone + Copy + fmt::Debug + Send + Sync {
/// The name of the service, used in Rust code and perhaps for human readability.
fn name(&self) -> &'static str;
/// The raw protobuf name of the service.
fn proto_name(&self) -> &'static str;
/// The Rust `TypeId` for the input that this method accepts.
fn input_type(&self) -> any::TypeId;
/// The raw protobuf name for the input type that this method accepts.
fn input_proto_type(&self) -> &'static str;
/// The Rust `TypeId` for the output that this method produces.
fn output_type(&self) -> any::TypeId;
/// The raw protobuf name for the output type that this method produces.
fn output_proto_type(&self) -> &'static str;
/// The index of the method in the service descriptor.
fn index(&self) -> u8;
}

View File

@ -0,0 +1,34 @@
//! Error type definitions for errors that can occur during RPC interactions.
use std::result;
use prost;
use thiserror;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("rust tun error {0}")]
ExecutionError(#[from] anyhow::Error),
#[error("Decode error: {0}")]
DecodeError(#[from] prost::DecodeError),
#[error("Encode error: {0}")]
EncodeError(#[from] prost::EncodeError),
#[error("Invalid method index: {0}, service: {1}")]
InvalidMethodIndex(u8, String),
#[error("Invalid service name: {0}, proto name: {1}")]
InvalidServiceKey(String, String),
#[error("Invalid packet: {0}")]
MalformatRpcPacket(String),
#[error("Timeout: {0}")]
Timeout(#[from] tokio::time::error::Elapsed),
#[error("Tunnel error: {0}")]
TunnelError(#[from] crate::tunnel::TunnelError),
}
pub type Result<T> = result::Result<T, Error>;

View File

@ -0,0 +1,67 @@
//! Traits for defining generic RPC handlers.
use super::{
controller::Controller,
descriptor::{self, ServiceDescriptor},
};
use bytes;
/// An implementation of a specific RPC handler.
///
/// This can be an actual implementation of a service, or something that will send a request over
/// a network to fulfill a request.
#[async_trait::async_trait]
pub trait Handler: Clone + Send + Sync + 'static {
/// The service descriptor for the service whose requests this handler can handle.
type Descriptor: descriptor::ServiceDescriptor + Default;
type Controller: super::controller::Controller;
///
/// Perform a raw call to the specified service and method.
async fn call(
&self,
ctrl: Self::Controller,
method: <Self::Descriptor as descriptor::ServiceDescriptor>::Method,
input: bytes::Bytes,
) -> super::error::Result<bytes::Bytes>;
fn service_descriptor(&self) -> Self::Descriptor {
Self::Descriptor::default()
}
fn get_method_from_index(
&self,
index: u8,
) -> super::error::Result<<Self::Descriptor as descriptor::ServiceDescriptor>::Method> {
let desc = self.service_descriptor();
<Self::Descriptor as descriptor::ServiceDescriptor>::Method::try_from(index)
.map_err(|_| super::error::Error::InvalidMethodIndex(index, desc.name().to_string()))
}
}
#[async_trait::async_trait]
pub trait HandlerExt: Send + Sync + 'static {
type Controller;
async fn call_method(
&self,
ctrl: Self::Controller,
method_index: u8,
input: bytes::Bytes,
) -> super::error::Result<bytes::Bytes>;
}
#[async_trait::async_trait]
impl<C: Controller, T: Handler<Controller = C>> HandlerExt for T {
type Controller = C;
async fn call_method(
&self,
ctrl: Self::Controller,
method_index: u8,
input: bytes::Bytes,
) -> super::error::Result<bytes::Bytes> {
let method = self.get_method_from_index(method_index)?;
self.call(ctrl, method, input).await
}
}

View File

@ -0,0 +1,5 @@
pub mod __rt;
pub mod controller;
pub mod descriptor;
pub mod error;
pub mod handler;

View File

@ -0,0 +1,24 @@
syntax = "proto3";
package tests;
/// The Greeting service. This service is used to generate greetings for various
/// use-cases.
service Greeting {
// Generates a "hello" greeting based on the supplied info.
rpc SayHello(SayHelloRequest) returns (SayHelloResponse);
// Generates a "goodbye" greeting based on the supplied info.
rpc SayGoodbye(SayGoodbyeRequest) returns (SayGoodbyeResponse);
}
// The request for an `Greeting.SayHello` call.
message SayHelloRequest { string name = 1; }
// The response for an `Greeting.SayHello` call.
message SayHelloResponse { string greeting = 1; }
// The request for an `Greeting.SayGoodbye` call.
message SayGoodbyeRequest { string name = 1; }
// The response for an `Greeting.SayGoodbye` call.
message SayGoodbyeResponse { string greeting = 1; }

225
easytier/src/proto/tests.rs Normal file
View File

@ -0,0 +1,225 @@
include!(concat!(env!("OUT_DIR"), "/tests.rs"));
use std::sync::{Arc, Mutex};
use futures::StreamExt as _;
use tokio::task::JoinSet;
use super::rpc_impl::RpcController;
#[derive(Clone)]
pub struct GreetingService {
pub delay_ms: u64,
pub prefix: String,
}
#[async_trait::async_trait]
impl Greeting for GreetingService {
type Controller = RpcController;
async fn say_hello(
&self,
_ctrl: Self::Controller,
input: SayHelloRequest,
) -> crate::proto::rpc_types::error::Result<SayHelloResponse> {
let resp = SayHelloResponse {
greeting: format!("{} {}!", self.prefix, input.name),
};
tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
Ok(resp)
}
/// Generates a "goodbye" greeting based on the supplied info.
async fn say_goodbye(
&self,
_ctrl: Self::Controller,
input: SayGoodbyeRequest,
) -> crate::proto::rpc_types::error::Result<SayGoodbyeResponse> {
let resp = SayGoodbyeResponse {
greeting: format!("Goodbye, {}!", input.name),
};
tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
Ok(resp)
}
}
use crate::proto::rpc_impl::client::Client;
use crate::proto::rpc_impl::server::Server;
struct TestContext {
client: Client,
server: Server,
tasks: Arc<Mutex<JoinSet<()>>>,
}
impl TestContext {
fn new() -> Self {
let rpc_server = Server::new();
rpc_server.run();
let client = Client::new();
client.run();
let tasks = Arc::new(Mutex::new(JoinSet::new()));
let (mut rx, tx) = (
rpc_server.get_transport_stream(),
client.get_transport_sink(),
);
tasks.lock().unwrap().spawn(async move {
while let Some(Ok(packet)) = rx.next().await {
if let Err(err) = tx.send(packet).await {
println!("{:?}", err);
break;
}
}
});
let (mut rx, tx) = (
client.get_transport_stream(),
rpc_server.get_transport_sink(),
);
tasks.lock().unwrap().spawn(async move {
while let Some(Ok(packet)) = rx.next().await {
if let Err(err) = tx.send(packet).await {
println!("{:?}", err);
break;
}
}
});
Self {
client,
server: rpc_server,
tasks,
}
}
}
fn random_string(len: usize) -> String {
use rand::distributions::Alphanumeric;
use rand::Rng;
let mut rng = rand::thread_rng();
let s: Vec<u8> = std::iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.take(len)
.collect();
String::from_utf8(s).unwrap()
}
#[tokio::test]
async fn rpc_basic_test() {
let ctx = TestContext::new();
let server = GreetingServer::new(GreetingService {
delay_ms: 0,
prefix: "Hello".to_string(),
});
ctx.server.registry().register(server, "");
let out = ctx
.client
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "".to_string());
// small size req and resp
let ctrl = RpcController {};
let input = SayHelloRequest {
name: "world".to_string(),
};
let ret = out.say_hello(ctrl, input).await;
assert_eq!(ret.unwrap().greeting, "Hello world!");
let ctrl = RpcController {};
let input = SayGoodbyeRequest {
name: "world".to_string(),
};
let ret = out.say_goodbye(ctrl, input).await;
assert_eq!(ret.unwrap().greeting, "Goodbye, world!");
// large size req and resp
let ctrl = RpcController {};
let name = random_string(20 * 1024 * 1024);
let input = SayGoodbyeRequest { name: name.clone() };
let ret = out.say_goodbye(ctrl, input).await;
assert_eq!(ret.unwrap().greeting, format!("Goodbye, {}!", name));
assert_eq!(0, ctx.client.inflight_count());
assert_eq!(0, ctx.server.inflight_count());
}
#[tokio::test]
async fn rpc_timeout_test() {
let ctx = TestContext::new();
let server = GreetingServer::new(GreetingService {
delay_ms: 10000,
prefix: "Hello".to_string(),
});
ctx.server.registry().register(server, "test");
let out = ctx
.client
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "test".to_string());
let ctrl = RpcController {};
let input = SayHelloRequest {
name: "world".to_string(),
};
let ret = out.say_hello(ctrl, input).await;
assert!(ret.is_err());
assert!(matches!(
ret.unwrap_err(),
crate::proto::rpc_types::error::Error::Timeout(_)
));
assert_eq!(0, ctx.client.inflight_count());
assert_eq!(0, ctx.server.inflight_count());
}
#[tokio::test]
async fn standalone_rpc_test() {
use crate::proto::rpc_impl::standalone::{StandAloneClient, StandAloneServer};
use crate::tunnel::tcp::{TcpTunnelConnector, TcpTunnelListener};
let mut server = StandAloneServer::new(TcpTunnelListener::new(
"tcp://0.0.0.0:33455".parse().unwrap(),
));
let service = GreetingServer::new(GreetingService {
delay_ms: 0,
prefix: "Hello".to_string(),
});
server.registry().register(service, "test");
server.serve().await.unwrap();
let mut client = StandAloneClient::new(TcpTunnelConnector::new(
"tcp://127.0.0.1:33455".parse().unwrap(),
));
let out = client
.scoped_client::<GreetingClientFactory<RpcController>>("test".to_string())
.await
.unwrap();
let ctrl = RpcController {};
let input = SayHelloRequest {
name: "world".to_string(),
};
let ret = out.say_hello(ctrl, input).await;
assert_eq!(ret.unwrap().greeting, "Hello world!");
let out = client
.scoped_client::<GreetingClientFactory<RpcController>>("test".to_string())
.await
.unwrap();
let ctrl = RpcController {};
let input = SayGoodbyeRequest {
name: "world".to_string(),
};
let ret = out.say_goodbye(ctrl, input).await;
assert_eq!(ret.unwrap().greeting, "Goodbye, world!");
drop(client);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
assert_eq!(0, server.inflight_server());
}

View File

@ -1 +0,0 @@
tonic::include_proto!("cli"); // The string specified here must match the proto package name

View File

@ -1,4 +0,0 @@
pub mod cli;
pub use cli::*;
pub mod peer;

View File

@ -1,22 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Default)]
pub struct GetIpListResponse {
pub public_ipv4: String,
pub interface_ipv4s: Vec<String>,
pub public_ipv6: String,
pub interface_ipv6s: Vec<String>,
pub listeners: Vec<url::Url>,
}
impl GetIpListResponse {
pub fn new() -> Self {
GetIpListResponse {
public_ipv4: "".to_string(),
interface_ipv4s: vec![],
public_ipv6: "".to_string(),
interface_ipv6s: vec![],
listeners: vec![],
}
}
}

View File

@ -127,7 +127,7 @@ pub fn enable_log() {
.init();
}
fn check_route(ipv4: &str, dst_peer_id: PeerId, routes: Vec<crate::rpc::Route>) {
fn check_route(ipv4: &str, dst_peer_id: PeerId, routes: Vec<crate::proto::cli::Route>) {
let mut found = false;
for r in routes.iter() {
if r.ipv4_addr == ipv4.to_string() {

View File

@ -518,8 +518,8 @@ pub async fn foreign_network_forward_nic_data() {
wait_for_condition(
|| async {
inst1.get_peer_manager().list_routes().await.len() == 1
&& inst2.get_peer_manager().list_routes().await.len() == 1
inst1.get_peer_manager().list_routes().await.len() == 2
&& inst2.get_peer_manager().list_routes().await.len() == 2
},
Duration::from_secs(5),
)

View File

@ -16,10 +16,9 @@ use tokio_stream::StreamExt;
use tokio_util::io::poll_write_buf;
use zerocopy::FromBytes as _;
use crate::{
rpc::TunnelInfo,
tunnel::packet_def::{ZCPacket, PEER_MANAGER_HEADER_SIZE},
};
use super::TunnelInfo;
use crate::tunnel::packet_def::{ZCPacket, PEER_MANAGER_HEADER_SIZE};
use super::{
buf::BufList,
@ -505,8 +504,8 @@ pub mod tests {
let ret = listener.accept().await.unwrap();
println!("accept: {:?}", ret.info());
assert_eq!(
ret.info().unwrap().local_addr,
listener.local_url().to_string()
url::Url::from(ret.info().unwrap().local_addr.unwrap()),
listener.local_url()
);
_tunnel_echo_server(ret, false).await
});
@ -515,8 +514,8 @@ pub mod tests {
println!("connect: {:?}", tunnel.info());
assert_eq!(
tunnel.info().unwrap().remote_addr,
connector.remote_url().to_string()
url::Url::from(tunnel.info().unwrap().remote_addr.unwrap()),
connector.remote_url(),
);
let (mut recv, mut send) = tunnel.split();

View File

@ -3,10 +3,11 @@ use std::{
task::{Context, Poll},
};
use crate::rpc::TunnelInfo;
use auto_impl::auto_impl;
use futures::{Sink, SinkExt, Stream, StreamExt};
use crate::proto::common::TunnelInfo;
use self::stats::Throughput;
use super::*;

View File

@ -8,7 +8,7 @@ use std::fmt::Debug;
use tokio::time::error::Elapsed;
use crate::rpc::TunnelInfo;
use crate::proto::common::TunnelInfo;
use self::packet_def::ZCPacket;

View File

@ -3,9 +3,13 @@
use std::{pin::Pin, time::Duration};
use anyhow::Context;
use tokio::{task::JoinHandle, time::timeout};
use tokio::time::timeout;
use super::{packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream};
use crate::common::scoped_task::ScopedTask;
use super::{
packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
};
use tachyonix::{channel, Receiver, Sender};
@ -29,12 +33,12 @@ impl MpscTunnelSender {
}
pub struct MpscTunnel<T> {
tx: Sender<ZCPacket>,
tx: Option<Sender<ZCPacket>>,
tunnel: T,
stream: Option<Pin<Box<dyn ZCPacketStream>>>,
task: Option<JoinHandle<()>>,
task: ScopedTask<()>,
}
impl<T: Tunnel> MpscTunnel<T> {
@ -54,10 +58,10 @@ impl<T: Tunnel> MpscTunnel<T> {
});
Self {
tx,
tx: Some(tx),
tunnel,
stream: Some(stream),
task: Some(task),
task: task.into(),
}
}
@ -81,7 +85,12 @@ impl<T: Tunnel> MpscTunnel<T> {
}
pub fn get_sink(&self) -> MpscTunnelSender {
MpscTunnelSender(self.tx.clone())
MpscTunnelSender(self.tx.as_ref().unwrap().clone())
}
pub fn close(&mut self) {
self.tx.take();
self.task.abort();
}
}

View File

@ -54,6 +54,8 @@ pub enum PacketType {
Pong = 5,
TaRpc = 6,
Route = 7,
RpcReq = 8,
RpcResp = 9,
}
bitflags::bitflags! {

View File

@ -4,12 +4,10 @@
use std::{error::Error, net::SocketAddr, sync::Arc};
use crate::{
rpc::TunnelInfo,
tunnel::{
check_scheme_and_get_socket_addr_ext,
common::{FramedReader, FramedWriter, TunnelWrapper},
},
use crate::tunnel::{
check_scheme_and_get_socket_addr_ext,
common::{FramedReader, FramedWriter, TunnelWrapper},
TunnelInfo,
};
use anyhow::Context;
use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Connection, Endpoint, ServerConfig};
@ -113,8 +111,10 @@ impl TunnelListener for QUICTunnelListener {
let info = TunnelInfo {
tunnel_type: "quic".to_owned(),
local_addr: self.local_url().into(),
remote_addr: super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(),
local_addr: Some(self.local_url().into()),
remote_addr: Some(
super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(),
),
};
Ok(Box::new(TunnelWrapper::new(
@ -177,8 +177,10 @@ impl TunnelConnector for QUICTunnelConnector {
let info = TunnelInfo {
tunnel_type: "quic".to_owned(),
local_addr: super::build_url_from_socket_addr(&local_addr.to_string(), "quic").into(),
remote_addr: self.addr.to_string(),
local_addr: Some(
super::build_url_from_socket_addr(&local_addr.to_string(), "quic").into(),
),
remote_addr: Some(self.addr.clone().into()),
};
let arc_conn = Arc::new(ConnWrapper { conn: connection });

View File

@ -261,8 +261,8 @@ fn get_tunnel_for_client(conn: Arc<Connection>) -> impl Tunnel {
RingSink::new(conn.server.clone()),
Some(TunnelInfo {
tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
local_addr: Some(build_url_from_socket_addr(&conn.client.id.into(), "ring").into()),
remote_addr: Some(build_url_from_socket_addr(&conn.server.id.into(), "ring").into()),
}),
)
}
@ -273,8 +273,8 @@ fn get_tunnel_for_server(conn: Arc<Connection>) -> impl Tunnel {
RingSink::new(conn.client.clone()),
Some(TunnelInfo {
tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
local_addr: Some(build_url_from_socket_addr(&conn.server.id.into(), "ring").into()),
remote_addr: Some(build_url_from_socket_addr(&conn.client.id.into(), "ring").into()),
}),
)
}

View File

@ -4,7 +4,8 @@ use async_trait::async_trait;
use futures::stream::FuturesUnordered;
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use crate::{rpc::TunnelInfo, tunnel::common::setup_sokcet2};
use super::TunnelInfo;
use crate::tunnel::common::setup_sokcet2;
use super::{
check_scheme_and_get_socket_addr, check_scheme_and_get_socket_addr_ext,
@ -56,9 +57,10 @@ impl TunnelListener for TcpTunnelListener {
stream.set_nodelay(true).unwrap();
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),
local_addr: self.local_url().into(),
remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp")
.into(),
local_addr: Some(self.local_url().into()),
remote_addr: Some(
super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp").into(),
),
};
let (r, w) = stream.into_split();
@ -82,9 +84,10 @@ fn get_tunnel_with_tcp_stream(
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),
local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp")
.into(),
remote_addr: remote_url.into(),
local_addr: Some(
super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp").into(),
),
remote_addr: Some(remote_url.into()),
};
let (r, w) = stream.into_split();

View File

@ -15,9 +15,9 @@ use tokio::{
use tracing::{instrument, Instrument};
use super::TunnelInfo;
use crate::{
common::join_joinset_background,
rpc::TunnelInfo,
tunnel::{
build_url_from_socket_addr,
common::{reserve_buf, TunnelWrapper},
@ -317,8 +317,10 @@ impl UdpTunnelListenerData {
Box::new(RingSink::new(ring_for_send_udp)),
Some(TunnelInfo {
tunnel_type: "udp".to_owned(),
local_addr: self.local_url.clone().into(),
remote_addr: build_url_from_socket_addr(&remote_addr.to_string(), "udp").into(),
local_addr: Some(self.local_url.clone().into()),
remote_addr: Some(
build_url_from_socket_addr(&remote_addr.to_string(), "udp").into(),
),
}),
));
@ -607,9 +609,10 @@ impl UdpTunnelConnector {
Box::new(RingSink::new(ring_for_send_udp)),
Some(TunnelInfo {
tunnel_type: "udp".to_owned(),
local_addr: build_url_from_socket_addr(&socket.local_addr()?.to_string(), "udp")
.into(),
remote_addr: self.addr.clone().into(),
local_addr: Some(
build_url_from_socket_addr(&socket.local_addr()?.to_string(), "udp").into(),
),
remote_addr: Some(self.addr.clone().into()),
}),
)))
}
@ -708,7 +711,7 @@ impl super::TunnelConnector for UdpTunnelConnector {
#[cfg(test)]
mod tests {
use std::time::Duration;
use std::{net::IpAddr, time::Duration};
use futures::SinkExt;
use tokio::time::timeout;
@ -786,7 +789,11 @@ mod tests {
loop {
let ret = listener.accept().await.unwrap();
assert_eq!(
ret.info().unwrap().local_addr,
ret.info()
.unwrap()
.local_addr
.unwrap_or_default()
.to_string(),
listener.local_url().to_string()
);
tokio::spawn(async move { _tunnel_echo_server(ret, false).await });
@ -801,15 +808,15 @@ mod tests {
tokio::spawn(timeout(
Duration::from_secs(2),
send_random_data_to_socket(t1.info().unwrap().local_addr.parse().unwrap()),
send_random_data_to_socket(t1.info().unwrap().local_addr.unwrap().into()),
));
tokio::spawn(timeout(
Duration::from_secs(2),
send_random_data_to_socket(t1.info().unwrap().remote_addr.parse().unwrap()),
send_random_data_to_socket(t1.info().unwrap().remote_addr.unwrap().into()),
));
tokio::spawn(timeout(
Duration::from_secs(2),
send_random_data_to_socket(t2.info().unwrap().remote_addr.parse().unwrap()),
send_random_data_to_socket(t2.info().unwrap().remote_addr.unwrap().into()),
));
let sender1 = tokio::spawn(async move {
@ -854,12 +861,12 @@ mod tests {
if ips.is_empty() {
return;
}
let bind_dev = get_interface_name_by_ip(&ips[0].parse().unwrap());
let bind_dev = get_interface_name_by_ip(&IpAddr::V4(ips[0].into()));
for ip in ips {
println!("bind to ip: {:?}, {:?}", ip, bind_dev);
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(
&format!("udp://{}:11111", ip).parse().unwrap(),
&format!("udp://{}:11111", ip.to_string()).parse().unwrap(),
"udp",
)
.unwrap();

View File

@ -8,7 +8,8 @@ use tokio_rustls::TlsAcceptor;
use tokio_websockets::{ClientBuilder, Limits, MaybeTlsStream, Message};
use zerocopy::AsBytes;
use crate::{rpc::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config};
use super::TunnelInfo;
use crate::tunnel::insecure_tls::get_insecure_tls_client_config;
use super::{
common::{setup_sokcet2, wait_for_connect_futures, TunnelWrapper},
@ -72,12 +73,14 @@ impl WSTunnelListener {
async fn try_accept(&mut self, stream: TcpStream) -> Result<Box<dyn Tunnel>, TunnelError> {
let info = TunnelInfo {
tunnel_type: self.addr.scheme().to_owned(),
local_addr: self.local_url().into(),
remote_addr: super::build_url_from_socket_addr(
&stream.peer_addr()?.to_string(),
self.addr.scheme().to_string().as_str(),
)
.into(),
local_addr: Some(self.local_url().into()),
remote_addr: Some(
super::build_url_from_socket_addr(
&stream.peer_addr()?.to_string(),
self.addr.scheme().to_string().as_str(),
)
.into(),
),
};
let server_bulder = tokio_websockets::ServerBuilder::new().limits(Limits::unlimited());
@ -182,12 +185,14 @@ impl WSTunnelConnector {
let info = TunnelInfo {
tunnel_type: addr.scheme().to_owned(),
local_addr: super::build_url_from_socket_addr(
&stream.local_addr()?.to_string(),
addr.scheme().to_string().as_str(),
)
.into(),
remote_addr: addr.to_string(),
local_addr: Some(
super::build_url_from_socket_addr(
&stream.local_addr()?.to_string(),
addr.scheme().to_string().as_str(),
)
.into(),
),
remote_addr: Some(addr.clone().into()),
};
let c = ClientBuilder::from_uri(http::Uri::try_from(addr.to_string()).unwrap());

View File

@ -20,13 +20,11 @@ use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use rand::RngCore;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use crate::{
rpc::TunnelInfo,
tunnel::{
build_url_from_socket_addr,
common::TunnelWrapper,
packet_def::{ZCPacket, WG_TUNNEL_HEADER_SIZE},
},
use super::TunnelInfo;
use crate::tunnel::{
build_url_from_socket_addr,
common::TunnelWrapper,
packet_def::{ZCPacket, WG_TUNNEL_HEADER_SIZE},
};
use super::{
@ -522,12 +520,16 @@ impl WgTunnelListener {
sink,
Some(TunnelInfo {
tunnel_type: "wg".to_owned(),
local_addr: build_url_from_socket_addr(
&socket.local_addr().unwrap().to_string(),
"wg",
)
.into(),
remote_addr: build_url_from_socket_addr(&addr.to_string(), "wg").into(),
local_addr: Some(
build_url_from_socket_addr(
&socket.local_addr().unwrap().to_string(),
"wg",
)
.into(),
),
remote_addr: Some(
build_url_from_socket_addr(&addr.to_string(), "wg").into(),
),
}),
));
if let Err(e) = conn_sender.send(tunnel) {
@ -670,8 +672,8 @@ impl WgTunnelConnector {
sink,
Some(TunnelInfo {
tunnel_type: "wg".to_owned(),
local_addr: super::build_url_from_socket_addr(&local_addr, "wg").into(),
remote_addr: addr_url.to_string(),
local_addr: Some(super::build_url_from_socket_addr(&local_addr, "wg").into()),
remote_addr: Some(addr_url.into()),
}),
Some(Box::new(wg_peer)),
));

View File

@ -5,7 +5,10 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilte
use crate::{
common::{config::ConfigLoader, get_logger_timer_rfc3339},
rpc::cli::{NatType, PeerInfo, Route},
proto::{
cli::{PeerInfo, Route},
common::NatType,
},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -114,17 +117,11 @@ pub fn list_peer_route_pair(peers: Vec<PeerInfo>, routes: Vec<Route>) -> Vec<Pee
for route in routes.iter() {
let peer = peers.iter().find(|peer| peer.peer_id == route.peer_id);
let has_tunnel = peer.map(|p| !p.conns.is_empty()).unwrap_or(false);
let mut pair = PeerRoutePair {
let pair = PeerRoutePair {
route: route.clone(),
peer: peer.cloned(),
};
// it is relayed by public server, adjust the cost
if !has_tunnel && pair.route.cost == 1 {
pair.route.cost = 2;
}
pairs.push(pair);
}

View File

@ -89,8 +89,8 @@ impl WireGuardImpl {
peer_mgr
.get_global_ctx()
.issue_event(GlobalCtxEvent::VpnPortalClientConnected(
info.local_addr.clone(),
info.remote_addr.clone(),
info.local_addr.clone().unwrap_or_default().to_string(),
info.remote_addr.clone().unwrap_or_default().to_string(),
));
let mut map_key = None;
@ -120,7 +120,7 @@ impl WireGuardImpl {
};
if !ip_registered {
let client_entry = Arc::new(ClientEntry {
endpoint_addr: remote_addr.parse().ok(),
endpoint_addr: remote_addr.clone().map(Into::into),
sink: mpsc_tunnel.get_sink(),
});
map_key = Some(i.get_source());
@ -142,8 +142,8 @@ impl WireGuardImpl {
peer_mgr
.get_global_ctx()
.issue_event(GlobalCtxEvent::VpnPortalClientDisconnected(
info.local_addr,
info.remote_addr,
info.local_addr.unwrap_or_default().to_string(),
info.remote_addr.unwrap_or_default().to_string(),
));
}