@@ -130,18 +130,27 @@ impl<'a, S: Deref> ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: Scor
130
130
131
131
impl < ' a , S : Deref > ScoreLookUp for ScorerAccountingForInFlightHtlcs < ' a , S > where S :: Target : ScoreLookUp {
132
132
type ScoreParams = <S :: Target as ScoreLookUp >:: ScoreParams ;
133
- fn channel_penalty_msat ( & self , short_channel_id : u64 , source : & NodeId , target : & NodeId , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
133
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
134
+ let target = match candidate. target ( ) {
135
+ Some ( target) => target,
136
+ None => return self . scorer . channel_penalty_msat ( candidate, usage, score_params) ,
137
+ } ;
138
+ let short_channel_id = match candidate. short_channel_id ( ) {
139
+ Some ( short_channel_id) => short_channel_id,
140
+ None => return self . scorer . channel_penalty_msat ( candidate, usage, score_params) ,
141
+ } ;
142
+ let source = candidate. source ( ) ;
134
143
if let Some ( used_liquidity) = self . inflight_htlcs . used_liquidity_msat (
135
- source, target, short_channel_id
144
+ & source, & target, short_channel_id
136
145
) {
137
146
let usage = ChannelUsage {
138
147
inflight_htlc_msat : usage. inflight_htlc_msat . saturating_add ( used_liquidity) ,
139
148
..usage
140
149
} ;
141
150
142
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
151
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
143
152
} else {
144
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
153
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
145
154
}
146
155
}
147
156
}
@@ -1065,7 +1074,7 @@ impl<'a> CandidateRouteHop<'a> {
1065
1074
/// For `Blinded` and `OneHopBlinded` we return `None` because next hop is not known.
1066
1075
pub fn short_channel_id ( & self ) -> Option < u64 > {
1067
1076
match self {
1068
- CandidateRouteHop :: FirstHop { details, .. } => Some ( details. get_outbound_payment_scid ( ) . unwrap ( ) ) ,
1077
+ CandidateRouteHop :: FirstHop { details, .. } => details. get_outbound_payment_scid ( ) ,
1069
1078
CandidateRouteHop :: PublicHop { short_channel_id, .. } => Some ( * short_channel_id) ,
1070
1079
CandidateRouteHop :: PrivateHop { hint, .. } => Some ( hint. short_channel_id ) ,
1071
1080
CandidateRouteHop :: Blinded { .. } => None ,
@@ -1171,7 +1180,7 @@ impl<'a> CandidateRouteHop<'a> {
1171
1180
CandidateRouteHop :: PublicHop { source_node_id, .. } => * source_node_id,
1172
1181
CandidateRouteHop :: PrivateHop { hint, .. } => hint. src_node_id . into ( ) ,
1173
1182
CandidateRouteHop :: Blinded { hint, .. } => hint. 1 . introduction_node_id . into ( ) ,
1174
- CandidateRouteHop :: OneHopBlinded { hint, .. } => hint. 1 . introduction_node_id . into ( )
1183
+ CandidateRouteHop :: OneHopBlinded { hint, .. } => hint. 1 . introduction_node_id . into ( ) ,
1175
1184
}
1176
1185
}
1177
1186
/// Returns the target node id of this hop, if known.
@@ -1798,7 +1807,7 @@ where L::Target: Logger {
1798
1807
let mut num_ignored_htlc_minimum_msat_limit: u32 = 0 ;
1799
1808
1800
1809
macro_rules! add_entry {
1801
- // Adds entry which goes from $src_node_id to $dest_node_id over the $candidate hop.
1810
+ // Adds entry which goes from candidate.source() to candiadte.target() over the $candidate hop.
1802
1811
// $next_hops_fee_msat represents the fees paid for using all the channels *after* this one,
1803
1812
// since that value has to be transferred over this channel.
1804
1813
// Returns the contribution amount of $candidate if the channel caused an update to `targets`.
@@ -1814,7 +1823,7 @@ where L::Target: Logger {
1814
1823
// - for first and last hops early in get_route
1815
1824
let src_node_id = $candidate. source( ) ;
1816
1825
let dest_node_id = $candidate. target( ) . unwrap_or( maybe_dummy_payee_node_id) ;
1817
- if src_node_id != dest_node_id {
1826
+ if Some ( $candidate . source ( ) ) != $candidate . target ( ) {
1818
1827
let scid_opt = $candidate. short_channel_id( ) ;
1819
1828
let effective_capacity = $candidate. effective_capacity( ) ;
1820
1829
let htlc_maximum_msat = max_htlc_from_capacity( effective_capacity, channel_saturation_pow_half) ;
@@ -1979,9 +1988,10 @@ where L::Target: Logger {
1979
1988
inflight_htlc_msat: used_liquidity_msat,
1980
1989
effective_capacity,
1981
1990
} ;
1982
- let channel_penalty_msat = scid_opt. map_or( 0 ,
1983
- |scid| scorer. channel_penalty_msat( scid, & src_node_id, & dest_node_id,
1984
- channel_usage, score_params) ) ;
1991
+ let channel_penalty_msat =
1992
+ scorer. channel_penalty_msat( $candidate,
1993
+ channel_usage,
1994
+ score_params) ;
1985
1995
let path_penalty_msat = $next_hops_path_penalty_msat
1986
1996
. saturating_add( channel_penalty_msat) ;
1987
1997
let new_graph_node = RouteGraphNode {
@@ -1994,7 +2004,7 @@ where L::Target: Logger {
1994
2004
path_length_to_node,
1995
2005
} ;
1996
2006
1997
- // Update the way of reaching $src_node_id with the given short_channel_id (from $dest_node_id ),
2007
+ // Update the way of reaching $candidate.source() with the given short_channel_id (from $candidate.target() ),
1998
2008
// if this way is cheaper than the already known
1999
2009
// (considering the cost to "reach" this channel from the route destination,
2000
2010
// the cost of using this channel,
@@ -2288,7 +2298,7 @@ where L::Target: Logger {
2288
2298
effective_capacity : candidate. effective_capacity ( ) ,
2289
2299
} ;
2290
2300
let channel_penalty_msat = scorer. channel_penalty_msat (
2291
- hop . short_channel_id , & source , & target , channel_usage, score_params
2301
+ & candidate , channel_usage, score_params
2292
2302
) ;
2293
2303
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
2294
2304
. saturating_add ( channel_penalty_msat) ;
@@ -2652,7 +2662,6 @@ where L::Target: Logger {
2652
2662
let mut paths = Vec :: new ( ) ;
2653
2663
for payment_path in selected_route {
2654
2664
let mut hops = Vec :: with_capacity ( payment_path. hops . len ( ) ) ;
2655
- let mut prev_hop_node_id = our_node_id;
2656
2665
for ( hop, node_features) in payment_path. hops . iter ( )
2657
2666
. filter ( |( h, _) | h. candidate . short_channel_id ( ) . is_some ( ) )
2658
2667
{
@@ -2669,7 +2678,7 @@ where L::Target: Logger {
2669
2678
// an alias, in which case we don't take any chances here.
2670
2679
network_graph. node ( & hop. node_id ) . map_or ( false , |hop_node|
2671
2680
hop_node. channels . iter ( ) . any ( |scid| network_graph. channel ( * scid)
2672
- . map_or ( false , |c| c. as_directed_from ( & prev_hop_node_id ) . is_some ( ) ) )
2681
+ . map_or ( false , |c| c. as_directed_from ( & hop . candidate . source ( ) ) . is_some ( ) ) )
2673
2682
)
2674
2683
} ;
2675
2684
@@ -2682,8 +2691,6 @@ where L::Target: Logger {
2682
2691
cltv_expiry_delta : hop. candidate . cltv_expiry_delta ( ) ,
2683
2692
maybe_announced_channel,
2684
2693
} ) ;
2685
-
2686
- prev_hop_node_id = hop. node_id ;
2687
2694
}
2688
2695
let mut final_cltv_delta = final_cltv_expiry_delta;
2689
2696
let blinded_tail = payment_path. hops . last ( ) . and_then ( |( h, _) | {
@@ -2846,13 +2853,13 @@ fn build_route_from_hops_internal<L: Deref>(
2846
2853
2847
2854
impl ScoreLookUp for HopScorer {
2848
2855
type ScoreParams = ( ) ;
2849
- fn channel_penalty_msat ( & self , _short_channel_id : u64 , source : & NodeId , target : & NodeId ,
2856
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop ,
2850
2857
_usage : ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64
2851
2858
{
2852
2859
let mut cur_id = self . our_node_id ;
2853
2860
for i in 0 ..self . hop_ids . len ( ) {
2854
2861
if let Some ( next_id) = self . hop_ids [ i] {
2855
- if cur_id == * source && next_id == * target {
2862
+ if cur_id == candidate . source ( ) && Some ( next_id) == candidate . target ( ) {
2856
2863
return 0 ;
2857
2864
}
2858
2865
cur_id = next_id;
@@ -2928,6 +2935,8 @@ mod tests {
2928
2935
2929
2936
use core:: convert:: TryInto ;
2930
2937
2938
+ use super :: CandidateRouteHop ;
2939
+
2931
2940
fn get_channel_details ( short_channel_id : Option < u64 > , node_id : PublicKey ,
2932
2941
features : InitFeatures , outbound_capacity_msat : u64 ) -> channelmanager:: ChannelDetails {
2933
2942
channelmanager:: ChannelDetails {
@@ -6200,7 +6209,11 @@ mod tests {
6200
6209
}
6201
6210
impl ScoreLookUp for BadChannelScorer {
6202
6211
type ScoreParams = ( ) ;
6203
- fn channel_penalty_msat ( & self , short_channel_id : u64 , _: & NodeId , _: & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6212
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6213
+ let short_channel_id = match candidate. short_channel_id ( ) {
6214
+ Some ( id) => id,
6215
+ None => return 0 ,
6216
+ } ;
6204
6217
if short_channel_id == self . short_channel_id { u64:: max_value ( ) } else { 0 }
6205
6218
}
6206
6219
}
@@ -6216,8 +6229,8 @@ mod tests {
6216
6229
6217
6230
impl ScoreLookUp for BadNodeScorer {
6218
6231
type ScoreParams = ( ) ;
6219
- fn channel_penalty_msat ( & self , _ : u64 , _ : & NodeId , target : & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6220
- if * target == self . node_id { u64:: max_value ( ) } else { 0 }
6232
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6233
+ if candidate . target ( ) == Some ( self . node_id ) { u64:: max_value ( ) } else { 0 }
6221
6234
}
6222
6235
}
6223
6236
@@ -6705,26 +6718,34 @@ mod tests {
6705
6718
} ;
6706
6719
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) , 123 ) ;
6707
6720
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 4 ] ) , 456 ) ;
6708
- assert_eq ! ( scorer. channel_penalty_msat( 42 , & NodeId :: from_pubkey( & nodes[ 3 ] ) , & NodeId :: from_pubkey( & nodes[ 4 ] ) , usage, & scorer_params) , 456 ) ;
6721
+ let network_graph = network_graph. read_only ( ) ;
6722
+ let channels = network_graph. channels ( ) ;
6723
+ let channel = channels. get ( & 5 ) . unwrap ( ) ;
6724
+ let info = channel. as_directed_from ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) . unwrap ( ) ;
6725
+ let candidate: CandidateRouteHop = CandidateRouteHop :: PublicHop {
6726
+ info : info. 0 ,
6727
+ short_channel_id : 5 ,
6728
+ source_node_id : NodeId :: from_pubkey ( & nodes[ 3 ] ) ,
6729
+ target_node_id : NodeId :: from_pubkey ( & nodes[ 4 ] ) ,
6730
+ } ;
6731
+ assert_eq ! ( scorer. channel_penalty_msat( & candidate, usage, & scorer_params) , 456 ) ;
6709
6732
6710
6733
// Then check we can get a normal route
6711
6734
let payment_params = PaymentParameters :: from_node_id ( nodes[ 10 ] , 42 ) ;
6712
6735
let route_params = RouteParameters :: from_payment_params_and_value (
6713
6736
payment_params, 100 ) ;
6714
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6737
+ let route = get_route ( & our_id, & route_params, & network_graph, None ,
6715
6738
Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6716
6739
assert ! ( route. is_ok( ) ) ;
6717
6740
6718
6741
// Then check that we can't get a route if we ban an intermediate node.
6719
6742
scorer_params. add_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6720
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6721
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6743
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6722
6744
assert ! ( route. is_err( ) ) ;
6723
6745
6724
6746
// Finally make sure we can route again, when we remove the ban.
6725
6747
scorer_params. remove_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6726
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6727
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6748
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6728
6749
assert ! ( route. is_ok( ) ) ;
6729
6750
}
6730
6751
0 commit comments