Skip to content

Include source and destination nodes in routing::Score #1133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lightning/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ pub mod network_graph;
pub mod router;
pub mod scorer;

use routing::network_graph::NodeId;

/// An interface used to score payment channels for path finding.
///
/// Scoring is in terms of fees willing to be paid in order to avoid routing through a channel.
pub trait Score {
/// Returns the fee in msats willing to be paid to avoid routing through the given channel.
fn channel_penalty_msat(&self, short_channel_id: u64) -> u64;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit +"in the given direction" on the docs.

/// Returns the fee in msats willing to be paid to avoid routing through the given channel
/// in the direction from `source` to `target`.
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64;
}
113 changes: 83 additions & 30 deletions lightning/src/routing/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ where L::Target: Logger {
}

let path_penalty_msat = $next_hops_path_penalty_msat
.checked_add(scorer.channel_penalty_msat($chan_id.clone()))
.checked_add(scorer.channel_penalty_msat($chan_id.clone(), &$src_node_id, &$dest_node_id))
.unwrap_or_else(|| u64::max_value());
let new_graph_node = RouteGraphNode {
node_id: $src_node_id,
Expand Down Expand Up @@ -973,15 +973,17 @@ where L::Target: Logger {
_ => aggregate_next_hops_fee_msat.checked_add(999).unwrap_or(u64::max_value())
}) { Some( val / 1000 ) } else { break; }; // converting from msat or breaking if max ~ infinity

let src_node_id = NodeId::from_pubkey(&hop.src_node_id);
let dest_node_id = NodeId::from_pubkey(&prev_hop_id);
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
.checked_add(scorer.channel_penalty_msat(hop.short_channel_id))
.checked_add(scorer.channel_penalty_msat(hop.short_channel_id, &src_node_id, &dest_node_id))
.unwrap_or_else(|| u64::max_value());

// We assume that the recipient only included route hints for routes which had
// sufficient value to route `final_value_msat`. Note that in the case of "0-value"
// invoices where the invoice does not specify value this may not be the case, but
// better to include the hints than not.
if !add_entry!(hop.short_channel_id, NodeId::from_pubkey(&hop.src_node_id), NodeId::from_pubkey(&prev_hop_id), directional_info, reqd_channel_cap, &empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
if !add_entry!(hop.short_channel_id, src_node_id, dest_node_id, directional_info, reqd_channel_cap, &empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
// If this hop was not used then there is no use checking the preceding hops
// in the RouteHint. We can break by just searching for a direct channel between
// last checked hop and first_hop_targets
Expand Down Expand Up @@ -1322,7 +1324,8 @@ where L::Target: Logger {

#[cfg(test)]
mod tests {
use routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
use routing;
use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId};
use routing::router::{get_route, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees};
use routing::scorer::Scorer;
use chain::transaction::OutPoint;
Expand Down Expand Up @@ -4351,42 +4354,92 @@ mod tests {
let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
let (_, our_id, _, nodes) = get_nodes(&secp_ctx);

// Without penalizing each hop 100 msats, a longer path with lower fees is chosen.
let scorer = Scorer::new(0);
let route = get_route(
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
).unwrap();
let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();

assert_eq!(route.get_total_fees(), 100);
assert_eq!(route.get_total_amount(), 100);
assert_eq!(path, vec![2, 4, 6, 11, 8]);

// Applying a 100 msat penalty to each hop results in taking channels 7 and 10 to nodes[6]
// from nodes[2] rather than channel 6, 11, and 8, even though the longer path is cheaper.
let scorer = Scorer::new(100);
let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, &last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
assert_eq!(route.paths[0].len(), 4);
let route = get_route(
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
).unwrap();
let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();

assert_eq!(route.paths[0][0].pubkey, nodes[1]);
assert_eq!(route.paths[0][0].short_channel_id, 2);
assert_eq!(route.paths[0][0].fee_msat, 200);
assert_eq!(route.paths[0][0].cltv_expiry_delta, (4 << 8) | 1);
assert_eq!(route.paths[0][0].node_features.le_flags(), &id_to_feature_flags(2));
assert_eq!(route.paths[0][0].channel_features.le_flags(), &id_to_feature_flags(2));
assert_eq!(route.get_total_fees(), 300);
assert_eq!(route.get_total_amount(), 100);
assert_eq!(path, vec![2, 4, 7, 10]);
}

assert_eq!(route.paths[0][1].pubkey, nodes[2]);
assert_eq!(route.paths[0][1].short_channel_id, 4);
assert_eq!(route.paths[0][1].fee_msat, 100);
assert_eq!(route.paths[0][1].cltv_expiry_delta, (7 << 8) | 1);
assert_eq!(route.paths[0][1].node_features.le_flags(), &id_to_feature_flags(3));
assert_eq!(route.paths[0][1].channel_features.le_flags(), &id_to_feature_flags(4));
struct BadChannelScorer {
short_channel_id: u64,
}

assert_eq!(route.paths[0][2].pubkey, nodes[5]);
assert_eq!(route.paths[0][2].short_channel_id, 7);
assert_eq!(route.paths[0][2].fee_msat, 0);
assert_eq!(route.paths[0][2].cltv_expiry_delta, (10 << 8) | 1);
assert_eq!(route.paths[0][2].node_features.le_flags(), &id_to_feature_flags(6));
assert_eq!(route.paths[0][2].channel_features.le_flags(), &id_to_feature_flags(7));
impl routing::Score for BadChannelScorer {
fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId) -> u64 {
if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
}
}

assert_eq!(route.paths[0][3].pubkey, nodes[6]);
assert_eq!(route.paths[0][3].short_channel_id, 10);
assert_eq!(route.paths[0][3].fee_msat, 100);
assert_eq!(route.paths[0][3].cltv_expiry_delta, 42);
assert_eq!(route.paths[0][3].node_features.le_flags(), &Vec::<u8>::new()); // We don't pass flags in from invoices yet
assert_eq!(route.paths[0][3].channel_features.le_flags(), &Vec::<u8>::new()); // We can't learn any flags from invoices, sadly
struct BadNodeScorer {
node_id: NodeId,
}

impl routing::Score for BadNodeScorer {
fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, target: &NodeId) -> u64 {
if *target == self.node_id { u64::max_value() } else { 0 }
}
}

#[test]
fn avoids_routing_through_bad_channels_and_nodes() {
let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
let (_, our_id, _, nodes) = get_nodes(&secp_ctx);

// A path to nodes[6] exists when no penalties are applied to any channel.
let scorer = Scorer::new(0);
let route = get_route(
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
).unwrap();
let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();

assert_eq!(route.get_total_fees(), 100);
assert_eq!(route.get_total_amount(), 100);
assert_eq!(path, vec![2, 4, 6, 11, 8]);

// A different path to nodes[6] exists if channel 6 cannot be routed over.
let scorer = BadChannelScorer { short_channel_id: 6 };
let route = get_route(
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
).unwrap();
let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();

assert_eq!(route.get_total_fees(), 300);
assert_eq!(route.get_total_amount(), 100);
assert_eq!(path, vec![2, 4, 7, 10]);

// A path to nodes[6] does not exist if nodes[2] cannot be routed through.
let scorer = BadNodeScorer { node_id: NodeId::from_pubkey(&nodes[2]) };
match get_route(
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
) {
Err(LightningError { err, .. } ) => {
assert_eq!(err, "Failed to find a path to the given destination");
},
Ok(_) => panic!("Expected error"),
}
}

#[test]
Expand Down
8 changes: 7 additions & 1 deletion lightning/src/routing/scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

use routing;

use routing::network_graph::NodeId;

/// [`routing::Score`] implementation that provides reasonable default behavior.
///
/// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
Expand Down Expand Up @@ -71,5 +73,9 @@ impl Default for Scorer {
}

impl routing::Score for Scorer {
fn channel_penalty_msat(&self, _short_channel_id: u64) -> u64 { self.base_penalty_msat }
fn channel_penalty_msat(
&self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
) -> u64 {
self.base_penalty_msat
}
}