@@ -134,18 +134,27 @@ impl<'a, SP: Sized, Sc: ScoreLookUp<ScoreParams = SP>, S: Deref<Target = Sc>> Wr
134
134
135
135
impl < ' a , SP : Sized , Sc : ' a + ScoreLookUp < ScoreParams = SP > , S : Deref < Target = Sc > > ScoreLookUp for ScorerAccountingForInFlightHtlcs < ' a , SP , Sc , S > {
136
136
type ScoreParams = Sc :: ScoreParams ;
137
- fn channel_penalty_msat ( & self , short_channel_id : u64 , source : & NodeId , target : & NodeId , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
137
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
138
+ let target = match candidate. target ( ) {
139
+ Some ( target) => target,
140
+ None => return self . scorer . channel_penalty_msat ( candidate, usage, score_params) ,
141
+ } ;
142
+ let short_channel_id = match candidate. short_channel_id ( ) {
143
+ Some ( short_channel_id) => short_channel_id,
144
+ None => return self . scorer . channel_penalty_msat ( candidate, usage, score_params) ,
145
+ } ;
146
+ let source = candidate. source ( ) ;
138
147
if let Some ( used_liquidity) = self . inflight_htlcs . used_liquidity_msat (
139
- source, target, short_channel_id
148
+ & source, & target, short_channel_id
140
149
) {
141
150
let usage = ChannelUsage {
142
151
inflight_htlc_msat : usage. inflight_htlc_msat . saturating_add ( used_liquidity) ,
143
152
..usage
144
153
} ;
145
154
146
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
155
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
147
156
} else {
148
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
157
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
149
158
}
150
159
}
151
160
}
@@ -1961,9 +1970,10 @@ where L::Target: Logger {
1961
1970
inflight_htlc_msat: used_liquidity_msat,
1962
1971
effective_capacity,
1963
1972
} ;
1964
- let channel_penalty_msat = scid_opt. map_or( 0 ,
1965
- |scid| scorer. channel_penalty_msat( scid, & $src_node_id, & $dest_node_id,
1966
- channel_usage, score_params) ) ;
1973
+ let channel_penalty_msat =
1974
+ scorer. channel_penalty_msat( $candidate,
1975
+ channel_usage,
1976
+ score_params) ;
1967
1977
let path_penalty_msat = $next_hops_path_penalty_msat
1968
1978
. saturating_add( channel_penalty_msat) ;
1969
1979
let new_graph_node = RouteGraphNode {
@@ -2084,7 +2094,7 @@ where L::Target: Logger {
2084
2094
if let Some ( first_channels) = first_hop_targets. get( & $node_id) {
2085
2095
for details in first_channels {
2086
2096
let candidate = CandidateRouteHop :: FirstHop { details, node_id: our_node_id } ;
2087
- add_entry!( candidate, our_node_id, $node_id, $fee_to_target_msat,
2097
+ add_entry!( & candidate, our_node_id, $node_id, $fee_to_target_msat,
2088
2098
$next_hops_value_contribution,
2089
2099
$next_hops_path_htlc_minimum_msat, $next_hops_path_penalty_msat,
2090
2100
$next_hops_cltv_delta, $next_hops_path_length) ;
@@ -2110,7 +2120,7 @@ where L::Target: Logger {
2110
2120
source_node_id: * source,
2111
2121
target_node_id: $node_id,
2112
2122
} ;
2113
- add_entry!( candidate, * source, $node_id,
2123
+ add_entry!( & candidate, * source, $node_id,
2114
2124
$fee_to_target_msat,
2115
2125
$next_hops_value_contribution,
2116
2126
$next_hops_path_htlc_minimum_msat,
@@ -2141,7 +2151,7 @@ where L::Target: Logger {
2141
2151
payee_node_id_opt. map ( |payee| first_hop_targets. get ( & payee) . map ( |first_channels| {
2142
2152
for details in first_channels {
2143
2153
let candidate = CandidateRouteHop :: FirstHop { details, node_id : our_node_id } ;
2144
- let added = add_entry ! ( candidate, our_node_id, payee, 0 , path_value_msat,
2154
+ let added = add_entry ! ( & candidate, our_node_id, payee, 0 , path_value_msat,
2145
2155
0 , 0u64 , 0 , 0 ) . is_some ( ) ;
2146
2156
log_trace ! ( logger, "{} direct route to payee via {}" ,
2147
2157
if added { "Added" } else { "Skipped" } , LoggedCandidateHop ( & candidate) ) ;
@@ -2178,7 +2188,7 @@ where L::Target: Logger {
2178
2188
CandidateRouteHop :: OneHopBlinded { hint, hint_idx }
2179
2189
} else { CandidateRouteHop :: Blinded { hint, hint_idx } } ;
2180
2190
let mut path_contribution_msat = path_value_msat;
2181
- if let Some ( hop_used_msat) = add_entry ! ( candidate, intro_node_id, maybe_dummy_payee_node_id,
2191
+ if let Some ( hop_used_msat) = add_entry ! ( & candidate, intro_node_id, maybe_dummy_payee_node_id,
2182
2192
0 , path_contribution_msat, 0 , 0_u64 , 0 , 0 )
2183
2193
{
2184
2194
path_contribution_msat = hop_used_msat;
@@ -2194,7 +2204,7 @@ where L::Target: Logger {
2194
2204
} ;
2195
2205
let path_min = candidate. htlc_minimum_msat ( ) . saturating_add (
2196
2206
compute_fees_saturating ( candidate. htlc_minimum_msat ( ) , candidate. fees ( ) ) ) ;
2197
- add_entry ! ( first_hop_candidate, our_node_id, intro_node_id, blinded_path_fee,
2207
+ add_entry ! ( & first_hop_candidate, our_node_id, intro_node_id, blinded_path_fee,
2198
2208
path_contribution_msat, path_min, 0_u64 , candidate. cltv_expiry_delta( ) ,
2199
2209
candidate. blinded_path( ) . map_or( 1 , |bp| bp. blinded_hops. len( ) as u8 ) ) ;
2200
2210
}
@@ -2248,7 +2258,7 @@ where L::Target: Logger {
2248
2258
} )
2249
2259
. unwrap_or_else ( || CandidateRouteHop :: PrivateHop { hint : hop, target_node_id : target } ) ;
2250
2260
2251
- if let Some ( hop_used_msat) = add_entry ! ( candidate, source, target,
2261
+ if let Some ( hop_used_msat) = add_entry ! ( & candidate, source, target,
2252
2262
aggregate_next_hops_fee_msat, aggregate_path_contribution_msat,
2253
2263
aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat,
2254
2264
aggregate_next_hops_cltv_delta, aggregate_next_hops_path_length)
@@ -2270,7 +2280,7 @@ where L::Target: Logger {
2270
2280
effective_capacity : candidate. effective_capacity ( ) ,
2271
2281
} ;
2272
2282
let channel_penalty_msat = scorer. channel_penalty_msat (
2273
- hop . short_channel_id , & source , & target , channel_usage, score_params
2283
+ & candidate , channel_usage, score_params
2274
2284
) ;
2275
2285
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
2276
2286
. saturating_add ( channel_penalty_msat) ;
@@ -2287,7 +2297,7 @@ where L::Target: Logger {
2287
2297
recommended_value_msat, our_node_pubkey) ;
2288
2298
for details in first_channels {
2289
2299
let first_hop_candidate = CandidateRouteHop :: FirstHop { details, node_id : our_node_id} ;
2290
- add_entry ! ( first_hop_candidate, our_node_id, NodeId :: from_pubkey( & prev_hop_id) ,
2300
+ add_entry ! ( & first_hop_candidate, our_node_id, NodeId :: from_pubkey( & prev_hop_id) ,
2291
2301
aggregate_next_hops_fee_msat, aggregate_path_contribution_msat,
2292
2302
aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat,
2293
2303
aggregate_next_hops_cltv_delta, aggregate_next_hops_path_length) ;
@@ -2332,7 +2342,7 @@ where L::Target: Logger {
2332
2342
recommended_value_msat, our_node_pubkey) ;
2333
2343
for details in first_channels {
2334
2344
let first_hop_candidate = CandidateRouteHop :: FirstHop { details, node_id : our_node_id} ;
2335
- add_entry ! ( first_hop_candidate, our_node_id,
2345
+ add_entry ! ( & first_hop_candidate, our_node_id,
2336
2346
NodeId :: from_pubkey( & hop. src_node_id) ,
2337
2347
aggregate_next_hops_fee_msat,
2338
2348
aggregate_path_contribution_msat,
@@ -2829,13 +2839,18 @@ fn build_route_from_hops_internal<L: Deref>(
2829
2839
2830
2840
impl ScoreLookUp for HopScorer {
2831
2841
type ScoreParams = ( ) ;
2832
- fn channel_penalty_msat ( & self , _short_channel_id : u64 , source : & NodeId , target : & NodeId ,
2842
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop ,
2833
2843
_usage : ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64
2834
2844
{
2845
+ let target = match candidate. target ( ) {
2846
+ Some ( target) => target,
2847
+ None => return 0 ,
2848
+ } ;
2849
+ let source = candidate. source ( ) ;
2835
2850
let mut cur_id = self . our_node_id ;
2836
2851
for i in 0 ..self . hop_ids . len ( ) {
2837
2852
if let Some ( next_id) = self . hop_ids [ i] {
2838
- if cur_id == * source && next_id == * target {
2853
+ if cur_id == source && next_id == target {
2839
2854
return 0 ;
2840
2855
}
2841
2856
cur_id = next_id;
@@ -2911,6 +2926,8 @@ mod tests {
2911
2926
2912
2927
use core:: convert:: TryInto ;
2913
2928
2929
+ use super :: CandidateRouteHop ;
2930
+
2914
2931
fn get_channel_details ( short_channel_id : Option < u64 > , node_id : PublicKey ,
2915
2932
features : InitFeatures , outbound_capacity_msat : u64 ) -> channelmanager:: ChannelDetails {
2916
2933
channelmanager:: ChannelDetails {
@@ -6164,7 +6181,11 @@ mod tests {
6164
6181
}
6165
6182
impl ScoreLookUp for BadChannelScorer {
6166
6183
type ScoreParams = ( ) ;
6167
- fn channel_penalty_msat ( & self , short_channel_id : u64 , _: & NodeId , _: & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6184
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6185
+ let short_channel_id = match candidate. short_channel_id ( ) {
6186
+ Some ( id) => id,
6187
+ None => return 0 ,
6188
+ } ;
6168
6189
if short_channel_id == self . short_channel_id { u64:: max_value ( ) } else { 0 }
6169
6190
}
6170
6191
}
@@ -6180,8 +6201,12 @@ mod tests {
6180
6201
6181
6202
impl ScoreLookUp for BadNodeScorer {
6182
6203
type ScoreParams = ( ) ;
6183
- fn channel_penalty_msat ( & self , _: u64 , _: & NodeId , target : & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6184
- if * target == self . node_id { u64:: max_value ( ) } else { 0 }
6204
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6205
+ let target = match candidate. target ( ) {
6206
+ Some ( target) => target,
6207
+ None => return 0 ,
6208
+ } ;
6209
+ if target == self . node_id { u64:: max_value ( ) } else { 0 }
6185
6210
}
6186
6211
}
6187
6212
@@ -6667,26 +6692,34 @@ mod tests {
6667
6692
} ;
6668
6693
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) , 123 ) ;
6669
6694
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 4 ] ) , 456 ) ;
6670
- assert_eq ! ( scorer. channel_penalty_msat( 42 , & NodeId :: from_pubkey( & nodes[ 3 ] ) , & NodeId :: from_pubkey( & nodes[ 4 ] ) , usage, & scorer_params) , 456 ) ;
6695
+ let network_graph = network_graph. read_only ( ) ;
6696
+ let channels = network_graph. channels ( ) ;
6697
+ let channel = channels. get ( & 5 ) . unwrap ( ) ;
6698
+ let info = channel. as_directed_from ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) . unwrap ( ) ;
6699
+ let candidate: CandidateRouteHop = CandidateRouteHop :: PublicHop {
6700
+ info : info. 0 ,
6701
+ short_channel_id : 5 ,
6702
+ source_node_id : NodeId :: from_pubkey ( & nodes[ 3 ] ) ,
6703
+ target_node_id : NodeId :: from_pubkey ( & nodes[ 4 ] ) ,
6704
+ } ;
6705
+ assert_eq ! ( scorer. channel_penalty_msat( & candidate, usage, & scorer_params) , 456 ) ;
6671
6706
6672
6707
// Then check we can get a normal route
6673
6708
let payment_params = PaymentParameters :: from_node_id ( nodes[ 10 ] , 42 ) ;
6674
6709
let route_params = RouteParameters :: from_payment_params_and_value (
6675
6710
payment_params, 100 ) ;
6676
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6711
+ let route = get_route ( & our_id, & route_params, & network_graph, None ,
6677
6712
Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6678
6713
assert ! ( route. is_ok( ) ) ;
6679
6714
6680
6715
// Then check that we can't get a route if we ban an intermediate node.
6681
6716
scorer_params. add_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6682
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6683
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6717
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6684
6718
assert ! ( route. is_err( ) ) ;
6685
6719
6686
6720
// Finally make sure we can route again, when we remove the ban.
6687
6721
scorer_params. remove_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6688
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6689
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6722
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6690
6723
assert ! ( route. is_ok( ) ) ;
6691
6724
}
6692
6725
0 commit comments