This commit is contained in:
Sijie.Sun 2024-11-14 23:10:04 +08:00 committed by GitHub
commit 5cd3f1218b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 491 additions and 74 deletions

37
Cargo.lock generated
View File

@ -262,6 +262,20 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
] ]
[[package]]
name = "async-compression"
version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857"
dependencies = [
"futures-core",
"memchr",
"pin-project-lite",
"tokio",
"zstd 0.13.2",
"zstd-safe 7.2.1",
]
[[package]] [[package]]
name = "async-event" name = "async-event"
version = "0.2.1" version = "0.2.1"
@ -1809,6 +1823,7 @@ version = "2.0.3"
dependencies = [ dependencies = [
"aes-gcm", "aes-gcm",
"anyhow", "anyhow",
"async-compression",
"async-recursion", "async-recursion",
"async-ringbuf", "async-ringbuf",
"async-stream", "async-stream",
@ -9474,7 +9489,7 @@ dependencies = [
"pbkdf2", "pbkdf2",
"sha1", "sha1",
"time", "time",
"zstd", "zstd 0.11.2+zstd.1.5.2",
] ]
[[package]] [[package]]
@ -9483,7 +9498,16 @@ version = "0.11.2+zstd.1.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
dependencies = [ dependencies = [
"zstd-safe", "zstd-safe 5.0.2+zstd.1.5.2",
]
[[package]]
name = "zstd"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9"
dependencies = [
"zstd-safe 7.2.1",
] ]
[[package]] [[package]]
@ -9496,6 +9520,15 @@ dependencies = [
"zstd-sys", "zstd-sys",
] ]
[[package]]
name = "zstd-safe"
version = "7.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059"
dependencies = [
"zstd-sys",
]
[[package]] [[package]]
name = "zstd-sys" name = "zstd-sys"
version = "2.0.13+zstd.1.5.6" version = "2.0.13+zstd.1.5.6"

View File

@ -1,6 +1,7 @@
import type { NetworkTypes } from 'easytier-frontend-lib'
import { addPluginListener } from '@tauri-apps/api/core' import { addPluginListener } from '@tauri-apps/api/core'
import { Utils } from 'easytier-frontend-lib'
import { prepare_vpn, start_vpn, stop_vpn } from 'tauri-plugin-vpnservice-api' import { prepare_vpn, start_vpn, stop_vpn } from 'tauri-plugin-vpnservice-api'
import { NetworkTypes, Utils } from 'easytier-frontend-lib'
type Route = NetworkTypes.Route type Route = NetworkTypes.Route

View File

@ -1,5 +1,5 @@
import type { NetworkTypes } from 'easytier-frontend-lib'
import { invoke } from '@tauri-apps/api/core' import { invoke } from '@tauri-apps/api/core'
import { NetworkTypes } from 'easytier-frontend-lib'
type NetworkConfig = NetworkTypes.NetworkConfig type NetworkConfig = NetworkTypes.NetworkConfig
type NetworkInstanceRunningInfo = NetworkTypes.NetworkInstanceRunningInfo type NetworkInstanceRunningInfo = NetworkTypes.NetworkInstanceRunningInfo

View File

@ -181,6 +181,8 @@ sys-locale = "0.3"
ringbuf = "0.4.5" ringbuf = "0.4.5"
async-ringbuf = "0.3.1" async-ringbuf = "0.3.1"
async-compression = { version = "0.4.17", default-features = false, features = ["zstd", "tokio"] }
[target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies] [target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies]
machine-uid = "0.5.3" machine-uid = "0.5.3"

View File

@ -0,0 +1,174 @@
use async_compression::tokio::write::{ZstdDecoder, ZstdEncoder};
use tokio::io::AsyncWriteExt;
use zerocopy::{AsBytes as _, FromBytes as _};
use crate::tunnel::packet_def::{
CompressorAlgo, CompressorTail, PacketType, ZCPacket, COMPRESSOR_TAIL_SIZE,
};
type Error = anyhow::Error;
#[async_trait::async_trait]
pub trait Compressor {
async fn compress(
&self,
packet: &mut ZCPacket,
compress_algo: CompressorAlgo,
) -> Result<(), Error>;
async fn decompress(&self, packet: &mut ZCPacket) -> Result<(), Error>;
}
pub struct DefaultCompressor {}
impl DefaultCompressor {
pub fn new() -> Self {
DefaultCompressor {}
}
pub async fn compress_raw(
&self,
data: &[u8],
compress_algo: CompressorAlgo,
) -> Result<Vec<u8>, Error> {
let buf = match compress_algo {
CompressorAlgo::ZstdDefault => {
let mut o = ZstdEncoder::new(Vec::new());
o.write_all(data).await?;
o.shutdown().await?;
o.into_inner()
}
CompressorAlgo::None => data.to_vec(),
};
Ok(buf)
}
pub async fn decompress_raw(
&self,
data: &[u8],
compress_algo: CompressorAlgo,
) -> Result<Vec<u8>, Error> {
let buf = match compress_algo {
CompressorAlgo::ZstdDefault => {
let mut o = ZstdDecoder::new(Vec::new());
o.write_all(data).await?;
o.shutdown().await?;
o.into_inner()
}
CompressorAlgo::None => data.to_vec(),
};
Ok(buf)
}
}
#[async_trait::async_trait]
impl Compressor for DefaultCompressor {
async fn compress(
&self,
zc_packet: &mut ZCPacket,
compress_algo: CompressorAlgo,
) -> Result<(), Error> {
if matches!(compress_algo, CompressorAlgo::None) {
return Ok(());
}
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_compressed() || pm_header.packet_type != PacketType::Data as u8 {
// only compress data packets
return Ok(());
}
let tail = CompressorTail::new(compress_algo);
let buf = self
.compress_raw(zc_packet.payload(), compress_algo)
.await?;
if buf.len() + COMPRESSOR_TAIL_SIZE > pm_header.len.get() as usize {
// Compressed data is larger than original data, don't compress
return Ok(());
}
zc_packet
.mut_peer_manager_header()
.unwrap()
.set_compressed(true);
let payload_offset = zc_packet.payload_offset();
zc_packet.mut_inner().truncate(payload_offset);
zc_packet.mut_inner().extend_from_slice(&buf);
zc_packet.mut_inner().extend_from_slice(tail.as_bytes());
Ok(())
}
async fn decompress(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if !pm_header.is_compressed() {
return Ok(());
}
let payload_len = zc_packet.payload().len();
if payload_len < COMPRESSOR_TAIL_SIZE {
return Err(anyhow::anyhow!("Packet too short: {}", payload_len));
}
let text_len = payload_len - COMPRESSOR_TAIL_SIZE;
let tail = CompressorTail::ref_from_suffix(zc_packet.payload())
.unwrap()
.clone();
let algo = tail
.get_algo()
.ok_or(anyhow::anyhow!("Unknown algo: {:?}", tail))?;
let buf = self
.decompress_raw(&zc_packet.payload()[..text_len], algo)
.await?;
if buf.len() != pm_header.len.get() as usize {
anyhow::bail!(
"Decompressed length mismatch: decompressed len {} != pm header len {}",
buf.len(),
pm_header.len.get()
);
}
zc_packet
.mut_peer_manager_header()
.unwrap()
.set_compressed(false);
let payload_offset = zc_packet.payload_offset();
zc_packet.mut_inner().truncate(payload_offset);
zc_packet.mut_inner().extend_from_slice(&buf);
Ok(())
}
}
#[cfg(test)]
pub mod tests {
use super::*;
#[tokio::test]
async fn test_compress() {
let text = b"12345670000000000000000000";
let mut packet = ZCPacket::new_with_payload(text);
packet.fill_peer_manager_hdr(0, 0, 0);
let compressor = DefaultCompressor {};
compressor
.compress(&mut packet, CompressorAlgo::ZstdDefault)
.await
.unwrap();
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), true);
compressor.decompress(&mut packet).await.unwrap();
assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
}
}

View File

@ -6,6 +6,7 @@ use std::{
use tokio::task::JoinSet; use tokio::task::JoinSet;
use tracing::Instrument; use tracing::Instrument;
pub mod compressor;
pub mod config; pub mod config;
pub mod constants; pub mod constants;
pub mod defer; pub mod defer;

View File

@ -20,6 +20,7 @@ use tokio::{
use crate::{ use crate::{
common::{ common::{
compressor::{Compressor as _, DefaultCompressor},
constants::EASYTIER_VERSION, constants::EASYTIER_VERSION,
error::Error, error::Error,
global_ctx::{ArcGlobalCtx, NetworkIdentity}, global_ctx::{ArcGlobalCtx, NetworkIdentity},
@ -41,7 +42,7 @@ use crate::{
}, },
tunnel::{ tunnel::{
self, self,
packet_def::{PacketType, ZCPacket}, packet_def::{CompressorAlgo, PacketType, ZCPacket},
SinkItem, Tunnel, TunnelConnector, SinkItem, Tunnel, TunnelConnector,
}, },
}; };
@ -61,6 +62,7 @@ use super::{
struct RpcTransport { struct RpcTransport {
my_peer_id: PeerId, my_peer_id: PeerId,
peers: Weak<PeerMap>, peers: Weak<PeerMap>,
// TODO: this seems can be removed
foreign_peers: Mutex<Option<Weak<ForeignNetworkClient>>>, foreign_peers: Mutex<Option<Weak<ForeignNetworkClient>>>,
packet_recv: Mutex<UnboundedReceiver<ZCPacket>>, packet_recv: Mutex<UnboundedReceiver<ZCPacket>>,
@ -76,48 +78,14 @@ impl PeerRpcManagerTransport for RpcTransport {
} }
async fn send(&self, mut msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { async fn send(&self, mut msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
let foreign_peers = self
.foreign_peers
.lock()
.await
.as_ref()
.ok_or(Error::Unknown)?
.upgrade()
.ok_or(Error::Unknown)?;
let peers = self.peers.upgrade().ok_or(Error::Unknown)?; let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
if peers.need_relay_by_foreign_network(dst_peer_id).await? {
if foreign_peers.has_next_hop(dst_peer_id) {
// do not encrypt for data sending to public server
tracing::debug!(
?dst_peer_id,
?self.my_peer_id,
"failed to send msg to peer, try foreign network",
);
foreign_peers.send_msg(msg, dst_peer_id).await
} else if let Some(gateway_id) = peers
.get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop)
.await
{
tracing::trace!(
?dst_peer_id,
?gateway_id,
?self.my_peer_id,
"send msg to peer via gateway",
);
self.encryptor self.encryptor
.encrypt(&mut msg) .encrypt(&mut msg)
.with_context(|| "encrypt failed")?; .with_context(|| "encrypt failed")?;
if peers.has_peer(gateway_id) {
peers.send_msg_directly(msg, gateway_id).await
} else {
foreign_peers.send_msg(msg, gateway_id).await
}
} else {
Err(Error::RouteError(Some(format!(
"peermgr RpcTransport no route for dst_peer_id: {}",
dst_peer_id
))))
} }
// send to self and this packet will be forwarded in peer_recv loop
peers.send_msg_directly(msg, self.my_peer_id).await
} }
async fn recv(&self) -> Result<ZCPacket, Error> { async fn recv(&self) -> Result<ZCPacket, Error> {
@ -163,6 +131,7 @@ pub struct PeerManager {
foreign_network_client: Arc<ForeignNetworkClient>, foreign_network_client: Arc<ForeignNetworkClient>,
encryptor: Arc<Box<dyn Encryptor>>, encryptor: Arc<Box<dyn Encryptor>>,
data_compress_algo: CompressorAlgo,
exit_nodes: Vec<Ipv4Addr>, exit_nodes: Vec<Ipv4Addr>,
} }
@ -272,6 +241,8 @@ impl PeerManager {
foreign_network_client, foreign_network_client,
encryptor, encryptor,
data_compress_algo: CompressorAlgo::None,
exit_nodes, exit_nodes,
} }
} }
@ -465,6 +436,12 @@ impl PeerManager {
continue; continue;
} }
let compressor = DefaultCompressor {};
if let Err(e) = compressor.decompress(&mut ret).await {
tracing::error!(?e, "decompress failed");
continue;
}
let mut processed = false; let mut processed = false;
let mut zc_packet = Some(ret); let mut zc_packet = Some(ret);
let mut idx = 0; let mut idx = 0;
@ -768,6 +745,11 @@ impl PeerManager {
tunnel::packet_def::PacketType::Data as u8, tunnel::packet_def::PacketType::Data as u8,
); );
self.run_nic_packet_process_pipeline(&mut msg).await; self.run_nic_packet_process_pipeline(&mut msg).await;
let compressor = DefaultCompressor {};
compressor
.compress(&mut msg, self.data_compress_algo)
.await
.with_context(|| "compress failed")?;
self.encryptor self.encryptor
.encrypt(&mut msg) .encrypt(&mut msg)
.with_context(|| "encrypt failed")?; .with_context(|| "encrypt failed")?;

View File

@ -250,6 +250,19 @@ impl PeerMap {
} }
route_map route_map
} }
pub async fn need_relay_by_foreign_network(&self, dst_peer_id: PeerId) -> Result<bool, Error> {
// if gateway_peer_id is not connected to me, means need relay by foreign network
let gateway_id = self
.get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop)
.await
.ok_or(Error::RouteError(Some(format!(
"peer map need_relay_by_foreign_network no gateway for dst_peer_id: {}",
dst_peer_id
))))?;
Ok(!self.has_peer(gateway_id))
}
} }
impl Drop for PeerMap { impl Drop for PeerMap {

View File

@ -32,7 +32,7 @@ message RpcDescriptor {
} }
message RpcRequest { message RpcRequest {
RpcDescriptor descriptor = 1; RpcDescriptor descriptor = 1 [ deprecated = true ];
bytes request = 2; bytes request = 2;
int32 timeout_ms = 3; int32 timeout_ms = 3;
@ -45,6 +45,21 @@ message RpcResponse {
uint64 runtime_us = 3; uint64 runtime_us = 3;
} }
enum CompressionAlgoPb {
Invalid = 0;
None = 1;
Zstd = 2;
}
message RpcCompressionInfo {
// use this to compress the content
CompressionAlgoPb algo = 1;
// tell the peer which compression algo is used to compress the next
// response/request
CompressionAlgoPb accepted_algo = 2;
}
message RpcPacket { message RpcPacket {
uint32 from_peer = 1; uint32 from_peer = 1;
uint32 to_peer = 2; uint32 to_peer = 2;
@ -58,6 +73,8 @@ message RpcPacket {
uint32 piece_idx = 8; uint32 piece_idx = 8;
int32 trace_id = 9; int32 trace_id = 9;
RpcCompressionInfo compression_info = 10;
} }
message Void {} message Void {}

View File

@ -2,6 +2,8 @@ use std::{fmt::Display, str::FromStr};
use anyhow::Context; use anyhow::Context;
use crate::tunnel::packet_def::CompressorAlgo;
include!(concat!(env!("OUT_DIR"), "/common.rs")); include!(concat!(env!("OUT_DIR"), "/common.rs"));
impl From<uuid::Uuid> for Uuid { impl From<uuid::Uuid> for Uuid {
@ -180,3 +182,26 @@ impl From<SocketAddr> for std::net::SocketAddr {
} }
} }
} }
impl TryFrom<CompressionAlgoPb> for CompressorAlgo {
type Error = anyhow::Error;
fn try_from(value: CompressionAlgoPb) -> Result<Self, Self::Error> {
match value {
CompressionAlgoPb::Zstd => Ok(CompressorAlgo::ZstdDefault),
CompressionAlgoPb::None => Ok(CompressorAlgo::None),
_ => Err(anyhow::anyhow!("Invalid CompressionAlgoPb")),
}
}
}
impl TryFrom<CompressorAlgo> for CompressionAlgoPb {
type Error = anyhow::Error;
fn try_from(value: CompressorAlgo) -> Result<Self, Self::Error> {
match value {
CompressorAlgo::ZstdDefault => Ok(CompressionAlgoPb::Zstd),
CompressorAlgo::None => Ok(CompressionAlgoPb::None),
}
}
}

View File

@ -12,8 +12,10 @@ use tokio_stream::StreamExt;
use crate::common::PeerId; use crate::common::PeerId;
use crate::defer; use crate::defer;
use crate::proto::common::{RpcDescriptor, RpcPacket, RpcRequest, RpcResponse}; use crate::proto::common::{
use crate::proto::rpc_impl::packet::build_rpc_packet; CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse,
};
use crate::proto::rpc_impl::packet::{build_rpc_packet, compress_packet, decompress_packet};
use crate::proto::rpc_types::controller::Controller; use crate::proto::rpc_types::controller::Controller;
use crate::proto::rpc_types::descriptor::MethodDescriptor; use crate::proto::rpc_types::descriptor::MethodDescriptor;
use crate::proto::rpc_types::{ use crate::proto::rpc_types::{
@ -48,12 +50,21 @@ struct InflightRequest {
start_time: std::time::Instant, start_time: std::time::Instant,
} }
#[derive(Debug, Clone, Default)]
struct PeerInfo {
peer_id: PeerId,
compression_info: RpcCompressionInfo,
last_active: Option<std::time::Instant>,
}
type InflightRequestTable = Arc<DashMap<InflightRequestKey, InflightRequest>>; type InflightRequestTable = Arc<DashMap<InflightRequestKey, InflightRequest>>;
type PeerInfoTable = Arc<DashMap<PeerId, PeerInfo>>;
pub struct Client { pub struct Client {
mpsc: Mutex<MpscTunnel<Box<dyn Tunnel>>>, mpsc: Mutex<MpscTunnel<Box<dyn Tunnel>>>,
transport: Mutex<Transport>, transport: Mutex<Transport>,
inflight_requests: InflightRequestTable, inflight_requests: InflightRequestTable,
peer_info: PeerInfoTable,
tasks: Arc<Mutex<JoinSet<()>>>, tasks: Arc<Mutex<JoinSet<()>>>,
} }
@ -64,6 +75,7 @@ impl Client {
mpsc: Mutex::new(MpscTunnel::new(ring_a, None)), mpsc: Mutex::new(MpscTunnel::new(ring_a, None)),
transport: Mutex::new(MpscTunnel::new(ring_b, None)), transport: Mutex::new(MpscTunnel::new(ring_b, None)),
inflight_requests: Arc::new(DashMap::new()), inflight_requests: Arc::new(DashMap::new()),
peer_info: Arc::new(DashMap::new()),
tasks: Arc::new(Mutex::new(JoinSet::new())), tasks: Arc::new(Mutex::new(JoinSet::new())),
} }
} }
@ -79,6 +91,21 @@ impl Client {
pub fn run(&self) { pub fn run(&self) {
let mut tasks = self.tasks.lock().unwrap(); let mut tasks = self.tasks.lock().unwrap();
let peer_infos = self.peer_info.clone();
tasks.spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
let now = std::time::Instant::now();
peer_infos.retain(|_, v| {
if let Some(last_active) = v.last_active {
return now.duration_since(last_active)
< std::time::Duration::from_secs(120);
}
true
});
}
});
let mut rx = self.mpsc.lock().unwrap().get_stream(); let mut rx = self.mpsc.lock().unwrap().get_stream();
let inflight_requests = self.inflight_requests.clone(); let inflight_requests = self.inflight_requests.clone();
tasks.spawn(async move { tasks.spawn(async move {
@ -111,6 +138,8 @@ impl Client {
continue; continue;
}; };
tracing::trace!(?packet, "Received response packet");
let ret = inflight_request.merger.feed(packet); let ret = inflight_request.merger.feed(packet);
match ret { match ret {
Ok(Some(rpc_packet)) => { Ok(Some(rpc_packet)) => {
@ -138,6 +167,7 @@ impl Client {
to_peer_id: PeerId, to_peer_id: PeerId,
zc_packet_sender: MpscTunnelSender, zc_packet_sender: MpscTunnelSender,
inflight_requests: InflightRequestTable, inflight_requests: InflightRequestTable,
peer_info: PeerInfoTable,
_phan: PhantomData<F>, _phan: PhantomData<F>,
} }
@ -194,23 +224,53 @@ impl Client {
}; };
let rpc_req = RpcRequest { let rpc_req = RpcRequest {
descriptor: Some(rpc_desc.clone()),
request: input.into(), request: input.into(),
timeout_ms: ctrl.timeout_ms(), timeout_ms: ctrl.timeout_ms(),
..Default::default()
}; };
let peer_info = self
.peer_info
.get(&self.to_peer_id)
.map(|v| v.clone())
.unwrap_or_default();
let (buf, c_algo) = compress_packet(
peer_info.compression_info.accepted_algo(),
&rpc_req.encode_to_vec(),
)
.await
.unwrap();
let packets = build_rpc_packet( let packets = build_rpc_packet(
self.from_peer_id, self.from_peer_id,
self.to_peer_id, self.to_peer_id,
rpc_desc, rpc_desc,
transaction_id, transaction_id,
true, true,
&rpc_req.encode_to_vec(), &buf,
ctrl.trace_id(), ctrl.trace_id(),
RpcCompressionInfo {
algo: c_algo.into(),
accepted_algo: CompressionAlgoPb::Zstd.into(),
},
); );
let timeout_dur = std::time::Duration::from_millis(ctrl.timeout_ms() as u64); let timeout_dur = std::time::Duration::from_millis(ctrl.timeout_ms() as u64);
let rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??; let mut rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??;
if let Some(compression_info) = rpc_packet.compression_info {
self.peer_info.insert(
self.to_peer_id,
PeerInfo {
peer_id: self.to_peer_id,
compression_info: compression_info.clone(),
last_active: Some(std::time::Instant::now()),
},
);
rpc_packet.body =
decompress_packet(compression_info.algo(), &rpc_packet.body).await?;
}
assert_eq!(rpc_packet.transaction_id, transaction_id); assert_eq!(rpc_packet.transaction_id, transaction_id);
@ -230,6 +290,7 @@ impl Client {
to_peer_id, to_peer_id,
zc_packet_sender: self.mpsc.lock().unwrap().get_sink(), zc_packet_sender: self.mpsc.lock().unwrap().get_sink(),
inflight_requests: self.inflight_requests.clone(), inflight_requests: self.inflight_requests.clone(),
peer_info: self.peer_info.clone(),
_phan: PhantomData, _phan: PhantomData,
}) })
} }

View File

@ -1,18 +1,44 @@
use prost::Message as _; use prost::Message as _;
use crate::{ use crate::{
common::PeerId, common::{compressor::DefaultCompressor, PeerId},
proto::{ proto::{
common::{RpcDescriptor, RpcPacket}, common::{CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket},
rpc_types::error::Error, rpc_types::error::Error,
}, },
tunnel::packet_def::{PacketType, ZCPacket}, tunnel::packet_def::{CompressorAlgo, PacketType, ZCPacket},
}; };
use super::RpcTransactId; use super::RpcTransactId;
const RPC_PACKET_CONTENT_MTU: usize = 1300; const RPC_PACKET_CONTENT_MTU: usize = 1300;
pub async fn compress_packet(
accepted_compression_algo: CompressionAlgoPb,
content: &[u8],
) -> Result<(Vec<u8>, CompressionAlgoPb), Error> {
let compressor = DefaultCompressor::new();
let algo = accepted_compression_algo
.try_into()
.unwrap_or(CompressorAlgo::None);
let compressed = compressor.compress_raw(&content, algo).await?;
if compressed.len() >= content.len() {
Ok((content.to_vec(), CompressionAlgoPb::None))
} else {
Ok((compressed, algo.try_into().unwrap()))
}
}
pub async fn decompress_packet(
compression_algo: CompressionAlgoPb,
content: &[u8],
) -> Result<Vec<u8>, Error> {
let compressor = DefaultCompressor::new();
let algo = compression_algo.try_into()?;
let decompressed = compressor.decompress_raw(&content, algo).await?;
Ok(decompressed)
}
pub struct PacketMerger { pub struct PacketMerger {
first_piece: Option<RpcPacket>, first_piece: Option<RpcPacket>,
pieces: Vec<RpcPacket>, pieces: Vec<RpcPacket>,
@ -46,7 +72,8 @@ impl PacketMerger {
body.extend_from_slice(&p.body); body.extend_from_slice(&p.body);
} }
let mut tmpl_packet = self.first_piece.as_ref().unwrap().clone(); // only the first packet contains the complete info
let mut tmpl_packet = self.pieces[0].clone();
tmpl_packet.total_pieces = 1; tmpl_packet.total_pieces = 1;
tmpl_packet.piece_idx = 0; tmpl_packet.piece_idx = 0;
tmpl_packet.body = body; tmpl_packet.body = body;
@ -58,17 +85,17 @@ impl PacketMerger {
let total_pieces = rpc_packet.total_pieces; let total_pieces = rpc_packet.total_pieces;
let piece_idx = rpc_packet.piece_idx; let piece_idx = rpc_packet.piece_idx;
if rpc_packet.descriptor.is_none() {
return Err(Error::MalformatRpcPacket(
"descriptor is missing".to_owned(),
));
}
// for compatibility with old version // for compatibility with old version
if total_pieces == 0 && piece_idx == 0 { if total_pieces == 0 && piece_idx == 0 {
return Ok(Some(rpc_packet)); return Ok(Some(rpc_packet));
} }
if rpc_packet.piece_idx == 0 && rpc_packet.descriptor.is_none() {
return Err(Error::MalformatRpcPacket(
"descriptor is missing".to_owned(),
));
}
// about 32MB max size // about 32MB max size
if total_pieces > 32 * 1024 || total_pieces == 0 { if total_pieces > 32 * 1024 || total_pieces == 0 {
return Err(Error::MalformatRpcPacket(format!( return Err(Error::MalformatRpcPacket(format!(
@ -89,6 +116,7 @@ impl PacketMerger {
{ {
self.first_piece = Some(rpc_packet.clone()); self.first_piece = Some(rpc_packet.clone());
self.pieces.clear(); self.pieces.clear();
tracing::trace!(?rpc_packet, "got first piece");
} }
self.pieces self.pieces
@ -113,6 +141,7 @@ pub fn build_rpc_packet(
is_req: bool, is_req: bool,
content: &Vec<u8>, content: &Vec<u8>,
trace_id: i32, trace_id: i32,
compression_info: RpcCompressionInfo,
) -> Vec<ZCPacket> { ) -> Vec<ZCPacket> {
let mut ret = Vec::new(); let mut ret = Vec::new();
let content_mtu = RPC_PACKET_CONTENT_MTU; let content_mtu = RPC_PACKET_CONTENT_MTU;
@ -130,13 +159,22 @@ pub fn build_rpc_packet(
let cur_packet = RpcPacket { let cur_packet = RpcPacket {
from_peer, from_peer,
to_peer, to_peer,
descriptor: Some(rpc_desc.clone()), descriptor: if cur_offset == 0 {
Some(rpc_desc.clone())
} else {
None
},
is_request: is_req, is_request: is_req,
total_pieces: total_pieces as u32, total_pieces: total_pieces as u32,
piece_idx: (cur_offset / content_mtu) as u32, piece_idx: (cur_offset / content_mtu) as u32,
transaction_id, transaction_id,
body: cur_content, body: cur_content,
trace_id, trace_id,
compression_info: if cur_offset == 0 {
Some(compression_info.clone())
} else {
None
},
}; };
cur_offset += cur_len; cur_offset += cur_len;

View File

@ -12,7 +12,10 @@ use tokio_stream::StreamExt;
use crate::{ use crate::{
common::{join_joinset_background, PeerId}, common::{join_joinset_background, PeerId},
proto::{ proto::{
common::{self, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse}, common::{
self, CompressionAlgoPb, RpcCompressionInfo, RpcPacket, RpcRequest,
RpcResponse,
},
rpc_types::error::Result, rpc_types::error::Result,
}, },
tunnel::{ tunnel::{
@ -23,7 +26,7 @@ use crate::{
}; };
use super::{ use super::{
packet::{build_rpc_packet, PacketMerger}, packet::{build_rpc_packet, compress_packet, decompress_packet, PacketMerger},
service_registry::ServiceRegistry, service_registry::ServiceRegistry,
RpcController, Transport, RpcController, Transport,
}; };
@ -31,7 +34,6 @@ use super::{
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct PacketMergerKey { struct PacketMergerKey {
from_peer_id: PeerId, from_peer_id: PeerId,
rpc_desc: RpcDescriptor,
transaction_id: i64, transaction_id: i64,
} }
@ -108,10 +110,11 @@ impl Server {
let key = PacketMergerKey { let key = PacketMergerKey {
from_peer_id: packet.from_peer, from_peer_id: packet.from_peer,
rpc_desc: packet.descriptor.clone().unwrap_or_default(),
transaction_id: packet.transaction_id, transaction_id: packet.transaction_id,
}; };
tracing::trace!(?key, ?packet, "Received request packet");
let ret = packet_merges let ret = packet_merges
.entry(key.clone()) .entry(key.clone())
.or_insert_with(PacketMerger::new) .or_insert_with(PacketMerger::new)
@ -144,7 +147,16 @@ impl Server {
} }
async fn handle_rpc_request(packet: RpcPacket, reg: Arc<ServiceRegistry>) -> Result<Bytes> { async fn handle_rpc_request(packet: RpcPacket, reg: Arc<ServiceRegistry>) -> Result<Bytes> {
let rpc_request = RpcRequest::decode(Bytes::from(packet.body))?; let body = if let Some(compression_info) = packet.compression_info {
decompress_packet(
compression_info.algo.try_into().unwrap_or_default(),
&packet.body,
)
.await?
} else {
packet.body
};
let rpc_request = RpcRequest::decode(Bytes::from(body))?;
let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64); let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64);
let ctrl = RpcController::default(); let ctrl = RpcController::default();
Ok(timeout( Ok(timeout(
@ -168,6 +180,7 @@ impl Server {
let mut resp_msg = RpcResponse::default(); let mut resp_msg = RpcResponse::default();
let now = std::time::Instant::now(); let now = std::time::Instant::now();
let compression_info = packet.compression_info.clone();
let resp_bytes = Self::handle_rpc_request(packet, reg).await; let resp_bytes = Self::handle_rpc_request(packet, reg).await;
match &resp_bytes { match &resp_bytes {
@ -180,14 +193,25 @@ impl Server {
}; };
resp_msg.runtime_us = now.elapsed().as_micros() as u64; resp_msg.runtime_us = now.elapsed().as_micros() as u64;
let (compressed_resp, algo) = compress_packet(
compression_info.unwrap_or_default().accepted_algo(),
&resp_msg.encode_to_vec(),
)
.await
.unwrap();
let packets = build_rpc_packet( let packets = build_rpc_packet(
to_peer, to_peer,
from_peer, from_peer,
desc, desc,
transaction_id, transaction_id,
false, false,
&resp_msg.encode_to_vec(), &compressed_resp,
trace_id, trace_id,
RpcCompressionInfo {
algo: algo.into(),
accepted_algo: CompressionAlgoPb::Zstd.into(),
},
); );
for packet in packets { for packet in packets {

View File

@ -107,6 +107,7 @@ fn random_string(len: usize) -> String {
#[tokio::test] #[tokio::test]
async fn rpc_basic_test() { async fn rpc_basic_test() {
// enable_log();
let ctx = TestContext::new(); let ctx = TestContext::new();
let server = GreetingServer::new(GreetingService { let server = GreetingServer::new(GreetingService {
@ -119,7 +120,7 @@ async fn rpc_basic_test() {
.client .client
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "".to_string()); .scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "".to_string());
// small size req and resp // // small size req and resp
let ctrl = RpcController::default(); let ctrl = RpcController::default();
let input = SayHelloRequest { let input = SayHelloRequest {

View File

@ -603,7 +603,7 @@ pub mod tests {
pub fn enable_log() { pub fn enable_log() {
let filter = tracing_subscriber::EnvFilter::builder() let filter = tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing::level_filters::LevelFilter::DEBUG.into()) .with_default_directive(tracing::level_filters::LevelFilter::TRACE.into())
.from_env() .from_env()
.unwrap() .unwrap()
.add_directive("tarpc=error".parse().unwrap()); .add_directive("tarpc=error".parse().unwrap());

View File

@ -7,6 +7,10 @@ use zerocopy::FromZeroes;
type DefaultEndian = LittleEndian; type DefaultEndian = LittleEndian;
const fn max(a: usize, b: usize) -> usize {
[a, b][(a < b) as usize]
}
// TCP TunnelHeader // TCP TunnelHeader
#[repr(C, packed)] #[repr(C, packed)]
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)] #[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
@ -49,11 +53,11 @@ pub enum PacketType {
Invalid = 0, Invalid = 0,
Data = 1, Data = 1,
HandShake = 2, HandShake = 2,
RoutePacket = 3, RoutePacket = 3, // deprecated
Ping = 4, Ping = 4,
Pong = 5, Pong = 5,
TaRpc = 6, TaRpc = 6, // deprecated
Route = 7, Route = 7, // deprecated
RpcReq = 8, RpcReq = 8,
RpcResp = 9, RpcResp = 9,
ForeignNetworkPacket = 10, ForeignNetworkPacket = 10,
@ -65,6 +69,7 @@ bitflags::bitflags! {
const LATENCY_FIRST = 0b0000_0010; const LATENCY_FIRST = 0b0000_0010;
const EXIT_NODE = 0b0000_0100; const EXIT_NODE = 0b0000_0100;
const NO_PROXY = 0b0000_1000; const NO_PROXY = 0b0000_1000;
const COMPRESSED = 0b0001_0000;
const _ = !0; const _ = !0;
} }
@ -118,6 +123,12 @@ impl PeerManagerHeader {
.contains(PeerManagerHeaderFlags::NO_PROXY) .contains(PeerManagerHeaderFlags::NO_PROXY)
} }
pub fn is_compressed(&self) -> bool {
PeerManagerHeaderFlags::from_bits(self.flags)
.unwrap()
.contains(PeerManagerHeaderFlags::COMPRESSED)
}
pub fn set_latency_first(&mut self, latency_first: bool) -> &mut Self { pub fn set_latency_first(&mut self, latency_first: bool) -> &mut Self {
let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap(); let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap();
if latency_first { if latency_first {
@ -150,6 +161,17 @@ impl PeerManagerHeader {
self.flags = flags.bits(); self.flags = flags.bits();
self self
} }
pub fn set_compressed(&mut self, compressed: bool) -> &mut Self {
let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap();
if compressed {
flags.insert(PeerManagerHeaderFlags::COMPRESSED);
} else {
flags.remove(PeerManagerHeaderFlags::COMPRESSED);
}
self.flags = flags.bits();
self
}
} }
#[repr(C, packed)] #[repr(C, packed)]
@ -201,12 +223,35 @@ pub struct AesGcmTail {
} }
pub const AES_GCM_ENCRYPTION_RESERVED: usize = std::mem::size_of::<AesGcmTail>(); pub const AES_GCM_ENCRYPTION_RESERVED: usize = std::mem::size_of::<AesGcmTail>();
pub const TAIL_RESERVED_SIZE: usize = AES_GCM_ENCRYPTION_RESERVED; #[derive(AsBytes, FromZeroes, Clone, Debug, Copy)]
#[repr(u8)]
const fn max(a: usize, b: usize) -> usize { pub enum CompressorAlgo {
[a, b][(a < b) as usize] None = 0,
ZstdDefault = 1,
} }
#[repr(C, packed)]
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
pub struct CompressorTail {
pub algo: u8,
}
pub const COMPRESSOR_TAIL_SIZE: usize = std::mem::size_of::<CompressorTail>();
impl CompressorTail {
pub fn get_algo(&self) -> Option<CompressorAlgo> {
match self.algo {
1 => Some(CompressorAlgo::ZstdDefault),
_ => None,
}
}
pub fn new(algo: CompressorAlgo) -> Self {
Self { algo: algo as u8 }
}
}
pub const TAIL_RESERVED_SIZE: usize = max(AES_GCM_ENCRYPTION_RESERVED, COMPRESSOR_TAIL_SIZE);
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct ZCPacketOffsets { pub struct ZCPacketOffsets {
pub payload_offset: usize, pub payload_offset: usize,