Support wireguard vpn portal (#43)

* support wireguard vpn portal
  user can use wireguard client to access easytier network

* add vpn portal cli

* clean logs

* avoid ospf msg too large
This commit is contained in:
Sijie.Sun 2024-03-30 22:15:14 +08:00 committed by GitHub
parent 90110aa587
commit 05cabb2651
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 704 additions and 63 deletions

View File

@ -91,7 +91,7 @@ url = { version = "2.5", features = ["serde"] }
byteorder = "1.5.0"
# for proxy
cidr = "0.2.2"
cidr = { version = "0.2.2", features = ["serde"] }
socket2 = "0.5.5"
# for hole punching
@ -119,6 +119,8 @@ boringtun = { version = "0.6.0" }
tabled = "0.15.*"
humansize = "2.1.3"
base64 = "0.21.7"
[target.'cfg(windows)'.dependencies]
windows-sys = { version = "0.52", features = [

View File

@ -142,3 +142,18 @@ message GetGlobalPeerMapResponse {
service PeerCenterRpc {
rpc GetGlobalPeerMap (GetGlobalPeerMapRequest) returns (GetGlobalPeerMapResponse);
}
message VpnPortalInfo {
string vpn_type = 1;
string client_config = 2;
repeated string connected_clients = 3;
}
message GetVpnPortalInfoRequest {}
message GetVpnPortalInfoResponse {
VpnPortalInfo vpn_portal_info = 1;
}
service VpnPortalRpc {
rpc GetVpnPortalInfo (GetVpnPortalInfoRequest) returns (GetVpnPortalInfoResponse);
}

View File

@ -42,6 +42,9 @@ pub trait ConfigLoader: Send + Sync {
fn get_rpc_portal(&self) -> Option<SocketAddr>;
fn set_rpc_portal(&self, addr: SocketAddr);
fn get_vpn_portal_config(&self) -> Option<VpnPortalConfig>;
fn set_vpn_portal_config(&self, config: VpnPortalConfig);
fn dump(&self) -> String;
}
@ -87,6 +90,12 @@ pub struct ConsoleLoggerConfig {
pub level: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct VpnPortalConfig {
pub client_cidr: cidr::Ipv4Cidr,
pub wireguard_listen: SocketAddr,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
struct Config {
netns: Option<String>,
@ -103,6 +112,8 @@ struct Config {
console_logger: Option<ConsoleLoggerConfig>,
rpc_portal: Option<SocketAddr>,
vpn_portal_config: Option<VpnPortalConfig>,
}
#[derive(Debug, Clone)]
@ -314,6 +325,13 @@ impl ConfigLoader for TomlConfigLoader {
self.config.lock().unwrap().rpc_portal = Some(addr);
}
fn get_vpn_portal_config(&self) -> Option<VpnPortalConfig> {
self.config.lock().unwrap().vpn_portal_config.clone()
}
fn set_vpn_portal_config(&self, config: VpnPortalConfig) {
self.config.lock().unwrap().vpn_portal_config = Some(config);
}
fn dump(&self) -> String {
toml::to_string_pretty(&*self.config.lock().unwrap()).unwrap()
}

View File

@ -28,6 +28,9 @@ pub enum GlobalCtxEvent {
Connecting(url::Url),
ConnectError(String, String), // (dst, error message)
VpnPortalClientConnected(String, String), // (portal, client ip)
VpnPortalClientDisconnected(String, String), // (portal, client ip)
}
type EventBus = tokio::sync::broadcast::Sender<GlobalCtxEvent>;
@ -192,6 +195,10 @@ impl GlobalCtx {
pub fn add_running_listener(&self, url: url::Url) {
self.running_listeners.lock().unwrap().push(url);
}
pub fn get_vpn_portal_cidr(&self) -> Option<cidr::Ipv4Cidr> {
self.config.get_vpn_portal_config().map(|x| x.client_cidr)
}
}
#[cfg(test)]

View File

@ -54,7 +54,7 @@ pub fn join_joinset_background<T: Debug + Send + Sync + 'static>(
}
future::poll_fn(|cx| {
tracing::info!("try join joinset tasks");
tracing::debug!("try join joinset tasks");
let Some(js) = js.upgrade() else {
return std::task::Poll::Ready(());
};

View File

@ -127,7 +127,7 @@ impl Stun {
continue;
};
tracing::info!(b = ?&udp_buf[..len], ?tids, ?remote_addr, ?stun_host, "recv stun response, msg: {:#?}", msg);
tracing::debug!(b = ?&udp_buf[..len], ?tids, ?remote_addr, ?stun_host, "recv stun response, msg: {:#?}", msg);
if msg.class() != MessageClass::SuccessResponse
|| msg.method() != BINDING
@ -194,7 +194,7 @@ impl Stun {
changed_addr
}
#[tracing::instrument(ret, err, level = Level::INFO)]
#[tracing::instrument(ret, err, level = Level::DEBUG)]
pub async fn bind_request(
&self,
source_port: u16,
@ -250,7 +250,7 @@ impl Stun {
real_port_changed,
};
tracing::info!(
tracing::debug!(
?stun_host,
?recv_addr,
?changed_socket_addr,
@ -300,7 +300,7 @@ impl UdpNatTypeDetector {
let ret = stun.bind_request(source_port, true, true).await;
if let Ok(resp) = ret {
if !resp.real_ip_changed || !resp.real_port_changed {
tracing::info!(
tracing::debug!(
?server_ip,
?ret,
"stun bind request return with unchanged ip and port"
@ -311,7 +311,7 @@ impl UdpNatTypeDetector {
}
ret_test2 = ret.ok();
ret_test3 = stun.bind_request(source_port, false, true).await.ok();
tracing::info!(?ret_test3, "stun bind request with changed port");
tracing::debug!(?ret_test3, "stun bind request with changed port");
succ = true;
break;
}
@ -320,7 +320,7 @@ impl UdpNatTypeDetector {
return NatType::Unknown;
}
tracing::info!(
tracing::debug!(
?ret_test1_1,
?ret_test1_2,
?ret_test2,

View File

@ -3,6 +3,7 @@
use std::{net::SocketAddr, vec};
use clap::{command, Args, Parser, Subcommand};
use rpc::vpn_portal_rpc_client::VpnPortalRpcClient;
mod arch;
mod common;
@ -38,6 +39,7 @@ enum SubCommand {
Stun,
Route,
PeerCenter,
VpnPortal,
}
#[derive(Args, Debug)]
@ -216,6 +218,12 @@ impl CommandHandler {
Ok(PeerCenterRpcClient::connect(self.addr.clone()).await?)
}
async fn get_vpn_portal_client(
&self,
) -> Result<VpnPortalRpcClient<tonic::transport::Channel>, Error> {
Ok(VpnPortalRpcClient::connect(self.addr.clone()).await?)
}
async fn list_peers(&self) -> Result<ListPeerResponse, Error> {
let mut client = self.get_peer_manager_client().await?;
let request = tonic::Request::new(ListPeerRequest::default());
@ -452,6 +460,18 @@ async fn main() -> Result<(), Error> {
.to_string()
);
}
SubCommand::VpnPortal => {
let mut vpn_portal_client = handler.get_vpn_portal_client().await?;
let resp = vpn_portal_client
.get_vpn_portal_info(GetVpnPortalInfoRequest::default())
.await?
.into_inner()
.vpn_portal_info
.unwrap_or_default();
println!("portal_name: {}\n", resp.vpn_type);
println!("client_config:{}", resp.client_config);
println!("connected_clients:\n{:#?}", resp.connected_clients);
}
}
Ok(())

View File

@ -17,9 +17,10 @@ mod peer_center;
mod peers;
mod rpc;
mod tunnels;
mod vpn_portal;
use common::{
config::{ConsoleLoggerConfig, FileLoggerConfig, NetworkIdentity, PeerConfig},
config::{ConsoleLoggerConfig, FileLoggerConfig, NetworkIdentity, PeerConfig, VpnPortalConfig},
get_logger_timer_rfc3339,
};
use instance::instance::Instance;
@ -105,6 +106,14 @@ struct Cli {
help = "instance uuid to identify this vpn node in whole vpn network example: 123e4567-e89b-12d3-a456-426614174000"
)]
instance_id: Option<String>,
#[arg(
long,
help = "url that defines the vpn portal, allow other vpn clients to connect.
example: wg://0.0.0.0:11010/10.14.14.0/24, means the vpn portal is a wireguard server listening on vpn.example.com:11010,
and the vpn client is in network of 10.14.14.0/24"
)]
vpn_portal: Option<String>,
}
impl From<Cli> for TomlConfigLoader {
@ -197,6 +206,38 @@ impl From<Cli> for TomlConfigLoader {
});
}
if cli.vpn_portal.is_some() {
let url: url::Url = cli
.vpn_portal
.clone()
.unwrap()
.parse()
.with_context(|| {
format!(
"failed to parse vpn portal url: {}",
cli.vpn_portal.unwrap()
)
})
.unwrap();
cfg.set_vpn_portal_config(VpnPortalConfig {
client_cidr: url.path()[1..]
.parse()
.with_context(|| {
format!("failed to parse vpn portal client cidr: {}", url.path())
})
.unwrap(),
wireguard_listen: format!("{}:{}", url.host_str().unwrap(), url.port().unwrap())
.parse()
.with_context(|| {
format!(
"failed to parse vpn portal wireguard listen address: {}",
url.host_str().unwrap()
)
})
.unwrap(),
});
}
cfg
}
}
@ -337,6 +378,20 @@ pub async fn main() {
GlobalCtxEvent::ConnectError(dst, err) => {
print_event(format!("connect to peer error. dst: {}, err: {}", dst, err));
}
GlobalCtxEvent::VpnPortalClientConnected(portal, client_addr) => {
print_event(format!(
"vpn portal client connected. portal: {}, client_addr: {}",
portal, client_addr
));
}
GlobalCtxEvent::VpnPortalClientDisconnected(portal, client_addr) => {
print_event(format!(
"vpn portal client disconnected. portal: {}, client_addr: {}",
portal, client_addr
));
}
}
}
});

View File

@ -1,6 +1,6 @@
use std::borrow::BorrowMut;
use std::net::Ipv4Addr;
use std::sync::Arc;
use std::sync::{Arc, Weak};
use anyhow::Context;
use futures::StreamExt;
@ -25,7 +25,10 @@ use crate::peer_center::instance::PeerCenterInstance;
use crate::peers::peer_conn::PeerConnId;
use crate::peers::peer_manager::{PeerManager, RouteAlgoType};
use crate::peers::rpc_service::PeerManagerRpcService;
use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc;
use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
use crate::tunnels::SinkItem;
use crate::vpn_portal::{self, VpnPortal};
use tokio_stream::wrappers::ReceiverStream;
@ -54,6 +57,8 @@ pub struct Instance {
peer_center: Arc<PeerCenterInstance>,
vpn_portal: Arc<Mutex<Box<dyn VpnPortal>>>,
global_ctx: ArcGlobalCtx,
}
@ -102,6 +107,8 @@ impl Instance {
let peer_center = Arc::new(PeerCenterInstance::new(peer_manager.clone()));
let vpn_portal_inst = vpn_portal::wireguard::WireGuard::default();
Instance {
inst_name: global_ctx.inst_name.clone(),
id,
@ -122,6 +129,8 @@ impl Instance {
peer_center,
vpn_portal: Arc::new(Mutex::new(Box::new(vpn_portal_inst))),
global_ctx,
}
}
@ -134,6 +143,7 @@ impl Instance {
if let Some(ipv4) = Ipv4Packet::new(&ret) {
if ipv4.get_version() != 4 {
tracing::info!("[USER_PACKET] not ipv4 packet: {:?}", ipv4);
return;
}
let dst_ipv4 = ipv4.get_destination();
tracing::trace!(
@ -270,6 +280,14 @@ impl Instance {
self.add_initial_peers().await?;
if let Some(_) = self.global_ctx.get_vpn_portal_cidr() {
self.vpn_portal
.lock()
.await
.start(self.get_global_ctx(), self.get_peer_manager())
.await?;
}
Ok(())
}
@ -304,6 +322,45 @@ impl Instance {
self.peer_manager.my_peer_id()
}
fn get_vpn_portal_rpc_service(&self) -> impl VpnPortalRpc {
struct VpnPortalRpcService {
peer_mgr: Weak<PeerManager>,
vpn_portal: Weak<Mutex<Box<dyn VpnPortal>>>,
}
#[tonic::async_trait]
impl VpnPortalRpc for VpnPortalRpcService {
async fn get_vpn_portal_info(
&self,
_request: tonic::Request<GetVpnPortalInfoRequest>,
) -> Result<tonic::Response<GetVpnPortalInfoResponse>, tonic::Status> {
let Some(vpn_portal) = self.vpn_portal.upgrade() else {
return Err(tonic::Status::unavailable("vpn portal not available"));
};
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
return Err(tonic::Status::unavailable("peer manager not available"));
};
let vpn_portal = vpn_portal.lock().await;
let ret = GetVpnPortalInfoResponse {
vpn_portal_info: Some(VpnPortalInfo {
vpn_type: vpn_portal.name(),
client_config: vpn_portal.dump_client_config(peer_mgr).await,
connected_clients: vpn_portal.list_clients().await,
}),
};
Ok(tonic::Response::new(ret))
}
}
VpnPortalRpcService {
peer_mgr: Arc::downgrade(&self.peer_manager),
vpn_portal: Arc::downgrade(&self.vpn_portal),
}
}
fn run_rpc_server(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let Some(addr) = self.global_ctx.config.get_rpc_portal() else {
tracing::info!("rpc server not enabled, because rpc_portal is not set.");
@ -313,6 +370,7 @@ impl Instance {
let conn_manager = self.conn_manager.clone();
let net_ns = self.global_ctx.net_ns.clone();
let peer_center = self.peer_center.clone();
let vpn_portal_rpc = self.get_vpn_portal_rpc_service();
self.tasks.spawn(async move {
let _g = net_ns.guard();
@ -332,6 +390,9 @@ impl Instance {
peer_center.get_rpc_service(),
),
)
.add_service(crate::rpc::vpn_portal_rpc_server::VpnPortalRpcServer::new(
vpn_portal_rpc,
))
.serve(addr)
.await
.with_context(|| format!("rpc server failed. addr: {}", addr))

View File

@ -242,9 +242,9 @@ impl PeerCenterInstance {
for _ in 1..10 {
peers = ctx.job_ctx.service.list_peers().await.into();
if peers == *ctx.job_ctx.last_report_peers.lock().await {
break;
return Ok(3000);
}
tokio::time::sleep(Duration::from_secs(1)).await;
tokio::time::sleep(Duration::from_secs(2)).await;
}
*ctx.job_ctx.last_report_peers.lock().await = peers.clone();

View File

@ -189,7 +189,7 @@ impl ForeignNetworkManager {
}
pub async fn add_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> {
tracing::warn!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network manager");
tracing::info!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network manager");
let entry = self
.data
@ -222,10 +222,11 @@ impl ForeignNetworkManager {
let mut s = self.global_ctx.subscribe();
self.tasks.lock().await.spawn(async move {
while let Ok(e) = s.recv().await {
tracing::warn!(?e, "global event");
if let GlobalCtxEvent::PeerRemoved(peer_id) = &e {
tracing::info!(?e, "remove peer from foreign network manager");
data.remove_peer(*peer_id);
} else if let GlobalCtxEvent::PeerConnRemoved(..) = &e {
tracing::info!(?e, "clear no conn peer from foreign network manager");
data.clear_no_conn_peer();
}
}

View File

@ -99,6 +99,7 @@ impl RoutePeerInfo {
.get_proxy_cidrs()
.iter()
.map(|x| x.to_string())
.chain(global_ctx.get_vpn_portal_cidr().map(|x| x.to_string()))
.collect(),
hostname: global_ctx.get_hostname(),
udp_stun_info: global_ctx
@ -385,6 +386,10 @@ impl RouteTable {
self.next_hop_map.get(&dst_peer_id).map(|x| *x)
}
fn peer_reachable(&self, peer_id: PeerId) -> bool {
self.next_hop_map.contains_key(&peer_id)
}
fn get_nat_type(&self, peer_id: PeerId) -> Option<NatType> {
self.peer_infos
.get(&peer_id)
@ -407,10 +412,10 @@ impl RouteTable {
// build next hop map
self.next_hop_map.clear();
self.next_hop_map.insert(my_peer_id, (my_peer_id, 0));
for item in self.peer_infos.iter() {
let peer_id = *item.key();
if peer_id == my_peer_id {
self.next_hop_map.insert(peer_id, (peer_id, 0));
continue;
}
let Some(path) = pathfinding::prelude::bfs(
@ -617,8 +622,7 @@ impl PeerRouteServiceImpl {
.synced_route_info
.update_my_peer_info(self.my_peer_id, &self.global_ctx)
{
self.update_cached_local_conn_bitmap();
self.update_route_table();
self.update_route_table_and_cached_local_conn_bitmap();
return true;
}
false
@ -631,8 +635,7 @@ impl PeerRouteServiceImpl {
.update_my_conn_info(self.my_peer_id, connected_peers);
if updated {
self.update_cached_local_conn_bitmap();
self.update_route_table();
self.update_route_table_and_cached_local_conn_bitmap();
}
updated
@ -643,12 +646,27 @@ impl PeerRouteServiceImpl {
.build_from_synced_info(self.my_peer_id, &self.synced_route_info);
}
fn update_cached_local_conn_bitmap(&self) {
fn update_route_table_and_cached_local_conn_bitmap(&self) {
// update route table first because we want to filter out unreachable peers.
self.update_route_table();
// the conn_bitmap should contain complete list of directly connected peers.
// use union of dst peers can preserve this property.
let all_dst_peer_ids = self
.synced_route_info
.conn_map
.iter()
.map(|x| x.value().clone().0.into_iter())
.flatten()
.collect::<BTreeSet<_>>();
let all_peer_ids = self
.synced_route_info
.conn_map
.iter()
.map(|x| (*x.key(), x.value().1.get()))
// do not sync conn info of peers that are not reachable from any peer.
.filter(|p| all_dst_peer_ids.contains(&p.0) || self.route_table.peer_reachable(p.0))
.collect::<Vec<_>>();
let mut conn_bitmap = RouteConnBitmap::new();
@ -680,6 +698,12 @@ impl PeerRouteServiceImpl {
{
continue;
}
// do not send unreachable peer info to dst peer.
if !self.route_table.peer_reachable(*item.key()) {
continue;
}
route_infos.push(item.value().clone());
}
@ -867,8 +891,7 @@ impl RouteService for RouteSessionManager {
session.update_dst_saved_conn_bitmap_version(conn_bitmap);
}
service_impl.update_cached_local_conn_bitmap();
service_impl.update_route_table();
service_impl.update_route_table_and_cached_local_conn_bitmap();
tracing::debug!(
"sync_route_info: from_peer_id: {:?}, is_initiator: {:?}, peer_infos: {:?}, conn_bitmap: {:?}, synced_route_info: {:?} session: {:?}, new_route_table: {:?}",
@ -1012,7 +1035,7 @@ impl RouteSessionManager {
.map(|x| *x)
.collect::<Vec<_>>();
tracing::info!(?service_impl.my_peer_id, ?peers, ?session_peers, ?initiator_candidates, "maintain_sessions begin");
tracing::debug!(?service_impl.my_peer_id, ?peers, ?session_peers, ?initiator_candidates, "maintain_sessions begin");
if initiator_candidates.is_empty() {
next_sleep_ms = 1000;

View File

@ -52,6 +52,7 @@ impl SyncPeerInfo {
.get_proxy_cidrs()
.iter()
.map(|x| x.to_string())
.chain(global_ctx.get_vpn_portal_cidr().map(|x| x.to_string()))
.collect(),
hostname: global_ctx.get_hostname(),
udp_stun_info: global_ctx

View File

@ -28,7 +28,7 @@ use super::{
DatagramSink, DatagramStream, Tunnel, TunnelListener,
};
pub const UDP_DATA_MTU: usize = 2500;
pub const UDP_DATA_MTU: usize = 65000;
#[derive(Archive, Deserialize, Serialize)]
#[archive(compare(PartialEq), check_bytes)]
@ -123,7 +123,7 @@ fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
}
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::warn!(?udp_packet, "udp magic not match");
tracing::trace!(?udp_packet, "udp magic not match");
return None;
}
@ -351,7 +351,7 @@ impl TunnelListener for UdpTunnelListener {
};
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::info!(?udp_packet, "udp magic not match");
tracing::trace!(?udp_packet, "udp magic not match");
continue;
}
@ -471,7 +471,7 @@ impl UdpTunnelConnector {
};
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::info!(?udp_packet, "udp magic not match");
tracing::trace!(?udp_packet, "udp magic not match");
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, magic not match. magic: {:?}",
udp_packet.magic

View File

@ -30,14 +30,25 @@ use super::{
DatagramSink, DatagramStream, Tunnel, TunnelError, TunnelListener,
};
const MAX_PACKET: usize = 4096;
const MAX_PACKET: usize = 65500;
#[derive(Debug, Clone)]
enum WgType {
// used by easytier peer, need remove/add ip header for in/out wg msg
InternalUse,
// used by wireguard peer, keep original ip header
ExternalUse,
}
#[derive(Clone)]
pub struct WgConfig {
my_secret_key: StaticSecret,
my_public_key: PublicKey,
peer_secret_key: StaticSecret,
peer_public_key: PublicKey,
wg_type: WgType,
}
impl WgConfig {
@ -56,14 +67,47 @@ impl WgConfig {
let my_secret_key = StaticSecret::from(my_sec);
let my_public_key = PublicKey::from(&my_secret_key);
let peer_secret_key = StaticSecret::from(my_sec);
let peer_public_key = my_public_key.clone();
WgConfig {
my_secret_key,
my_public_key,
peer_secret_key,
peer_public_key,
wg_type: WgType::InternalUse,
}
}
pub fn new_for_portal(server_key_seed: &str, client_key_seed: &str) -> Self {
let server_cfg = Self::new_from_network_identity("server", server_key_seed);
let client_cfg = Self::new_from_network_identity("client", client_key_seed);
Self {
my_secret_key: server_cfg.my_secret_key,
my_public_key: server_cfg.my_public_key,
peer_secret_key: client_cfg.my_secret_key,
peer_public_key: client_cfg.my_public_key,
wg_type: WgType::ExternalUse,
}
}
pub fn my_secret_key(&self) -> &[u8] {
self.my_secret_key.as_bytes()
}
pub fn peer_secret_key(&self) -> &[u8] {
self.peer_secret_key.as_bytes()
}
pub fn my_public_key(&self) -> &[u8] {
self.my_public_key.as_bytes()
}
pub fn peer_public_key(&self) -> &[u8] {
self.peer_public_key.as_bytes()
}
}
#[derive(Clone)]
@ -73,6 +117,7 @@ struct WgPeerData {
tunn: Arc<Mutex<Tunn>>,
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
stream: Arc<Mutex<Pin<Box<dyn DatagramStream>>>>,
wg_type: WgType,
}
impl Debug for WgPeerData {
@ -88,12 +133,17 @@ impl WgPeerData {
#[tracing::instrument]
async fn handle_one_packet_from_me(&self, packet: &[u8]) -> Result<(), anyhow::Error> {
let mut send_buf = [0u8; MAX_PACKET];
let encapsulate_result = {
let mut peer = self.tunn.lock().await;
peer.encapsulate(&packet, &mut send_buf)
if matches!(self.wg_type, WgType::InternalUse) {
peer.encapsulate(&self.add_ip_header(&packet), &mut send_buf)
} else {
peer.encapsulate(&packet, &mut send_buf)
}
};
tracing::info!(
tracing::trace!(
?encapsulate_result,
"Received {} bytes from me",
packet.len()
@ -177,9 +227,13 @@ impl WgPeerData {
.lock()
.await
.send(
WgPeer::remove_ip_header(packet, packet[0] >> 4 == 4)
.to_vec()
.into(),
if matches!(self.wg_type, WgType::InternalUse) {
self.remove_ip_header(packet, packet[0] >> 4 == 4)
} else {
packet
}
.to_vec()
.into(),
)
.await;
if ret.is_err() {
@ -250,6 +304,31 @@ impl WgPeerData {
self.handle_routine_tun_result(tun_result).await;
}
}
fn add_ip_header(&self, packet: &[u8]) -> Vec<u8> {
let mut ret = vec![0u8; packet.len() + 20];
let ip_header = ret.as_mut_slice();
ip_header[0] = 0x45;
ip_header[1] = 0;
ip_header[2..4].copy_from_slice(&((packet.len() + 20) as u16).to_be_bytes());
ip_header[4..6].copy_from_slice(&0u16.to_be_bytes());
ip_header[6..8].copy_from_slice(&0u16.to_be_bytes());
ip_header[8] = 64;
ip_header[9] = 0;
ip_header[10..12].copy_from_slice(&0u16.to_be_bytes());
ip_header[12..16].copy_from_slice(&0u32.to_be_bytes());
ip_header[16..20].copy_from_slice(&0u32.to_be_bytes());
ip_header[20..].copy_from_slice(packet);
ret
}
fn remove_ip_header<'a>(&self, packet: &'a [u8], is_v4: bool) -> &'a [u8] {
if is_v4 {
return &packet[20..];
} else {
return &packet[40..];
}
}
}
struct WgPeer {
@ -277,36 +356,9 @@ impl WgPeer {
}
}
fn add_ip_header(packet: &[u8]) -> Vec<u8> {
let mut ret = vec![0u8; packet.len() + 20];
let ip_header = ret.as_mut_slice();
ip_header[0] = 0x45;
ip_header[1] = 0;
ip_header[2..4].copy_from_slice(&((packet.len() + 20) as u16).to_be_bytes());
ip_header[4..6].copy_from_slice(&0u16.to_be_bytes());
ip_header[6..8].copy_from_slice(&0u16.to_be_bytes());
ip_header[8] = 64;
ip_header[9] = 0;
ip_header[10..12].copy_from_slice(&0u16.to_be_bytes());
ip_header[12..16].copy_from_slice(&0u32.to_be_bytes());
ip_header[16..20].copy_from_slice(&0u32.to_be_bytes());
ip_header[20..].copy_from_slice(packet);
ret
}
fn remove_ip_header(packet: &[u8], is_v4: bool) -> &[u8] {
if is_v4 {
return &packet[20..];
} else {
return &packet[40..];
}
}
async fn handle_packet_from_me(data: WgPeerData) {
while let Some(Ok(packet)) = data.stream.lock().await.next().await {
let ret = data
.handle_one_packet_from_me(&Self::add_ip_header(&packet))
.await;
let ret = data.handle_one_packet_from_me(&packet).await;
if let Err(e) = ret {
tracing::error!("Failed to handle packet from me: {}", e);
}
@ -315,7 +367,7 @@ impl WgPeer {
async fn handle_packet_from_peer(&mut self, packet: &[u8]) {
self.access_time = std::time::Instant::now();
tracing::info!("Received {} bytes from peer", packet.len());
tracing::trace!("Received {} bytes from peer", packet.len());
let data = self.data.as_ref().unwrap();
data.handle_one_packet_from_peer(packet).await;
}
@ -339,6 +391,7 @@ impl WgPeer {
)),
sink: Arc::new(Mutex::new(stunnel.pin_sink())),
stream: Arc::new(Mutex::new(stunnel.pin_stream())),
wg_type: self.config.wg_type.clone(),
};
self.data = Some(data.clone());
@ -349,6 +402,17 @@ impl WgPeer {
}
}
impl Drop for WgPeer {
fn drop(&mut self) {
self.tasks.abort_all();
if let Some(data) = self.data.clone() {
tokio::spawn(async move {
let _ = data.sink.lock().await.close().await;
});
}
}
}
type ConnSender = tokio::sync::mpsc::UnboundedSender<Box<dyn Tunnel>>;
type ConnReceiver = tokio::sync::mpsc::UnboundedReceiver<Box<dyn Tunnel>>;
@ -406,7 +470,7 @@ impl WgTunnelListener {
};
let data = &buf[..n];
tracing::info!("Received {} bytes from {}", n, addr);
tracing::trace!("Received {} bytes from {}", n, addr);
if !peer_map.contains_key(&addr) {
tracing::info!("New peer: {}", addr);
@ -636,13 +700,17 @@ pub mod tests {
let server_cfg = WgConfig {
my_secret_key: my_secret_key.clone(),
my_public_key,
peer_secret_key: their_secret_key.clone(),
peer_public_key: their_public_key.clone(),
wg_type: WgType::InternalUse,
};
let client_cfg = WgConfig {
my_secret_key: their_secret_key,
my_public_key: their_public_key,
peer_secret_key: my_secret_key,
peer_public_key: my_public_key,
wg_type: WgType::InternalUse,
};
(server_cfg, client_cfg)

24
src/vpn_portal/mod.rs Normal file
View File

@ -0,0 +1,24 @@
// with vpn portal, user can use other vpn client to connect to easytier servers
// without installing easytier.
// these vpn client include:
// 1. wireguard
// 2. openvpn (TODO)
// 3. shadowsocks (TODO)
use std::sync::Arc;
use crate::{common::global_ctx::ArcGlobalCtx, peers::peer_manager::PeerManager};
pub mod wireguard;
#[async_trait::async_trait]
pub trait VpnPortal: Send + Sync {
async fn start(
&mut self,
global_ctx: ArcGlobalCtx,
peer_mgr: Arc<PeerManager>,
) -> anyhow::Result<()>;
async fn dump_client_config(&self, peer_mgr: Arc<PeerManager>) -> String;
fn name(&self) -> String;
async fn list_clients(&self) -> Vec<String>;
}

346
src/vpn_portal/wireguard.rs Normal file
View File

@ -0,0 +1,346 @@
use std::{
net::{Ipv4Addr, SocketAddr},
pin::Pin,
sync::Arc,
};
use anyhow::Context;
use base64::{prelude::BASE64_STANDARD, Engine};
use cidr::Ipv4Inet;
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use pnet::packet::ipv4::Ipv4Packet;
use tokio::{sync::Mutex, task::JoinSet};
use tokio_util::bytes::Bytes;
use crate::{
common::{
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
join_joinset_background,
},
peers::{
packet::{self, ArchivedPacket},
peer_manager::PeerManager,
PeerPacketFilter,
},
tunnels::{
wireguard::{WgConfig, WgTunnelListener},
DatagramSink, Tunnel, TunnelListener,
},
};
use super::VpnPortal;
type WgPeerIpTable = Arc<DashMap<Ipv4Addr, Arc<ClientEntry>>>;
struct ClientEntry {
endpoint_addr: Option<url::Url>,
sink: Mutex<Pin<Box<dyn DatagramSink + 'static>>>,
}
struct WireGuardImpl {
global_ctx: ArcGlobalCtx,
peer_mgr: Arc<PeerManager>,
wg_config: WgConfig,
listenr_addr: SocketAddr,
wg_peer_ip_table: WgPeerIpTable,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
}
impl WireGuardImpl {
fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc<PeerManager>) -> Self {
let nid = global_ctx.get_network_identity();
let key_seed = format!("{}{}", nid.network_name, nid.network_secret);
let wg_config = WgConfig::new_for_portal(&key_seed, &key_seed);
let vpn_cfg = global_ctx.config.get_vpn_portal_config().unwrap();
let listenr_addr = vpn_cfg.wireguard_listen;
Self {
global_ctx,
peer_mgr,
wg_config,
listenr_addr,
wg_peer_ip_table: Arc::new(DashMap::new()),
tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
}
}
async fn handle_incoming_conn(
t: Box<dyn Tunnel>,
peer_mgr: Arc<PeerManager>,
wg_peer_ip_table: WgPeerIpTable,
) {
let mut s = t.pin_stream();
let mut ip_registered = false;
let info = t.info().unwrap_or_default();
let remote_addr = info.remote_addr.clone();
peer_mgr
.get_global_ctx()
.issue_event(GlobalCtxEvent::VpnPortalClientConnected(
info.local_addr,
info.remote_addr,
));
while let Some(Ok(msg)) = s.next().await {
let Some(i) = Ipv4Packet::new(&msg) else {
tracing::error!(?msg, "Failed to parse ipv4 packet");
continue;
};
if !ip_registered {
let client_entry = Arc::new(ClientEntry {
endpoint_addr: remote_addr.parse().ok(),
sink: Mutex::new(t.pin_sink()),
});
wg_peer_ip_table.insert(i.get_source(), client_entry.clone());
ip_registered = true;
}
tracing::trace!(?i, "Received from wg client");
let _ = peer_mgr
.send_msg_ipv4(msg.clone(), i.get_destination())
.await;
}
let info = t.info().unwrap_or_default();
peer_mgr
.get_global_ctx()
.issue_event(GlobalCtxEvent::VpnPortalClientDisconnected(
info.local_addr,
info.remote_addr,
));
}
async fn start_pipeline_processor(&self) {
struct PeerPacketFilterForVpnPortal {
wg_peer_ip_table: WgPeerIpTable,
}
#[async_trait::async_trait]
impl PeerPacketFilter for PeerPacketFilterForVpnPortal {
async fn try_process_packet_from_peer(
&self,
packet: &ArchivedPacket,
_: &Bytes,
) -> Option<()> {
if packet.packet_type != packet::PacketType::Data {
return None;
};
let payload_bytes = packet.payload.as_bytes();
let ipv4 = Ipv4Packet::new(payload_bytes)?;
if ipv4.get_version() != 4 {
return None;
}
let entry = self.wg_peer_ip_table.get(&ipv4.get_destination())?.clone();
tracing::trace!(?ipv4, "Packet filter for vpn portal");
let ret = entry
.sink
.lock()
.await
.send(Bytes::copy_from_slice(payload_bytes))
.await;
ret.ok()
}
}
self.peer_mgr
.add_packet_process_pipeline(Box::new(PeerPacketFilterForVpnPortal {
wg_peer_ip_table: self.wg_peer_ip_table.clone(),
}))
.await;
}
async fn start(&self) -> anyhow::Result<()> {
let mut l = WgTunnelListener::new(
format!("wg://{}", self.listenr_addr).parse().unwrap(),
self.wg_config.clone(),
);
l.listen()
.await
.with_context(|| "Failed to start wireguard listener for vpn portal")?;
join_joinset_background(self.tasks.clone(), "wireguard".to_string());
let tasks = Arc::downgrade(&self.tasks.clone());
let peer_mgr = self.peer_mgr.clone();
let wg_peer_ip_table = self.wg_peer_ip_table.clone();
self.tasks.lock().unwrap().spawn(async move {
while let Ok(t) = l.accept().await {
let Some(tasks) = tasks.upgrade() else {
break;
};
tasks.lock().unwrap().spawn(Self::handle_incoming_conn(
t,
peer_mgr.clone(),
wg_peer_ip_table.clone(),
));
}
});
self.start_pipeline_processor().await;
Ok(())
}
}
#[derive(Default)]
pub struct WireGuard {
inner: Option<WireGuardImpl>,
}
#[async_trait::async_trait]
impl VpnPortal for WireGuard {
async fn start(
&mut self,
global_ctx: ArcGlobalCtx,
peer_mgr: Arc<PeerManager>,
) -> anyhow::Result<()> {
assert!(self.inner.is_none());
let vpn_cfg = global_ctx.config.get_vpn_portal_config();
if vpn_cfg.is_none() {
anyhow::bail!("vpn cfg is not set for wireguard vpn portal");
}
let inner = WireGuardImpl::new(global_ctx, peer_mgr);
inner.start().await?;
self.inner = Some(inner);
Ok(())
}
async fn dump_client_config(&self, peer_mgr: Arc<PeerManager>) -> String {
let global_ctx = self.inner.as_ref().unwrap().global_ctx.clone();
let routes = peer_mgr.list_routes().await;
let mut allow_ips = routes
.iter()
.map(|x| x.proxy_cidrs.iter().map(String::to_string))
.flatten()
.collect::<Vec<_>>();
for ipv4 in routes.iter().map(|x| &x.ipv4_addr) {
let Ok(ipv4) = ipv4.parse() else {
continue;
};
let inet = Ipv4Inet::new(ipv4, 24).unwrap();
allow_ips.push(inet.network().to_string());
break;
}
let allow_ips = allow_ips
.into_iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",");
let vpn_cfg = global_ctx.config.get_vpn_portal_config().unwrap();
let client_cidr = vpn_cfg.client_cidr;
let cfg = self.inner.as_ref().unwrap().wg_config.clone();
let cfg_str = format!(
r#"
[Interface]
PrivateKey = {peer_secret_key}
Address = {client_cidr} # should assign an ip from this cidr manually
[Peer]
PublicKey = {my_public_key}
AllowedIPs = {allow_ips}
Endpoint = {listenr_addr} # should be the public ip of the vpn server
"#,
peer_secret_key = BASE64_STANDARD.encode(cfg.peer_secret_key()),
my_public_key = BASE64_STANDARD.encode(cfg.my_public_key()),
listenr_addr = self.inner.as_ref().unwrap().listenr_addr,
allow_ips = allow_ips,
client_cidr = client_cidr,
);
cfg_str
}
fn name(&self) -> String {
"wireguard".to_string()
}
async fn list_clients(&self) -> Vec<String> {
self.inner
.as_ref()
.unwrap()
.wg_peer_ip_table
.iter()
.map(|x| {
x.value()
.endpoint_addr
.as_ref()
.map(|x| x.to_string())
.unwrap_or_default()
})
.collect()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::{
common::{
config::{NetworkIdentity, VpnPortalConfig},
global_ctx::tests::get_mock_global_ctx_with_network,
},
connector::udp_hole_punch::tests::replace_stun_info_collector,
peers::{
peer_manager::{PeerManager, RouteAlgoType},
tests::wait_for_condition,
},
rpc::NatType,
tunnels::{tcp_tunnel::TcpTunnelConnector, TunnelConnector},
};
async fn portal_test() {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Ospf,
get_mock_global_ctx_with_network(Some(NetworkIdentity {
network_name: "sijie".to_string(),
network_secret: "1919119".to_string(),
})),
s,
));
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
peer_mgr
.get_global_ctx()
.config
.set_vpn_portal_config(VpnPortalConfig {
wireguard_listen: "0.0.0.0:11021".parse().unwrap(),
client_cidr: "10.14.14.0/24".parse().unwrap(),
});
peer_mgr.run().await.unwrap();
let mut pmgr_conn = TcpTunnelConnector::new("tcp://127.0.0.1:11010".parse().unwrap());
let tunnel = pmgr_conn.connect().await;
peer_mgr.add_client_tunnel(tunnel.unwrap()).await.unwrap();
wait_for_condition(
|| async {
let routes = peer_mgr.list_routes().await;
println!("Routes: {:?}", routes);
routes.len() != 0
},
std::time::Duration::from_secs(10),
)
.await;
let mut wg = WireGuard::default();
wg.start(peer_mgr.get_global_ctx(), peer_mgr.clone())
.await
.unwrap();
}
}