Skip to content

Commit 21aa932

Browse files
committed
Use CandidateRouteHop as input for channel_penalty_msat
We remove `source`, `target` and `scid` from `channel_penalty_msat` inputs to consume them from `candidate` of type `CandidateRouteHop`
1 parent fb5a3f0 commit 21aa932

File tree

4 files changed

+463
-187
lines changed

4 files changed

+463
-187
lines changed

lightning-background-processor/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -863,8 +863,8 @@ mod tests {
863863
use lightning::ln::msgs::{ChannelMessageHandler, Init};
864864
use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
865865
use lightning::routing::gossip::{NetworkGraph, NodeId, P2PGossipSync};
866-
use lightning::routing::router::{DefaultRouter, Path, RouteHop};
867866
use lightning::routing::scoring::{ChannelUsage, ScoreUpdate, ScoreLookUp, LockableScore};
867+
use lightning::routing::router::{DefaultRouter, Path, RouteHop, CandidateRouteHop};
868868
use lightning::util::config::UserConfig;
869869
use lightning::util::ser::Writeable;
870870
use lightning::util::test_utils;
@@ -1071,7 +1071,7 @@ mod tests {
10711071
impl ScoreLookUp for TestScorer {
10721072
type ScoreParams = ();
10731073
fn channel_penalty_msat(
1074-
&self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId, _usage: ChannelUsage, _score_params: &Self::ScoreParams
1074+
&self, _candidate: &CandidateRouteHop, _usage: ChannelUsage, _score_params: &Self::ScoreParams
10751075
) -> u64 { unimplemented!(); }
10761076
}
10771077

lightning/src/routing/router.rs

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,27 @@ impl<'a, S: Deref> ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: Scor
130130

131131
impl<'a, S: Deref> ScoreLookUp for ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: ScoreLookUp {
132132
type ScoreParams = <S::Target as ScoreLookUp>::ScoreParams;
133-
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
133+
fn channel_penalty_msat(&self, candidate: &CandidateRouteHop, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
134+
let target = match candidate.target() {
135+
Some(target) => target,
136+
None => return self.scorer.channel_penalty_msat(candidate, usage, score_params),
137+
};
138+
let short_channel_id = match candidate.short_channel_id() {
139+
Some(short_channel_id) => short_channel_id,
140+
None => return self.scorer.channel_penalty_msat(candidate, usage, score_params),
141+
};
142+
let source = candidate.source();
134143
if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat(
135-
source, target, short_channel_id
144+
&source, &target, short_channel_id
136145
) {
137146
let usage = ChannelUsage {
138147
inflight_htlc_msat: usage.inflight_htlc_msat.saturating_add(used_liquidity),
139148
..usage
140149
};
141150

142-
self.scorer.channel_penalty_msat(short_channel_id, source, target, usage, score_params)
151+
self.scorer.channel_penalty_msat(candidate, usage, score_params)
143152
} else {
144-
self.scorer.channel_penalty_msat(short_channel_id, source, target, usage, score_params)
153+
self.scorer.channel_penalty_msat(candidate, usage, score_params)
145154
}
146155
}
147156
}
@@ -1072,7 +1081,7 @@ impl<'a> CandidateRouteHop<'a> {
10721081
/// For `Blinded` and `OneHopBlinded` we return `None` because next hop is not known.
10731082
pub fn short_channel_id(&self) -> Option<u64> {
10741083
match self {
1075-
CandidateRouteHop::FirstHop { details, .. } => Some(details.get_outbound_payment_scid().unwrap()),
1084+
CandidateRouteHop::FirstHop { details, .. } => details.get_outbound_payment_scid(),
10761085
CandidateRouteHop::PublicHop { short_channel_id, .. } => Some(*short_channel_id),
10771086
CandidateRouteHop::PrivateHop { hint, .. } => Some(hint.short_channel_id),
10781087
CandidateRouteHop::Blinded { .. } => None,
@@ -1177,7 +1186,7 @@ impl<'a> CandidateRouteHop<'a> {
11771186
CandidateRouteHop::PublicHop { source_node_id, .. } => *source_node_id,
11781187
CandidateRouteHop::PrivateHop { hint, .. } => hint.src_node_id.into(),
11791188
CandidateRouteHop::Blinded { hint, .. } => hint.1.introduction_node_id.into(),
1180-
CandidateRouteHop::OneHopBlinded { hint, .. } => hint.1.introduction_node_id.into()
1189+
CandidateRouteHop::OneHopBlinded { hint, .. } => hint.1.introduction_node_id.into(),
11811190
}
11821191
}
11831192
/// Returns the target node id of this hop, if known.
@@ -1929,7 +1938,7 @@ where L::Target: Logger {
19291938
);
19301939
let path_htlc_minimum_msat = compute_fees_saturating(curr_min, $candidate.fees())
19311940
.saturating_add(curr_min);
1932-
let hm_entry = dist.entry(&src_node_id);
1941+
let hm_entry = dist.entry(src_node_id);
19331942
let old_entry = hm_entry.or_insert_with(|| {
19341943
// If there was previously no known way to access the source node
19351944
// (recall it goes payee-to-payer) of short_channel_id, first add a
@@ -1985,9 +1994,10 @@ where L::Target: Logger {
19851994
inflight_htlc_msat: used_liquidity_msat,
19861995
effective_capacity,
19871996
};
1988-
let channel_penalty_msat = scid_opt.map_or(0,
1989-
|scid| scorer.channel_penalty_msat(scid, &src_node_id, &dest_node_id,
1990-
channel_usage, score_params));
1997+
let channel_penalty_msat =
1998+
scorer.channel_penalty_msat($candidate,
1999+
channel_usage,
2000+
score_params);
19912001
let path_penalty_msat = $next_hops_path_penalty_msat
19922002
.saturating_add(channel_penalty_msat);
19932003
let new_graph_node = RouteGraphNode {
@@ -2295,7 +2305,7 @@ where L::Target: Logger {
22952305
effective_capacity: candidate.effective_capacity(),
22962306
};
22972307
let channel_penalty_msat = scorer.channel_penalty_msat(
2298-
hop.short_channel_id, &source, &target, channel_usage, score_params
2308+
&candidate, channel_usage, score_params
22992309
);
23002310
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
23012311
.saturating_add(channel_penalty_msat);
@@ -2850,13 +2860,13 @@ fn build_route_from_hops_internal<L: Deref>(
28502860

28512861
impl ScoreLookUp for HopScorer {
28522862
type ScoreParams = ();
2853-
fn channel_penalty_msat(&self, _short_channel_id: u64, source: &NodeId, target: &NodeId,
2863+
fn channel_penalty_msat(&self, candidate: &CandidateRouteHop,
28542864
_usage: ChannelUsage, _score_params: &Self::ScoreParams) -> u64
28552865
{
28562866
let mut cur_id = self.our_node_id;
28572867
for i in 0..self.hop_ids.len() {
28582868
if let Some(next_id) = self.hop_ids[i] {
2859-
if cur_id == *source && next_id == *target {
2869+
if cur_id == candidate.source() && Some(next_id) == candidate.target() {
28602870
return 0;
28612871
}
28622872
cur_id = next_id;
@@ -2897,7 +2907,7 @@ mod tests {
28972907
use crate::routing::utxo::UtxoResult;
28982908
use crate::routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
28992909
BlindedTail, InFlightHtlcs, Path, PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees,
2900-
DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE, RouteParameters};
2910+
DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE, RouteParameters, CandidateRouteHop};
29012911
use crate::routing::scoring::{ChannelUsage, FixedPenaltyScorer, ScoreLookUp, ProbabilisticScorer, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters};
29022912
use crate::routing::test_utils::{add_channel, add_or_update_node, build_graph, build_line_graph, id_to_feature_flags, get_nodes, update_channel};
29032913
use crate::chain::transaction::OutPoint;
@@ -6202,8 +6212,8 @@ mod tests {
62026212
}
62036213
impl ScoreLookUp for BadChannelScorer {
62046214
type ScoreParams = ();
6205-
fn channel_penalty_msat(&self, short_channel_id: u64, _: &NodeId, _: &NodeId, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
6206-
if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
6215+
fn channel_penalty_msat(&self, candidate: &CandidateRouteHop, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
6216+
if candidate.short_channel_id() == Some(self.short_channel_id) { u64::max_value() } else { 0 }
62076217
}
62086218
}
62096219

@@ -6218,8 +6228,8 @@ mod tests {
62186228

62196229
impl ScoreLookUp for BadNodeScorer {
62206230
type ScoreParams = ();
6221-
fn channel_penalty_msat(&self, _: u64, _: &NodeId, target: &NodeId, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
6222-
if *target == self.node_id { u64::max_value() } else { 0 }
6231+
fn channel_penalty_msat(&self, candidate: &CandidateRouteHop, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
6232+
if candidate.target() == Some(self.node_id) { u64::max_value() } else { 0 }
62236233
}
62246234
}
62256235

@@ -6707,26 +6717,34 @@ mod tests {
67076717
};
67086718
scorer_params.set_manual_penalty(&NodeId::from_pubkey(&nodes[3]), 123);
67096719
scorer_params.set_manual_penalty(&NodeId::from_pubkey(&nodes[4]), 456);
6710-
assert_eq!(scorer.channel_penalty_msat(42, &NodeId::from_pubkey(&nodes[3]), &NodeId::from_pubkey(&nodes[4]), usage, &scorer_params), 456);
6720+
let network_graph = network_graph.read_only();
6721+
let channels = network_graph.channels();
6722+
let channel = channels.get(&5).unwrap();
6723+
let info = channel.as_directed_from(&NodeId::from_pubkey(&nodes[3])).unwrap();
6724+
let candidate: CandidateRouteHop = CandidateRouteHop::PublicHop {
6725+
info: info.0,
6726+
short_channel_id: 5,
6727+
source_node_id: NodeId::from_pubkey(&nodes[3]),
6728+
target_node_id: NodeId::from_pubkey(&nodes[4]),
6729+
};
6730+
assert_eq!(scorer.channel_penalty_msat(&candidate, usage, &scorer_params), 456);
67116731

67126732
// Then check we can get a normal route
67136733
let payment_params = PaymentParameters::from_node_id(nodes[10], 42);
67146734
let route_params = RouteParameters::from_payment_params_and_value(
67156735
payment_params, 100);
6716-
let route = get_route(&our_id, &route_params, &network_graph.read_only(), None,
6736+
let route = get_route(&our_id, &route_params, &network_graph, None,
67176737
Arc::clone(&logger), &scorer, &scorer_params, &random_seed_bytes);
67186738
assert!(route.is_ok());
67196739

67206740
// Then check that we can't get a route if we ban an intermediate node.
67216741
scorer_params.add_banned(&NodeId::from_pubkey(&nodes[3]));
6722-
let route = get_route(&our_id, &route_params, &network_graph.read_only(), None,
6723-
Arc::clone(&logger), &scorer, &scorer_params, &random_seed_bytes);
6742+
let route = get_route(&our_id, &route_params, &network_graph, None, Arc::clone(&logger), &scorer, &scorer_params,&random_seed_bytes);
67246743
assert!(route.is_err());
67256744

67266745
// Finally make sure we can route again, when we remove the ban.
67276746
scorer_params.remove_banned(&NodeId::from_pubkey(&nodes[3]));
6728-
let route = get_route(&our_id, &route_params, &network_graph.read_only(), None,
6729-
Arc::clone(&logger), &scorer, &scorer_params, &random_seed_bytes);
6747+
let route = get_route(&our_id, &route_params, &network_graph, None, Arc::clone(&logger), &scorer, &scorer_params,&random_seed_bytes);
67306748
assert!(route.is_ok());
67316749
}
67326750

0 commit comments

Comments
 (0)