@@ -129,18 +129,24 @@ impl<'a, S: Deref> ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: Scor
129
129
130
130
impl < ' a , S : Deref > ScoreLookUp for ScorerAccountingForInFlightHtlcs < ' a , S > where S :: Target : ScoreLookUp {
131
131
type ScoreParams = <S :: Target as ScoreLookUp >:: ScoreParams ;
132
- fn channel_penalty_msat ( & self , short_channel_id : u64 , source : & NodeId , target : & NodeId , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
132
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
133
+ let target = candidate. target ( ) ;
134
+ let short_channel_id = match candidate. short_channel_id ( ) {
135
+ Some ( short_channel_id) => short_channel_id,
136
+ None => return self . scorer . channel_penalty_msat ( candidate, usage, score_params) ,
137
+ } ;
138
+ let source = candidate. source ( ) ;
133
139
if let Some ( used_liquidity) = self . inflight_htlcs . used_liquidity_msat (
134
- source, target, short_channel_id
140
+ & source, & target, short_channel_id
135
141
) {
136
142
let usage = ChannelUsage {
137
143
inflight_htlc_msat : usage. inflight_htlc_msat . saturating_add ( used_liquidity) ,
138
144
..usage
139
145
} ;
140
146
141
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
147
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
142
148
} else {
143
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
149
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
144
150
}
145
151
}
146
152
}
@@ -1974,9 +1980,10 @@ where L::Target: Logger {
1974
1980
inflight_htlc_msat: used_liquidity_msat,
1975
1981
effective_capacity,
1976
1982
} ;
1977
- let channel_penalty_msat = scid_opt. map_or( 0 ,
1978
- |scid| scorer. channel_penalty_msat( scid, & $src_node_id, & $dest_node_id,
1979
- channel_usage, score_params) ) ;
1983
+ let channel_penalty_msat =
1984
+ scorer. channel_penalty_msat( $candidate,
1985
+ channel_usage,
1986
+ score_params) ;
1980
1987
let path_penalty_msat = $next_hops_path_penalty_msat
1981
1988
. saturating_add( channel_penalty_msat) ;
1982
1989
let new_graph_node = RouteGraphNode {
@@ -2097,7 +2104,7 @@ where L::Target: Logger {
2097
2104
if let Some ( first_channels) = first_hop_targets. get( & $node_id) {
2098
2105
for details in first_channels {
2099
2106
let candidate = CandidateRouteHop :: FirstHop { details, node_id: our_node_id } ;
2100
- add_entry!( candidate, our_node_id, $node_id, $fee_to_target_msat,
2107
+ add_entry!( & candidate, our_node_id, $node_id, $fee_to_target_msat,
2101
2108
$next_hops_value_contribution,
2102
2109
$next_hops_path_htlc_minimum_msat, $next_hops_path_penalty_msat,
2103
2110
$next_hops_cltv_delta, $next_hops_path_length) ;
@@ -2123,7 +2130,7 @@ where L::Target: Logger {
2123
2130
source_node_id: * source,
2124
2131
target_node_id: $node_id,
2125
2132
} ;
2126
- add_entry!( candidate, * source, $node_id,
2133
+ add_entry!( & candidate, * source, $node_id,
2127
2134
$fee_to_target_msat,
2128
2135
$next_hops_value_contribution,
2129
2136
$next_hops_path_htlc_minimum_msat,
@@ -2154,7 +2161,7 @@ where L::Target: Logger {
2154
2161
payee_node_id_opt. map ( |payee| first_hop_targets. get ( & payee) . map ( |first_channels| {
2155
2162
for details in first_channels {
2156
2163
let candidate = CandidateRouteHop :: FirstHop { details, node_id : our_node_id } ;
2157
- let added = add_entry ! ( candidate, our_node_id, payee, 0 , path_value_msat,
2164
+ let added = add_entry ! ( & candidate, our_node_id, payee, 0 , path_value_msat,
2158
2165
0 , 0u64 , 0 , 0 ) . is_some ( ) ;
2159
2166
log_trace ! ( logger, "{} direct route to payee via {}" ,
2160
2167
if added { "Added" } else { "Skipped" } , LoggedCandidateHop ( & candidate) ) ;
@@ -2191,7 +2198,7 @@ where L::Target: Logger {
2191
2198
CandidateRouteHop :: OneHopBlinded { hint, hint_idx, target_node_id : maybe_dummy_payee_node_id }
2192
2199
} else { CandidateRouteHop :: Blinded { hint, hint_idx, target_node_id : maybe_dummy_payee_node_id } } ;
2193
2200
let mut path_contribution_msat = path_value_msat;
2194
- if let Some ( hop_used_msat) = add_entry ! ( candidate, intro_node_id, maybe_dummy_payee_node_id,
2201
+ if let Some ( hop_used_msat) = add_entry ! ( & candidate, intro_node_id, maybe_dummy_payee_node_id,
2195
2202
0 , path_contribution_msat, 0 , 0_u64 , 0 , 0 )
2196
2203
{
2197
2204
path_contribution_msat = hop_used_msat;
@@ -2207,7 +2214,7 @@ where L::Target: Logger {
2207
2214
} ;
2208
2215
let path_min = candidate. htlc_minimum_msat ( ) . saturating_add (
2209
2216
compute_fees_saturating ( candidate. htlc_minimum_msat ( ) , candidate. fees ( ) ) ) ;
2210
- add_entry ! ( first_hop_candidate, our_node_id, intro_node_id, blinded_path_fee,
2217
+ add_entry ! ( & first_hop_candidate, our_node_id, intro_node_id, blinded_path_fee,
2211
2218
path_contribution_msat, path_min, 0_u64 , candidate. cltv_expiry_delta( ) ,
2212
2219
candidate. blinded_path( ) . map_or( 1 , |bp| bp. blinded_hops. len( ) as u8 ) ) ;
2213
2220
}
@@ -2261,7 +2268,7 @@ where L::Target: Logger {
2261
2268
} )
2262
2269
. unwrap_or_else ( || CandidateRouteHop :: PrivateHop { hint : hop, target_node_id : target } ) ;
2263
2270
2264
- if let Some ( hop_used_msat) = add_entry ! ( candidate, source, target,
2271
+ if let Some ( hop_used_msat) = add_entry ! ( & candidate, source, target,
2265
2272
aggregate_next_hops_fee_msat, aggregate_path_contribution_msat,
2266
2273
aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat,
2267
2274
aggregate_next_hops_cltv_delta, aggregate_next_hops_path_length)
@@ -2283,7 +2290,7 @@ where L::Target: Logger {
2283
2290
effective_capacity : candidate. effective_capacity ( ) ,
2284
2291
} ;
2285
2292
let channel_penalty_msat = scorer. channel_penalty_msat (
2286
- hop . short_channel_id , & source , & target , channel_usage, score_params
2293
+ & candidate , channel_usage, score_params
2287
2294
) ;
2288
2295
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
2289
2296
. saturating_add ( channel_penalty_msat) ;
@@ -2300,7 +2307,7 @@ where L::Target: Logger {
2300
2307
recommended_value_msat, our_node_pubkey) ;
2301
2308
for details in first_channels {
2302
2309
let first_hop_candidate = CandidateRouteHop :: FirstHop { details, node_id : our_node_id} ;
2303
- add_entry ! ( first_hop_candidate, our_node_id, NodeId :: from_pubkey( & prev_hop_id) ,
2310
+ add_entry ! ( & first_hop_candidate, our_node_id, NodeId :: from_pubkey( & prev_hop_id) ,
2304
2311
aggregate_next_hops_fee_msat, aggregate_path_contribution_msat,
2305
2312
aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat,
2306
2313
aggregate_next_hops_cltv_delta, aggregate_next_hops_path_length) ;
@@ -2345,7 +2352,7 @@ where L::Target: Logger {
2345
2352
recommended_value_msat, our_node_pubkey) ;
2346
2353
for details in first_channels {
2347
2354
let first_hop_candidate = CandidateRouteHop :: FirstHop { details, node_id : our_node_id} ;
2348
- add_entry ! ( first_hop_candidate, our_node_id,
2355
+ add_entry ! ( & first_hop_candidate, our_node_id,
2349
2356
NodeId :: from_pubkey( & hop. src_node_id) ,
2350
2357
aggregate_next_hops_fee_msat,
2351
2358
aggregate_path_contribution_msat,
@@ -2842,13 +2849,15 @@ fn build_route_from_hops_internal<L: Deref>(
2842
2849
2843
2850
impl ScoreLookUp for HopScorer {
2844
2851
type ScoreParams = ( ) ;
2845
- fn channel_penalty_msat ( & self , _short_channel_id : u64 , source : & NodeId , target : & NodeId ,
2852
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop ,
2846
2853
_usage : ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64
2847
2854
{
2855
+ let target = candidate. target ( ) ;
2856
+ let source = candidate. source ( ) ;
2848
2857
let mut cur_id = self . our_node_id ;
2849
2858
for i in 0 ..self . hop_ids . len ( ) {
2850
2859
if let Some ( next_id) = self . hop_ids [ i] {
2851
- if cur_id == * source && next_id == * target {
2860
+ if cur_id == source && next_id == target {
2852
2861
return 0 ;
2853
2862
}
2854
2863
cur_id = next_id;
@@ -2924,6 +2933,8 @@ mod tests {
2924
2933
2925
2934
use core:: convert:: TryInto ;
2926
2935
2936
+ use super :: CandidateRouteHop ;
2937
+
2927
2938
fn get_channel_details ( short_channel_id : Option < u64 > , node_id : PublicKey ,
2928
2939
features : InitFeatures , outbound_capacity_msat : u64 ) -> channelmanager:: ChannelDetails {
2929
2940
channelmanager:: ChannelDetails {
@@ -6177,7 +6188,11 @@ mod tests {
6177
6188
}
6178
6189
impl ScoreLookUp for BadChannelScorer {
6179
6190
type ScoreParams = ( ) ;
6180
- fn channel_penalty_msat ( & self , short_channel_id : u64 , _: & NodeId , _: & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6191
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6192
+ let short_channel_id = match candidate. short_channel_id ( ) {
6193
+ Some ( id) => id,
6194
+ None => return 0 ,
6195
+ } ;
6181
6196
if short_channel_id == self . short_channel_id { u64:: max_value ( ) } else { 0 }
6182
6197
}
6183
6198
}
@@ -6193,8 +6208,9 @@ mod tests {
6193
6208
6194
6209
impl ScoreLookUp for BadNodeScorer {
6195
6210
type ScoreParams = ( ) ;
6196
- fn channel_penalty_msat ( & self , _: u64 , _: & NodeId , target : & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6197
- if * target == self . node_id { u64:: max_value ( ) } else { 0 }
6211
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6212
+ let target = candidate. target ( ) ;
6213
+ if target == self . node_id { u64:: max_value ( ) } else { 0 }
6198
6214
}
6199
6215
}
6200
6216
@@ -6680,26 +6696,34 @@ mod tests {
6680
6696
} ;
6681
6697
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) , 123 ) ;
6682
6698
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 4 ] ) , 456 ) ;
6683
- assert_eq ! ( scorer. channel_penalty_msat( 42 , & NodeId :: from_pubkey( & nodes[ 3 ] ) , & NodeId :: from_pubkey( & nodes[ 4 ] ) , usage, & scorer_params) , 456 ) ;
6699
+ let network_graph = network_graph. read_only ( ) ;
6700
+ let channels = network_graph. channels ( ) ;
6701
+ let channel = channels. get ( & 5 ) . unwrap ( ) ;
6702
+ let info = channel. as_directed_from ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) . unwrap ( ) ;
6703
+ let candidate: CandidateRouteHop = CandidateRouteHop :: PublicHop {
6704
+ info : info. 0 ,
6705
+ short_channel_id : 5 ,
6706
+ source_node_id : NodeId :: from_pubkey ( & nodes[ 3 ] ) ,
6707
+ target_node_id : NodeId :: from_pubkey ( & nodes[ 4 ] ) ,
6708
+ } ;
6709
+ assert_eq ! ( scorer. channel_penalty_msat( & candidate, usage, & scorer_params) , 456 ) ;
6684
6710
6685
6711
// Then check we can get a normal route
6686
6712
let payment_params = PaymentParameters :: from_node_id ( nodes[ 10 ] , 42 ) ;
6687
6713
let route_params = RouteParameters :: from_payment_params_and_value (
6688
6714
payment_params, 100 ) ;
6689
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6715
+ let route = get_route ( & our_id, & route_params, & network_graph, None ,
6690
6716
Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6691
6717
assert ! ( route. is_ok( ) ) ;
6692
6718
6693
6719
// Then check that we can't get a route if we ban an intermediate node.
6694
6720
scorer_params. add_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6695
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6696
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6721
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6697
6722
assert ! ( route. is_err( ) ) ;
6698
6723
6699
6724
// Finally make sure we can route again, when we remove the ban.
6700
6725
scorer_params. remove_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6701
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6702
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6726
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6703
6727
assert ! ( route. is_ok( ) ) ;
6704
6728
}
6705
6729
0 commit comments