@@ -748,7 +748,7 @@ where L::Target: Logger {
748
748
}
749
749
750
750
let path_penalty_msat = $next_hops_path_penalty_msat
751
- . checked_add( scorer. channel_penalty_msat( $chan_id. clone( ) ) )
751
+ . checked_add( scorer. channel_penalty_msat( $chan_id. clone( ) , & $src_node_id , & $dest_node_id ) )
752
752
. unwrap_or_else( || u64 :: max_value( ) ) ;
753
753
let new_graph_node = RouteGraphNode {
754
754
node_id: $src_node_id,
@@ -973,15 +973,17 @@ where L::Target: Logger {
973
973
_ => aggregate_next_hops_fee_msat. checked_add ( 999 ) . unwrap_or ( u64:: max_value ( ) )
974
974
} ) { Some ( val / 1000 ) } else { break ; } ; // converting from msat or breaking if max ~ infinity
975
975
976
+ let src_node_id = NodeId :: from_pubkey ( & hop. src_node_id ) ;
977
+ let dest_node_id = NodeId :: from_pubkey ( & prev_hop_id) ;
976
978
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
977
- . checked_add ( scorer. channel_penalty_msat ( hop. short_channel_id ) )
979
+ . checked_add ( scorer. channel_penalty_msat ( hop. short_channel_id , & src_node_id , & dest_node_id ) )
978
980
. unwrap_or_else ( || u64:: max_value ( ) ) ;
979
981
980
982
// We assume that the recipient only included route hints for routes which had
981
983
// sufficient value to route `final_value_msat`. Note that in the case of "0-value"
982
984
// invoices where the invoice does not specify value this may not be the case, but
983
985
// better to include the hints than not.
984
- if !add_entry ! ( hop. short_channel_id, NodeId :: from_pubkey ( & hop . src_node_id) , NodeId :: from_pubkey ( & prev_hop_id ) , directional_info, reqd_channel_cap, & empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
986
+ if !add_entry ! ( hop. short_channel_id, src_node_id, dest_node_id , directional_info, reqd_channel_cap, & empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
985
987
// If this hop was not used then there is no use checking the preceding hops
986
988
// in the RouteHint. We can break by just searching for a direct channel between
987
989
// last checked hop and first_hop_targets
@@ -1322,7 +1324,8 @@ where L::Target: Logger {
1322
1324
1323
1325
#[ cfg( test) ]
1324
1326
mod tests {
1325
- use routing:: network_graph:: { NetworkGraph , NetGraphMsgHandler } ;
1327
+ use routing;
1328
+ use routing:: network_graph:: { NetworkGraph , NetGraphMsgHandler , NodeId } ;
1326
1329
use routing:: router:: { get_route, Route , RouteHint , RouteHintHop , RouteHop , RoutingFees } ;
1327
1330
use routing:: scorer:: Scorer ;
1328
1331
use chain:: transaction:: OutPoint ;
@@ -4377,6 +4380,68 @@ mod tests {
4377
4380
assert_eq ! ( path, vec![ 2 , 4 , 7 , 10 ] ) ;
4378
4381
}
4379
4382
4383
+ struct BadChannelScorer {
4384
+ short_channel_id : u64 ,
4385
+ }
4386
+
4387
+ impl routing:: Score for BadChannelScorer {
4388
+ fn channel_penalty_msat ( & self , short_channel_id : u64 , _source : & NodeId , _target : & NodeId ) -> u64 {
4389
+ if short_channel_id == self . short_channel_id { u64:: max_value ( ) } else { 0 }
4390
+ }
4391
+ }
4392
+
4393
+ struct BadNodeScorer {
4394
+ node_id : NodeId ,
4395
+ }
4396
+
4397
+ impl routing:: Score for BadNodeScorer {
4398
+ fn channel_penalty_msat ( & self , _short_channel_id : u64 , _source : & NodeId , target : & NodeId ) -> u64 {
4399
+ if * target == self . node_id { u64:: max_value ( ) } else { 0 }
4400
+ }
4401
+ }
4402
+
4403
+ #[ test]
4404
+ fn avoids_routing_through_bad_channels_and_nodes ( ) {
4405
+ let ( secp_ctx, net_graph_msg_handler, _, logger) = build_graph ( ) ;
4406
+ let ( _, our_id, _, nodes) = get_nodes ( & secp_ctx) ;
4407
+
4408
+ // A path to nodes[6] exists when no penalties are applied to any channel.
4409
+ let scorer = Scorer :: new ( 0 ) ;
4410
+ let route = get_route (
4411
+ & our_id, & net_graph_msg_handler. network_graph , & nodes[ 6 ] , None , None ,
4412
+ & last_hops ( & nodes) . iter ( ) . collect :: < Vec < _ > > ( ) , 100 , 42 , Arc :: clone ( & logger) , & scorer
4413
+ ) . unwrap ( ) ;
4414
+ let path = route. paths [ 0 ] . iter ( ) . map ( |hop| hop. short_channel_id ) . collect :: < Vec < _ > > ( ) ;
4415
+
4416
+ assert_eq ! ( route. get_total_fees( ) , 100 ) ;
4417
+ assert_eq ! ( route. get_total_amount( ) , 100 ) ;
4418
+ assert_eq ! ( path, vec![ 2 , 4 , 6 , 11 , 8 ] ) ;
4419
+
4420
+ // A different path to nodes[6] exists if channel 6 cannot be routed over.
4421
+ let scorer = BadChannelScorer { short_channel_id : 6 } ;
4422
+ let route = get_route (
4423
+ & our_id, & net_graph_msg_handler. network_graph , & nodes[ 6 ] , None , None ,
4424
+ & last_hops ( & nodes) . iter ( ) . collect :: < Vec < _ > > ( ) , 100 , 42 , Arc :: clone ( & logger) , & scorer
4425
+ ) . unwrap ( ) ;
4426
+ let path = route. paths [ 0 ] . iter ( ) . map ( |hop| hop. short_channel_id ) . collect :: < Vec < _ > > ( ) ;
4427
+
4428
+ assert_eq ! ( route. get_total_fees( ) , 300 ) ;
4429
+ assert_eq ! ( route. get_total_amount( ) , 100 ) ;
4430
+ assert_eq ! ( path, vec![ 2 , 4 , 7 , 10 ] ) ;
4431
+
4432
+ // A path to nodes[6] does not exist if nodes[2] cannot be routed through.
4433
+ let scorer = BadNodeScorer { node_id : NodeId :: from_pubkey ( & nodes[ 2 ] ) } ;
4434
+ match get_route (
4435
+ & our_id, & net_graph_msg_handler. network_graph , & nodes[ 6 ] , None , None ,
4436
+ & last_hops ( & nodes) . iter ( ) . collect :: < Vec < _ > > ( ) , 100 , 42 , Arc :: clone ( & logger) , & scorer
4437
+ ) {
4438
+ Err ( LightningError { err, .. } ) => {
4439
+ assert_eq ! ( err, "Failed to find a path to the given destination" ) ;
4440
+ } ,
4441
+ Ok ( _) => panic ! ( "Expected error" ) ,
4442
+ }
4443
+ }
4444
+
4380
4445
#[ test]
4381
4446
fn total_fees_single_path ( ) {
4382
4447
let route = Route {
0 commit comments