@@ -748,7 +748,7 @@ where L::Target: Logger {
748
748
}
749
749
750
750
let path_penalty_msat = $next_hops_path_penalty_msat
751
- . checked_add( scorer. channel_penalty_msat( $chan_id. clone( ) ) )
751
+ . checked_add( scorer. channel_penalty_msat( $chan_id. clone( ) , & $src_node_id , & $dest_node_id ) )
752
752
. unwrap_or_else( || u64 :: max_value( ) ) ;
753
753
let new_graph_node = RouteGraphNode {
754
754
node_id: $src_node_id,
@@ -973,15 +973,17 @@ where L::Target: Logger {
973
973
_ => aggregate_next_hops_fee_msat. checked_add ( 999 ) . unwrap_or ( u64:: max_value ( ) )
974
974
} ) { Some ( val / 1000 ) } else { break ; } ; // converting from msat or breaking if max ~ infinity
975
975
976
+ let src_node_id = NodeId :: from_pubkey ( & hop. src_node_id ) ;
977
+ let dest_node_id = NodeId :: from_pubkey ( & prev_hop_id) ;
976
978
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
977
- . checked_add ( scorer. channel_penalty_msat ( hop. short_channel_id ) )
979
+ . checked_add ( scorer. channel_penalty_msat ( hop. short_channel_id , & src_node_id , & dest_node_id ) )
978
980
. unwrap_or_else ( || u64:: max_value ( ) ) ;
979
981
980
982
// We assume that the recipient only included route hints for routes which had
981
983
// sufficient value to route `final_value_msat`. Note that in the case of "0-value"
982
984
// invoices where the invoice does not specify value this may not be the case, but
983
985
// better to include the hints than not.
984
- if !add_entry ! ( hop. short_channel_id, NodeId :: from_pubkey ( & hop . src_node_id) , NodeId :: from_pubkey ( & prev_hop_id ) , directional_info, reqd_channel_cap, & empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
986
+ if !add_entry ! ( hop. short_channel_id, src_node_id, dest_node_id , directional_info, reqd_channel_cap, & empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
985
987
// If this hop was not used then there is no use checking the preceding hops
986
988
// in the RouteHint. We can break by just searching for a direct channel between
987
989
// last checked hop and first_hop_targets
@@ -1322,7 +1324,8 @@ where L::Target: Logger {
1322
1324
1323
1325
#[ cfg( test) ]
1324
1326
mod tests {
1325
- use routing:: network_graph:: { NetworkGraph , NetGraphMsgHandler } ;
1327
+ use routing;
1328
+ use routing:: network_graph:: { NetworkGraph , NetGraphMsgHandler , NodeId } ;
1326
1329
use routing:: router:: { get_route, Route , RouteHint , RouteHintHop , RouteHop , RoutingFees } ;
1327
1330
use routing:: scorer:: Scorer ;
1328
1331
use chain:: transaction:: OutPoint ;
@@ -4351,42 +4354,92 @@ mod tests {
4351
4354
let ( secp_ctx, net_graph_msg_handler, _, logger) = build_graph ( ) ;
4352
4355
let ( _, our_id, _, nodes) = get_nodes ( & secp_ctx) ;
4353
4356
4357
+ // Without penalizing each hop 100 msats, a longer path with lower fees is chosen.
4358
+ let scorer = Scorer :: new ( 0 ) ;
4359
+ let route = get_route (
4360
+ & our_id, & net_graph_msg_handler. network_graph , & nodes[ 6 ] , None , None ,
4361
+ & last_hops ( & nodes) . iter ( ) . collect :: < Vec < _ > > ( ) , 100 , 42 , Arc :: clone ( & logger) , & scorer
4362
+ ) . unwrap ( ) ;
4363
+ let path = route. paths [ 0 ] . iter ( ) . map ( |hop| hop. short_channel_id ) . collect :: < Vec < _ > > ( ) ;
4364
+
4365
+ assert_eq ! ( route. get_total_fees( ) , 100 ) ;
4366
+ assert_eq ! ( route. get_total_amount( ) , 100 ) ;
4367
+ assert_eq ! ( path, vec![ 2 , 4 , 6 , 11 , 8 ] ) ;
4368
+
4354
4369
// Applying a 100 msat penalty to each hop results in taking channels 7 and 10 to nodes[6]
4355
4370
// from nodes[2] rather than channel 6, 11, and 8, even though the longer path is cheaper.
4356
4371
let scorer = Scorer :: new ( 100 ) ;
4357
- let route = get_route ( & our_id, & net_graph_msg_handler. network_graph , & nodes[ 6 ] , None , None , & last_hops ( & nodes) . iter ( ) . collect :: < Vec < _ > > ( ) , 100 , 42 , Arc :: clone ( & logger) , & scorer) . unwrap ( ) ;
4358
- assert_eq ! ( route. paths[ 0 ] . len( ) , 4 ) ;
4372
+ let route = get_route (
4373
+ & our_id, & net_graph_msg_handler. network_graph , & nodes[ 6 ] , None , None ,
4374
+ & last_hops ( & nodes) . iter ( ) . collect :: < Vec < _ > > ( ) , 100 , 42 , Arc :: clone ( & logger) , & scorer
4375
+ ) . unwrap ( ) ;
4376
+ let path = route. paths [ 0 ] . iter ( ) . map ( |hop| hop. short_channel_id ) . collect :: < Vec < _ > > ( ) ;
4359
4377
4360
- assert_eq ! ( route. paths[ 0 ] [ 0 ] . pubkey, nodes[ 1 ] ) ;
4361
- assert_eq ! ( route. paths[ 0 ] [ 0 ] . short_channel_id, 2 ) ;
4362
- assert_eq ! ( route. paths[ 0 ] [ 0 ] . fee_msat, 200 ) ;
4363
- assert_eq ! ( route. paths[ 0 ] [ 0 ] . cltv_expiry_delta, ( 4 << 8 ) | 1 ) ;
4364
- assert_eq ! ( route. paths[ 0 ] [ 0 ] . node_features. le_flags( ) , & id_to_feature_flags( 2 ) ) ;
4365
- assert_eq ! ( route. paths[ 0 ] [ 0 ] . channel_features. le_flags( ) , & id_to_feature_flags( 2 ) ) ;
4378
+ assert_eq ! ( route. get_total_fees( ) , 300 ) ;
4379
+ assert_eq ! ( route. get_total_amount( ) , 100 ) ;
4380
+ assert_eq ! ( path, vec![ 2 , 4 , 7 , 10 ] ) ;
4381
+ }
4366
4382
4367
- assert_eq ! ( route. paths[ 0 ] [ 1 ] . pubkey, nodes[ 2 ] ) ;
4368
- assert_eq ! ( route. paths[ 0 ] [ 1 ] . short_channel_id, 4 ) ;
4369
- assert_eq ! ( route. paths[ 0 ] [ 1 ] . fee_msat, 100 ) ;
4370
- assert_eq ! ( route. paths[ 0 ] [ 1 ] . cltv_expiry_delta, ( 7 << 8 ) | 1 ) ;
4371
- assert_eq ! ( route. paths[ 0 ] [ 1 ] . node_features. le_flags( ) , & id_to_feature_flags( 3 ) ) ;
4372
- assert_eq ! ( route. paths[ 0 ] [ 1 ] . channel_features. le_flags( ) , & id_to_feature_flags( 4 ) ) ;
4383
+ struct BadChannelScorer {
4384
+ short_channel_id : u64 ,
4385
+ }
4373
4386
4374
- assert_eq ! ( route. paths[ 0 ] [ 2 ] . pubkey, nodes[ 5 ] ) ;
4375
- assert_eq ! ( route. paths[ 0 ] [ 2 ] . short_channel_id, 7 ) ;
4376
- assert_eq ! ( route. paths[ 0 ] [ 2 ] . fee_msat, 0 ) ;
4377
- assert_eq ! ( route. paths[ 0 ] [ 2 ] . cltv_expiry_delta, ( 10 << 8 ) | 1 ) ;
4378
- assert_eq ! ( route. paths[ 0 ] [ 2 ] . node_features. le_flags( ) , & id_to_feature_flags( 6 ) ) ;
4379
- assert_eq ! ( route. paths[ 0 ] [ 2 ] . channel_features. le_flags( ) , & id_to_feature_flags( 7 ) ) ;
4387
+ impl routing:: Score for BadChannelScorer {
4388
+ fn channel_penalty_msat ( & self , short_channel_id : u64 , _source : & NodeId , _target : & NodeId ) -> u64 {
4389
+ if short_channel_id == self . short_channel_id { u64:: max_value ( ) } else { 0 }
4390
+ }
4391
+ }
4380
4392
4381
- assert_eq ! ( route. paths[ 0 ] [ 3 ] . pubkey, nodes[ 6 ] ) ;
4382
- assert_eq ! ( route. paths[ 0 ] [ 3 ] . short_channel_id, 10 ) ;
4383
- assert_eq ! ( route. paths[ 0 ] [ 3 ] . fee_msat, 100 ) ;
4384
- assert_eq ! ( route. paths[ 0 ] [ 3 ] . cltv_expiry_delta, 42 ) ;
4385
- assert_eq ! ( route. paths[ 0 ] [ 3 ] . node_features. le_flags( ) , & Vec :: <u8 >:: new( ) ) ; // We don't pass flags in from invoices yet
4386
- assert_eq ! ( route. paths[ 0 ] [ 3 ] . channel_features. le_flags( ) , & Vec :: <u8 >:: new( ) ) ; // We can't learn any flags from invoices, sadly
4393
+ struct BadNodeScorer {
4394
+ node_id : NodeId ,
4395
+ }
4396
+
4397
+ impl routing:: Score for BadNodeScorer {
4398
+ fn channel_penalty_msat ( & self , _short_channel_id : u64 , _source : & NodeId , target : & NodeId ) -> u64 {
4399
+ if * target == self . node_id { u64:: max_value ( ) } else { 0 }
4400
+ }
4401
+ }
4402
+
4403
+ #[ test]
4404
+ fn avoids_routing_through_bad_channels_and_nodes ( ) {
4405
+ let ( secp_ctx, net_graph_msg_handler, _, logger) = build_graph ( ) ;
4406
+ let ( _, our_id, _, nodes) = get_nodes ( & secp_ctx) ;
4407
+
4408
+ // A path to nodes[6] exists when no penalties are applied to any channel.
4409
+ let scorer = Scorer :: new ( 0 ) ;
4410
+ let route = get_route (
4411
+ & our_id, & net_graph_msg_handler. network_graph , & nodes[ 6 ] , None , None ,
4412
+ & last_hops ( & nodes) . iter ( ) . collect :: < Vec < _ > > ( ) , 100 , 42 , Arc :: clone ( & logger) , & scorer
4413
+ ) . unwrap ( ) ;
4414
+ let path = route. paths [ 0 ] . iter ( ) . map ( |hop| hop. short_channel_id ) . collect :: < Vec < _ > > ( ) ;
4415
+
4416
+ assert_eq ! ( route. get_total_fees( ) , 100 ) ;
4417
+ assert_eq ! ( route. get_total_amount( ) , 100 ) ;
4418
+ assert_eq ! ( path, vec![ 2 , 4 , 6 , 11 , 8 ] ) ;
4419
+
4420
+ // A different path to nodes[6] exists if channel 6 cannot be routed over.
4421
+ let scorer = BadChannelScorer { short_channel_id : 6 } ;
4422
+ let route = get_route (
4423
+ & our_id, & net_graph_msg_handler. network_graph , & nodes[ 6 ] , None , None ,
4424
+ & last_hops ( & nodes) . iter ( ) . collect :: < Vec < _ > > ( ) , 100 , 42 , Arc :: clone ( & logger) , & scorer
4425
+ ) . unwrap ( ) ;
4426
+ let path = route. paths [ 0 ] . iter ( ) . map ( |hop| hop. short_channel_id ) . collect :: < Vec < _ > > ( ) ;
4387
4427
4388
4428
assert_eq ! ( route. get_total_fees( ) , 300 ) ;
4389
4429
assert_eq ! ( route. get_total_amount( ) , 100 ) ;
4430
+ assert_eq ! ( path, vec![ 2 , 4 , 7 , 10 ] ) ;
4431
+
4432
+ // A path to nodes[6] does not exist if nodes[2] cannot be routed through.
4433
+ let scorer = BadNodeScorer { node_id : NodeId :: from_pubkey ( & nodes[ 2 ] ) } ;
4434
+ match get_route (
4435
+ & our_id, & net_graph_msg_handler. network_graph , & nodes[ 6 ] , None , None ,
4436
+ & last_hops ( & nodes) . iter ( ) . collect :: < Vec < _ > > ( ) , 100 , 42 , Arc :: clone ( & logger) , & scorer
4437
+ ) {
4438
+ Err ( LightningError { err, .. } ) => {
4439
+ assert_eq ! ( err, "Failed to find a path to the given destination" ) ;
4440
+ } ,
4441
+ Ok ( _) => panic ! ( "Expected error" ) ,
4442
+ }
4390
4443
}
4391
4444
4392
4445
#[ test]
0 commit comments