mirror of
https://github.com/EasyTier/EasyTier.git
synced 2024-11-16 11:42:27 +08:00
support custom cost calculate func when generating route table
This commit is contained in:
parent
51aa23b635
commit
72f86025bd
|
@ -19,7 +19,14 @@ use crate::{
|
|||
rpc::{NatType, StunInfo},
|
||||
};
|
||||
|
||||
use super::{peer_rpc::PeerRpcManager, PeerPacketFilter};
|
||||
use super::{
|
||||
peer_rpc::PeerRpcManager,
|
||||
route_trait::{
|
||||
DefaultRouteCostCalculator, NextHopPolicy, RouteCostCalculator,
|
||||
RouteCostCalculatorInterface,
|
||||
},
|
||||
PeerPacketFilter,
|
||||
};
|
||||
|
||||
static SERVICE_ID: u32 = 7;
|
||||
static UPDATE_PEER_INFO_PERIOD: Duration = Duration::from_secs(3600);
|
||||
|
@ -393,7 +400,12 @@ impl RouteTable {
|
|||
.map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap())
|
||||
}
|
||||
|
||||
fn build_from_synced_info(&self, my_peer_id: PeerId, synced_info: &SyncedRouteInfo) {
|
||||
fn build_from_synced_info<T: RouteCostCalculatorInterface>(
|
||||
&self,
|
||||
my_peer_id: PeerId,
|
||||
synced_info: &SyncedRouteInfo,
|
||||
cost_calc: T,
|
||||
) {
|
||||
// build peer_infos
|
||||
self.peer_infos.clear();
|
||||
for item in synced_info.peer_infos.iter() {
|
||||
|
@ -415,12 +427,18 @@ impl RouteTable {
|
|||
if peer_id == my_peer_id {
|
||||
continue;
|
||||
}
|
||||
let Some(path) = pathfinding::prelude::bfs(
|
||||
let Some((path, _cost)): Option<(Vec<u32>, i32)> = pathfinding::prelude::dijkstra(
|
||||
&my_peer_id,
|
||||
|p| {
|
||||
|src_peer| {
|
||||
synced_info
|
||||
.get_connected_peers(*p)
|
||||
.get_connected_peers(*src_peer)
|
||||
.unwrap_or_else(|| BTreeSet::new())
|
||||
.into_iter()
|
||||
.map(|dst_peer| {
|
||||
let cost = cost_calc.calculate_cost(*src_peer, dst_peer);
|
||||
(dst_peer, cost)
|
||||
})
|
||||
.collect::<BTreeSet<_>>()
|
||||
},
|
||||
|x| *x == peer_id,
|
||||
) else {
|
||||
|
@ -563,7 +581,9 @@ struct PeerRouteServiceImpl {
|
|||
|
||||
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
|
||||
|
||||
cost_calculator: Arc<std::sync::Mutex<Option<RouteCostCalculator>>>,
|
||||
route_table: RouteTable,
|
||||
route_table_with_cost: RouteTable,
|
||||
synced_route_info: Arc<SyncedRouteInfo>,
|
||||
cached_local_conn_map: std::sync::Mutex<RouteConnBitmap>,
|
||||
}
|
||||
|
@ -585,9 +605,17 @@ impl PeerRouteServiceImpl {
|
|||
PeerRouteServiceImpl {
|
||||
my_peer_id,
|
||||
global_ctx,
|
||||
interface: Arc::new(Mutex::new(None)),
|
||||
sessions: DashMap::new(),
|
||||
|
||||
interface: Arc::new(Mutex::new(None)),
|
||||
|
||||
cost_calculator: Arc::new(std::sync::Mutex::new(Some(Box::new(
|
||||
DefaultRouteCostCalculator,
|
||||
)))),
|
||||
|
||||
route_table: RouteTable::new(),
|
||||
route_table_with_cost: RouteTable::new(),
|
||||
|
||||
synced_route_info: Arc::new(SyncedRouteInfo {
|
||||
peer_infos: DashMap::new(),
|
||||
conn_map: DashMap::new(),
|
||||
|
@ -649,8 +677,31 @@ impl PeerRouteServiceImpl {
|
|||
}
|
||||
|
||||
fn update_route_table(&self) {
|
||||
self.route_table
|
||||
.build_from_synced_info(self.my_peer_id, &self.synced_route_info);
|
||||
self.route_table.build_from_synced_info(
|
||||
self.my_peer_id,
|
||||
&self.synced_route_info,
|
||||
DefaultRouteCostCalculator::default(),
|
||||
);
|
||||
|
||||
let calc_locked = self.cost_calculator.lock().unwrap();
|
||||
if calc_locked.is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.route_table_with_cost.build_from_synced_info(
|
||||
self.my_peer_id,
|
||||
&self.synced_route_info,
|
||||
&calc_locked.as_ref().unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
fn cost_calculator_need_update(&self) -> bool {
|
||||
self.cost_calculator
|
||||
.lock()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.map(|x| x.need_update())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn update_route_table_and_cached_local_conn_bitmap(&self) {
|
||||
|
@ -1183,6 +1234,10 @@ impl PeerRoute {
|
|||
session_mgr.sync_now("update_my_infos");
|
||||
}
|
||||
|
||||
if service_impl.cost_calculator_need_update() {
|
||||
service_impl.update_route_table();
|
||||
}
|
||||
|
||||
select! {
|
||||
ev = global_event_receiver.recv() => {
|
||||
tracing::info!(?ev, "global event received in update_my_peer_info_routine");
|
||||
|
@ -1234,6 +1289,19 @@ impl Route for PeerRoute {
|
|||
route_table.get_next_hop(dst_peer_id).map(|x| x.0)
|
||||
}
|
||||
|
||||
async fn get_next_hop_with_policy(
|
||||
&self,
|
||||
dst_peer_id: PeerId,
|
||||
policy: NextHopPolicy,
|
||||
) -> Option<PeerId> {
|
||||
let route_table = if matches!(policy, NextHopPolicy::LeastCost) {
|
||||
&self.service_impl.route_table_with_cost
|
||||
} else {
|
||||
&self.service_impl.route_table
|
||||
};
|
||||
route_table.get_next_hop(dst_peer_id).map(|x| x.0)
|
||||
}
|
||||
|
||||
async fn list_routes(&self) -> Vec<crate::rpc::Route> {
|
||||
let route_table = &self.service_impl.route_table;
|
||||
let mut routes = Vec::new();
|
||||
|
@ -1265,6 +1333,11 @@ impl Route for PeerRoute {
|
|||
tracing::info!(?ipv4_addr, "no peer id for ipv4");
|
||||
None
|
||||
}
|
||||
|
||||
async fn set_route_cost_fn(&self, _cost_fn: RouteCostCalculator) {
|
||||
*self.service_impl.cost_calculator.lock().unwrap() = Some(_cost_fn);
|
||||
self.service_impl.update_route_table();
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerPacketFilter for Arc<PeerRoute> {}
|
||||
|
@ -1282,7 +1355,7 @@ mod tests {
|
|||
connector::udp_hole_punch::tests::replace_stun_info_collector,
|
||||
peers::{
|
||||
peer_manager::{PeerManager, RouteAlgoType},
|
||||
route_trait::Route,
|
||||
route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface},
|
||||
tests::{connect_peer_manager, wait_for_condition},
|
||||
},
|
||||
rpc::NatType,
|
||||
|
@ -1609,4 +1682,59 @@ mod tests {
|
|||
println!("session: {:?}", r_a.session_mgr.dump_sessions());
|
||||
check_rpc_counter(&r_a, p_b.my_peer_id(), 2, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cost_calculator() {
|
||||
let p_a = create_mock_pmgr().await;
|
||||
let p_b = create_mock_pmgr().await;
|
||||
let p_c = create_mock_pmgr().await;
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_c.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_a.clone(), p_c.clone()).await;
|
||||
|
||||
let _r_a = create_mock_route(p_a.clone()).await;
|
||||
let _r_b = create_mock_route(p_b.clone()).await;
|
||||
let r_c = create_mock_route(p_c.clone()).await;
|
||||
|
||||
// in normal mode, packet from p_c should directly forward to p_a
|
||||
wait_for_condition(
|
||||
|| async { r_c.get_next_hop(p_a.my_peer_id()).await == Some(p_a.my_peer_id()) },
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
struct TestCostCalculator {
|
||||
p_a_peer_id: PeerId,
|
||||
p_b_peer_id: PeerId,
|
||||
p_c_peer_id: PeerId,
|
||||
}
|
||||
|
||||
impl RouteCostCalculatorInterface for TestCostCalculator {
|
||||
fn calculate_cost(&self, src: PeerId, dst: PeerId) -> i32 {
|
||||
if src == self.p_c_peer_id && dst == self.p_a_peer_id {
|
||||
return 100;
|
||||
}
|
||||
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
r_c.set_route_cost_fn(Box::new(TestCostCalculator {
|
||||
p_a_peer_id: p_a.my_peer_id(),
|
||||
p_b_peer_id: p_b.my_peer_id(),
|
||||
p_c_peer_id: p_c.my_peer_id(),
|
||||
}))
|
||||
.await;
|
||||
|
||||
// after set cost, packet from p_c should forward to p_b first
|
||||
wait_for_condition(
|
||||
|| async {
|
||||
r_c.get_next_hop_with_policy(p_a.my_peer_id(), NextHopPolicy::LeastCost)
|
||||
.await
|
||||
== Some(p_b.my_peer_id())
|
||||
},
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,18 @@ use tokio_util::bytes::Bytes;
|
|||
|
||||
use crate::common::{error::Error, PeerId};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum NextHopPolicy {
|
||||
LeastHop,
|
||||
LeastCost,
|
||||
}
|
||||
|
||||
impl Default for NextHopPolicy {
|
||||
fn default() -> Self {
|
||||
NextHopPolicy::LeastHop
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait RouteInterface {
|
||||
async fn list_peers(&self) -> Vec<PeerId>;
|
||||
|
@ -19,6 +31,28 @@ pub trait RouteInterface {
|
|||
|
||||
pub type RouteInterfaceBox = Box<dyn RouteInterface + Send + Sync>;
|
||||
|
||||
#[auto_impl::auto_impl(Box, Arc, &)]
|
||||
pub trait RouteCostCalculatorInterface: Send + Sync {
|
||||
fn calculate_cost(&self, _src: PeerId, _dst: PeerId) -> i32 {
|
||||
1
|
||||
}
|
||||
|
||||
fn need_update(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn dump(&self) -> String {
|
||||
"All routes have cost 1".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DefaultRouteCostCalculator;
|
||||
|
||||
impl RouteCostCalculatorInterface for DefaultRouteCostCalculator {}
|
||||
|
||||
pub type RouteCostCalculator = Box<dyn RouteCostCalculatorInterface>;
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box, Arc)]
|
||||
pub trait Route {
|
||||
|
@ -26,11 +60,21 @@ pub trait Route {
|
|||
async fn close(&self);
|
||||
|
||||
async fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId>;
|
||||
async fn get_next_hop_with_policy(
|
||||
&self,
|
||||
peer_id: PeerId,
|
||||
_policy: NextHopPolicy,
|
||||
) -> Option<PeerId> {
|
||||
self.get_next_hop(peer_id).await
|
||||
}
|
||||
|
||||
async fn list_routes(&self) -> Vec<crate::rpc::Route>;
|
||||
|
||||
async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option<PeerId> {
|
||||
None
|
||||
}
|
||||
|
||||
async fn set_route_cost_fn(&self, _cost_fn: RouteCostCalculator) {}
|
||||
}
|
||||
|
||||
pub type ArcRoute = Arc<Box<dyn Route + Send + Sync>>;
|
||||
|
|
Loading…
Reference in New Issue
Block a user