@@ -130,18 +130,27 @@ impl<'a, S: Deref> ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: Scor
130
130
131
131
impl < ' a , S : Deref > ScoreLookUp for ScorerAccountingForInFlightHtlcs < ' a , S > where S :: Target : ScoreLookUp {
132
132
type ScoreParams = <S :: Target as ScoreLookUp >:: ScoreParams ;
133
- fn channel_penalty_msat ( & self , short_channel_id : u64 , source : & NodeId , target : & NodeId , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
133
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
134
+ let target = match candidate. target ( ) {
135
+ Some ( target) => target,
136
+ None => return self . scorer . channel_penalty_msat ( candidate, usage, score_params) ,
137
+ } ;
138
+ let short_channel_id = match candidate. short_channel_id ( ) {
139
+ Some ( short_channel_id) => short_channel_id,
140
+ None => return self . scorer . channel_penalty_msat ( candidate, usage, score_params) ,
141
+ } ;
142
+ let source = candidate. source ( ) ;
134
143
if let Some ( used_liquidity) = self . inflight_htlcs . used_liquidity_msat (
135
- source, target, short_channel_id
144
+ & source, & target, short_channel_id
136
145
) {
137
146
let usage = ChannelUsage {
138
147
inflight_htlc_msat : usage. inflight_htlc_msat . saturating_add ( used_liquidity) ,
139
148
..usage
140
149
} ;
141
150
142
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
151
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
143
152
} else {
144
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
153
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
145
154
}
146
155
}
147
156
}
@@ -1072,7 +1081,7 @@ impl<'a> CandidateRouteHop<'a> {
1072
1081
/// For `Blinded` and `OneHopBlinded` we return `None` because next hop is not known.
1073
1082
pub fn short_channel_id ( & self ) -> Option < u64 > {
1074
1083
match self {
1075
- CandidateRouteHop :: FirstHop { details, .. } => Some ( details. get_outbound_payment_scid ( ) . unwrap ( ) ) ,
1084
+ CandidateRouteHop :: FirstHop { details, .. } => details. get_outbound_payment_scid ( ) ,
1076
1085
CandidateRouteHop :: PublicHop { short_channel_id, .. } => Some ( * short_channel_id) ,
1077
1086
CandidateRouteHop :: PrivateHop { hint, .. } => Some ( hint. short_channel_id ) ,
1078
1087
CandidateRouteHop :: Blinded { .. } => None ,
@@ -1177,7 +1186,7 @@ impl<'a> CandidateRouteHop<'a> {
1177
1186
CandidateRouteHop :: PublicHop { source_node_id, .. } => * source_node_id,
1178
1187
CandidateRouteHop :: PrivateHop { hint, .. } => hint. src_node_id . into ( ) ,
1179
1188
CandidateRouteHop :: Blinded { hint, .. } => hint. 1 . introduction_node_id . into ( ) ,
1180
- CandidateRouteHop :: OneHopBlinded { hint, .. } => hint. 1 . introduction_node_id . into ( )
1189
+ CandidateRouteHop :: OneHopBlinded { hint, .. } => hint. 1 . introduction_node_id . into ( ) ,
1181
1190
}
1182
1191
}
1183
1192
/// Returns the target node id of this hop, if known.
@@ -1929,7 +1938,7 @@ where L::Target: Logger {
1929
1938
) ;
1930
1939
let path_htlc_minimum_msat = compute_fees_saturating( curr_min, $candidate. fees( ) )
1931
1940
. saturating_add( curr_min) ;
1932
- let hm_entry = dist. entry( & src_node_id) ;
1941
+ let hm_entry = dist. entry( src_node_id) ;
1933
1942
let old_entry = hm_entry. or_insert_with( || {
1934
1943
// If there was previously no known way to access the source node
1935
1944
// (recall it goes payee-to-payer) of short_channel_id, first add a
@@ -1985,9 +1994,10 @@ where L::Target: Logger {
1985
1994
inflight_htlc_msat: used_liquidity_msat,
1986
1995
effective_capacity,
1987
1996
} ;
1988
- let channel_penalty_msat = scid_opt. map_or( 0 ,
1989
- |scid| scorer. channel_penalty_msat( scid, & src_node_id, & dest_node_id,
1990
- channel_usage, score_params) ) ;
1997
+ let channel_penalty_msat =
1998
+ scorer. channel_penalty_msat( $candidate,
1999
+ channel_usage,
2000
+ score_params) ;
1991
2001
let path_penalty_msat = $next_hops_path_penalty_msat
1992
2002
. saturating_add( channel_penalty_msat) ;
1993
2003
let new_graph_node = RouteGraphNode {
@@ -2295,7 +2305,7 @@ where L::Target: Logger {
2295
2305
effective_capacity : candidate. effective_capacity ( ) ,
2296
2306
} ;
2297
2307
let channel_penalty_msat = scorer. channel_penalty_msat (
2298
- hop . short_channel_id , & source , & target , channel_usage, score_params
2308
+ & candidate , channel_usage, score_params
2299
2309
) ;
2300
2310
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
2301
2311
. saturating_add ( channel_penalty_msat) ;
@@ -2850,13 +2860,13 @@ fn build_route_from_hops_internal<L: Deref>(
2850
2860
2851
2861
impl ScoreLookUp for HopScorer {
2852
2862
type ScoreParams = ( ) ;
2853
- fn channel_penalty_msat ( & self , _short_channel_id : u64 , source : & NodeId , target : & NodeId ,
2863
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop ,
2854
2864
_usage : ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64
2855
2865
{
2856
2866
let mut cur_id = self . our_node_id ;
2857
2867
for i in 0 ..self . hop_ids . len ( ) {
2858
2868
if let Some ( next_id) = self . hop_ids [ i] {
2859
- if cur_id == * source && next_id == * target {
2869
+ if cur_id == candidate . source ( ) && Some ( next_id) == candidate . target ( ) {
2860
2870
return 0 ;
2861
2871
}
2862
2872
cur_id = next_id;
@@ -2897,7 +2907,7 @@ mod tests {
2897
2907
use crate :: routing:: utxo:: UtxoResult ;
2898
2908
use crate :: routing:: router:: { get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
2899
2909
BlindedTail , InFlightHtlcs , Path , PaymentParameters , Route , RouteHint , RouteHintHop , RouteHop , RoutingFees ,
2900
- DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA , MAX_PATH_LENGTH_ESTIMATE , RouteParameters } ;
2910
+ DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA , MAX_PATH_LENGTH_ESTIMATE , RouteParameters , CandidateRouteHop } ;
2901
2911
use crate :: routing:: scoring:: { ChannelUsage , FixedPenaltyScorer , ScoreLookUp , ProbabilisticScorer , ProbabilisticScoringFeeParameters , ProbabilisticScoringDecayParameters } ;
2902
2912
use crate :: routing:: test_utils:: { add_channel, add_or_update_node, build_graph, build_line_graph, id_to_feature_flags, get_nodes, update_channel} ;
2903
2913
use crate :: chain:: transaction:: OutPoint ;
@@ -6202,8 +6212,8 @@ mod tests {
6202
6212
}
6203
6213
impl ScoreLookUp for BadChannelScorer {
6204
6214
type ScoreParams = ( ) ;
6205
- fn channel_penalty_msat ( & self , short_channel_id : u64 , _ : & NodeId , _ : & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6206
- if short_channel_id == self . short_channel_id { u64:: max_value ( ) } else { 0 }
6215
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6216
+ if candidate . short_channel_id ( ) == Some ( self . short_channel_id ) { u64:: max_value ( ) } else { 0 }
6207
6217
}
6208
6218
}
6209
6219
@@ -6218,8 +6228,8 @@ mod tests {
6218
6228
6219
6229
impl ScoreLookUp for BadNodeScorer {
6220
6230
type ScoreParams = ( ) ;
6221
- fn channel_penalty_msat ( & self , _ : u64 , _ : & NodeId , target : & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6222
- if * target == self . node_id { u64:: max_value ( ) } else { 0 }
6231
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6232
+ if candidate . target ( ) == Some ( self . node_id ) { u64:: max_value ( ) } else { 0 }
6223
6233
}
6224
6234
}
6225
6235
@@ -6707,26 +6717,34 @@ mod tests {
6707
6717
} ;
6708
6718
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) , 123 ) ;
6709
6719
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 4 ] ) , 456 ) ;
6710
- assert_eq ! ( scorer. channel_penalty_msat( 42 , & NodeId :: from_pubkey( & nodes[ 3 ] ) , & NodeId :: from_pubkey( & nodes[ 4 ] ) , usage, & scorer_params) , 456 ) ;
6720
+ let network_graph = network_graph. read_only ( ) ;
6721
+ let channels = network_graph. channels ( ) ;
6722
+ let channel = channels. get ( & 5 ) . unwrap ( ) ;
6723
+ let info = channel. as_directed_from ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) . unwrap ( ) ;
6724
+ let candidate: CandidateRouteHop = CandidateRouteHop :: PublicHop {
6725
+ info : info. 0 ,
6726
+ short_channel_id : 5 ,
6727
+ source_node_id : NodeId :: from_pubkey ( & nodes[ 3 ] ) ,
6728
+ target_node_id : NodeId :: from_pubkey ( & nodes[ 4 ] ) ,
6729
+ } ;
6730
+ assert_eq ! ( scorer. channel_penalty_msat( & candidate, usage, & scorer_params) , 456 ) ;
6711
6731
6712
6732
// Then check we can get a normal route
6713
6733
let payment_params = PaymentParameters :: from_node_id ( nodes[ 10 ] , 42 ) ;
6714
6734
let route_params = RouteParameters :: from_payment_params_and_value (
6715
6735
payment_params, 100 ) ;
6716
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6736
+ let route = get_route ( & our_id, & route_params, & network_graph, None ,
6717
6737
Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6718
6738
assert ! ( route. is_ok( ) ) ;
6719
6739
6720
6740
// Then check that we can't get a route if we ban an intermediate node.
6721
6741
scorer_params. add_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6722
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6723
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6742
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6724
6743
assert ! ( route. is_err( ) ) ;
6725
6744
6726
6745
// Finally make sure we can route again, when we remove the ban.
6727
6746
scorer_params. remove_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6728
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6729
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6747
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6730
6748
assert ! ( route. is_ok( ) ) ;
6731
6749
}
6732
6750
0 commit comments