support custom cost calculate func when generating route table

This commit is contained in:
sijie.sun 2024-05-12 11:48:13 +08:00
parent 51aa23b635
commit 72f86025bd
2 changed files with 181 additions and 9 deletions

View File

@ -19,7 +19,14 @@ use crate::{
rpc::{NatType, StunInfo},
};
use super::{peer_rpc::PeerRpcManager, PeerPacketFilter};
use super::{
peer_rpc::PeerRpcManager,
route_trait::{
DefaultRouteCostCalculator, NextHopPolicy, RouteCostCalculator,
RouteCostCalculatorInterface,
},
PeerPacketFilter,
};
static SERVICE_ID: u32 = 7;
static UPDATE_PEER_INFO_PERIOD: Duration = Duration::from_secs(3600);
@ -393,7 +400,12 @@ impl RouteTable {
.map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap())
}
fn build_from_synced_info(&self, my_peer_id: PeerId, synced_info: &SyncedRouteInfo) {
fn build_from_synced_info<T: RouteCostCalculatorInterface>(
&self,
my_peer_id: PeerId,
synced_info: &SyncedRouteInfo,
cost_calc: T,
) {
// build peer_infos
self.peer_infos.clear();
for item in synced_info.peer_infos.iter() {
@ -415,12 +427,18 @@ impl RouteTable {
if peer_id == my_peer_id {
continue;
}
let Some(path) = pathfinding::prelude::bfs(
let Some((path, _cost)): Option<(Vec<u32>, i32)> = pathfinding::prelude::dijkstra(
&my_peer_id,
|p| {
|src_peer| {
synced_info
.get_connected_peers(*p)
.get_connected_peers(*src_peer)
.unwrap_or_else(|| BTreeSet::new())
.into_iter()
.map(|dst_peer| {
let cost = cost_calc.calculate_cost(*src_peer, dst_peer);
(dst_peer, cost)
})
.collect::<BTreeSet<_>>()
},
|x| *x == peer_id,
) else {
@ -563,7 +581,9 @@ struct PeerRouteServiceImpl {
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
cost_calculator: Arc<std::sync::Mutex<Option<RouteCostCalculator>>>,
route_table: RouteTable,
route_table_with_cost: RouteTable,
synced_route_info: Arc<SyncedRouteInfo>,
cached_local_conn_map: std::sync::Mutex<RouteConnBitmap>,
}
@ -585,9 +605,17 @@ impl PeerRouteServiceImpl {
PeerRouteServiceImpl {
my_peer_id,
global_ctx,
interface: Arc::new(Mutex::new(None)),
sessions: DashMap::new(),
interface: Arc::new(Mutex::new(None)),
cost_calculator: Arc::new(std::sync::Mutex::new(Some(Box::new(
DefaultRouteCostCalculator,
)))),
route_table: RouteTable::new(),
route_table_with_cost: RouteTable::new(),
synced_route_info: Arc::new(SyncedRouteInfo {
peer_infos: DashMap::new(),
conn_map: DashMap::new(),
@ -649,8 +677,31 @@ impl PeerRouteServiceImpl {
}
fn update_route_table(&self) {
self.route_table
.build_from_synced_info(self.my_peer_id, &self.synced_route_info);
self.route_table.build_from_synced_info(
self.my_peer_id,
&self.synced_route_info,
DefaultRouteCostCalculator::default(),
);
let calc_locked = self.cost_calculator.lock().unwrap();
if calc_locked.is_none() {
return;
}
self.route_table_with_cost.build_from_synced_info(
self.my_peer_id,
&self.synced_route_info,
&calc_locked.as_ref().unwrap(),
);
}
fn cost_calculator_need_update(&self) -> bool {
self.cost_calculator
.lock()
.unwrap()
.as_ref()
.map(|x| x.need_update())
.unwrap_or(false)
}
fn update_route_table_and_cached_local_conn_bitmap(&self) {
@ -1183,6 +1234,10 @@ impl PeerRoute {
session_mgr.sync_now("update_my_infos");
}
if service_impl.cost_calculator_need_update() {
service_impl.update_route_table();
}
select! {
ev = global_event_receiver.recv() => {
tracing::info!(?ev, "global event received in update_my_peer_info_routine");
@ -1234,6 +1289,19 @@ impl Route for PeerRoute {
route_table.get_next_hop(dst_peer_id).map(|x| x.0)
}
async fn get_next_hop_with_policy(
&self,
dst_peer_id: PeerId,
policy: NextHopPolicy,
) -> Option<PeerId> {
let route_table = if matches!(policy, NextHopPolicy::LeastCost) {
&self.service_impl.route_table_with_cost
} else {
&self.service_impl.route_table
};
route_table.get_next_hop(dst_peer_id).map(|x| x.0)
}
async fn list_routes(&self) -> Vec<crate::rpc::Route> {
let route_table = &self.service_impl.route_table;
let mut routes = Vec::new();
@ -1265,6 +1333,11 @@ impl Route for PeerRoute {
tracing::info!(?ipv4_addr, "no peer id for ipv4");
None
}
async fn set_route_cost_fn(&self, _cost_fn: RouteCostCalculator) {
*self.service_impl.cost_calculator.lock().unwrap() = Some(_cost_fn);
self.service_impl.update_route_table();
}
}
impl PeerPacketFilter for Arc<PeerRoute> {}
@ -1282,7 +1355,7 @@ mod tests {
connector::udp_hole_punch::tests::replace_stun_info_collector,
peers::{
peer_manager::{PeerManager, RouteAlgoType},
route_trait::Route,
route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface},
tests::{connect_peer_manager, wait_for_condition},
},
rpc::NatType,
@ -1609,4 +1682,59 @@ mod tests {
println!("session: {:?}", r_a.session_mgr.dump_sessions());
check_rpc_counter(&r_a, p_b.my_peer_id(), 2, 2);
}
#[tokio::test]
async fn test_cost_calculator() {
let p_a = create_mock_pmgr().await;
let p_b = create_mock_pmgr().await;
let p_c = create_mock_pmgr().await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_c.clone(), p_b.clone()).await;
connect_peer_manager(p_a.clone(), p_c.clone()).await;
let _r_a = create_mock_route(p_a.clone()).await;
let _r_b = create_mock_route(p_b.clone()).await;
let r_c = create_mock_route(p_c.clone()).await;
// in normal mode, packet from p_c should directly forward to p_a
wait_for_condition(
|| async { r_c.get_next_hop(p_a.my_peer_id()).await == Some(p_a.my_peer_id()) },
Duration::from_secs(5),
)
.await;
struct TestCostCalculator {
p_a_peer_id: PeerId,
p_b_peer_id: PeerId,
p_c_peer_id: PeerId,
}
impl RouteCostCalculatorInterface for TestCostCalculator {
fn calculate_cost(&self, src: PeerId, dst: PeerId) -> i32 {
if src == self.p_c_peer_id && dst == self.p_a_peer_id {
return 100;
}
1
}
}
r_c.set_route_cost_fn(Box::new(TestCostCalculator {
p_a_peer_id: p_a.my_peer_id(),
p_b_peer_id: p_b.my_peer_id(),
p_c_peer_id: p_c.my_peer_id(),
}))
.await;
// after set cost, packet from p_c should forward to p_b first
wait_for_condition(
|| async {
r_c.get_next_hop_with_policy(p_a.my_peer_id(), NextHopPolicy::LeastCost)
.await
== Some(p_b.my_peer_id())
},
Duration::from_secs(5),
)
.await;
}
}

View File

@ -5,6 +5,18 @@ use tokio_util::bytes::Bytes;
use crate::common::{error::Error, PeerId};
#[derive(Clone, Debug)]
pub enum NextHopPolicy {
LeastHop,
LeastCost,
}
impl Default for NextHopPolicy {
fn default() -> Self {
NextHopPolicy::LeastHop
}
}
#[async_trait]
pub trait RouteInterface {
async fn list_peers(&self) -> Vec<PeerId>;
@ -19,6 +31,28 @@ pub trait RouteInterface {
pub type RouteInterfaceBox = Box<dyn RouteInterface + Send + Sync>;
#[auto_impl::auto_impl(Box, Arc, &)]
pub trait RouteCostCalculatorInterface: Send + Sync {
fn calculate_cost(&self, _src: PeerId, _dst: PeerId) -> i32 {
1
}
fn need_update(&self) -> bool {
false
}
fn dump(&self) -> String {
"All routes have cost 1".to_string()
}
}
#[derive(Clone, Debug, Default)]
pub struct DefaultRouteCostCalculator;
impl RouteCostCalculatorInterface for DefaultRouteCostCalculator {}
pub type RouteCostCalculator = Box<dyn RouteCostCalculatorInterface>;
#[async_trait]
#[auto_impl::auto_impl(Box, Arc)]
pub trait Route {
@ -26,11 +60,21 @@ pub trait Route {
async fn close(&self);
async fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId>;
async fn get_next_hop_with_policy(
&self,
peer_id: PeerId,
_policy: NextHopPolicy,
) -> Option<PeerId> {
self.get_next_hop(peer_id).await
}
async fn list_routes(&self) -> Vec<crate::rpc::Route>;
async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option<PeerId> {
None
}
async fn set_route_cost_fn(&self, _cost_fn: RouteCostCalculator) {}
}
pub type ArcRoute = Arc<Box<dyn Route + Send + Sync>>;