@@ -64,7 +64,7 @@ const MAX_EXCESS_BYTES_FOR_RELAY: usize = 1024;
64
64
const MAX_SCIDS_PER_REPLY : usize = 8000 ;
65
65
66
66
/// Represents the compressed public key of a node
67
- #[ derive( Clone , Copy ) ]
67
+ #[ derive( Clone , Copy , PartialEq , Eq ) ]
68
68
pub struct NodeId ( [ u8 ; PUBLIC_KEY_SIZE ] ) ;
69
69
70
70
impl NodeId {
@@ -116,14 +116,6 @@ impl core::hash::Hash for NodeId {
116
116
}
117
117
}
118
118
119
- impl Eq for NodeId { }
120
-
121
- impl PartialEq for NodeId {
122
- fn eq ( & self , other : & Self ) -> bool {
123
- self . 0 [ ..] == other. 0 [ ..]
124
- }
125
- }
126
-
127
119
impl cmp:: PartialOrd for NodeId {
128
120
fn partial_cmp ( & self , other : & Self ) -> Option < cmp:: Ordering > {
129
121
Some ( self . cmp ( other) )
@@ -885,7 +877,7 @@ impl Readable for ChannelUpdateInfo {
885
877
}
886
878
}
887
879
888
- #[ derive( Clone , Debug , PartialEq , Eq ) ]
880
+ #[ derive( Clone , Debug , Eq ) ]
889
881
/// Details about a channel (both directions).
890
882
/// Received within a channel announcement.
891
883
pub struct ChannelInfo {
@@ -910,6 +902,24 @@ pub struct ChannelInfo {
910
902
/// (which we can probably assume we are - no-std environments probably won't have a full
911
903
/// network graph in memory!).
912
904
announcement_received_time : u64 ,
905
+
906
+ /// The [`NodeInfo::node_counter`] of the node pointed to by [`Self::node_one`].
907
+ pub ( crate ) node_one_counter : u32 ,
908
+ /// The [`NodeInfo::node_counter`] of the node pointed to by [`Self::node_two`].
909
+ pub ( crate ) node_two_counter : u32 ,
910
+ }
911
+
912
+ impl PartialEq for ChannelInfo {
913
+ fn eq ( & self , o : & ChannelInfo ) -> bool {
914
+ self . features == o. features &&
915
+ self . node_one == o. node_one &&
916
+ self . one_to_two == o. one_to_two &&
917
+ self . node_two == o. node_two &&
918
+ self . two_to_one == o. two_to_one &&
919
+ self . capacity_sats == o. capacity_sats &&
920
+ self . announcement_message == o. announcement_message &&
921
+ self . announcement_received_time == o. announcement_received_time
922
+ }
913
923
}
914
924
915
925
impl ChannelInfo {
@@ -1030,6 +1040,8 @@ impl Readable for ChannelInfo {
1030
1040
capacity_sats : _init_tlv_based_struct_field ! ( capacity_sats, required) ,
1031
1041
announcement_message : _init_tlv_based_struct_field ! ( announcement_message, required) ,
1032
1042
announcement_received_time : _init_tlv_based_struct_field ! ( announcement_received_time, ( default_value, 0 ) ) ,
1043
+ node_one_counter : u32:: max_value ( ) ,
1044
+ node_two_counter : u32:: max_value ( ) ,
1033
1045
} )
1034
1046
}
1035
1047
}
@@ -1505,7 +1517,7 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
1505
1517
let mut channels = IndexedMap :: with_capacity ( cmp:: min ( channels_count as usize , 22500 ) ) ;
1506
1518
for _ in 0 ..channels_count {
1507
1519
let chan_id: u64 = Readable :: read ( reader) ?;
1508
- let chan_info = Readable :: read ( reader) ?;
1520
+ let chan_info: ChannelInfo = Readable :: read ( reader) ?;
1509
1521
channels. insert ( chan_id, chan_info) ;
1510
1522
}
1511
1523
let nodes_count: u64 = Readable :: read ( reader) ?;
@@ -1521,6 +1533,13 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
1521
1533
nodes. insert ( node_id, node_info) ;
1522
1534
}
1523
1535
1536
+ for ( _, chan) in channels. unordered_iter_mut ( ) {
1537
+ chan. node_one_counter =
1538
+ nodes. get ( & chan. node_one ) . ok_or ( DecodeError :: InvalidValue ) ?. node_counter ;
1539
+ chan. node_two_counter =
1540
+ nodes. get ( & chan. node_two ) . ok_or ( DecodeError :: InvalidValue ) ?. node_counter ;
1541
+ }
1542
+
1524
1543
let mut last_rapid_gossip_sync_timestamp: Option < u32 > = None ;
1525
1544
read_tlv_fields ! ( reader, {
1526
1545
( 1 , last_rapid_gossip_sync_timestamp, option) ,
@@ -1590,6 +1609,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
1590
1609
1591
1610
fn test_node_counter_consistency ( & self ) {
1592
1611
#[ cfg( debug_assertions) ] {
1612
+ let channels = self . channels . read ( ) . unwrap ( ) ;
1593
1613
let nodes = self . nodes . read ( ) . unwrap ( ) ;
1594
1614
let removed_node_counters = self . removed_node_counters . lock ( ) . unwrap ( ) ;
1595
1615
let next_counter = self . next_node_counter . load ( Ordering :: Acquire ) ;
@@ -1609,6 +1629,19 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
1609
1629
assert_eq ! ( used_node_counters[ pos] & bit, 0 ) ;
1610
1630
used_node_counters[ pos] |= bit;
1611
1631
}
1632
+
1633
+ for ( idx, used_bitset) in used_node_counters. iter ( ) . enumerate ( ) {
1634
+ if idx != next_counter / 8 {
1635
+ assert_eq ! ( * used_bitset, 0xff ) ;
1636
+ } else {
1637
+ assert_eq ! ( * used_bitset, ( 1u8 << ( next_counter % 8 ) ) - 1 ) ;
1638
+ }
1639
+ }
1640
+
1641
+ for ( _, chan) in channels. unordered_iter ( ) {
1642
+ assert_eq ! ( chan. node_one_counter, nodes. get( & chan. node_one) . unwrap( ) . node_counter) ;
1643
+ assert_eq ! ( chan. node_two_counter, nodes. get( & chan. node_two) . unwrap( ) . node_counter) ;
1644
+ }
1612
1645
}
1613
1646
}
1614
1647
@@ -1773,6 +1806,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
1773
1806
capacity_sats : None ,
1774
1807
announcement_message : None ,
1775
1808
announcement_received_time : timestamp,
1809
+ node_one_counter : u32:: max_value ( ) ,
1810
+ node_two_counter : u32:: max_value ( ) ,
1776
1811
} ;
1777
1812
1778
1813
self . add_channel_between_nodes ( short_channel_id, channel_info, None )
@@ -1787,7 +1822,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
1787
1822
1788
1823
log_gossip ! ( self . logger, "Adding channel {} between nodes {} and {}" , short_channel_id, node_id_a, node_id_b) ;
1789
1824
1790
- match channels. entry ( short_channel_id) {
1825
+ let channel_info = match channels. entry ( short_channel_id) {
1791
1826
IndexedMapEntry :: Occupied ( mut entry) => {
1792
1827
//TODO: because asking the blockchain if short_channel_id is valid is only optional
1793
1828
//in the blockchain API, we need to handle it smartly here, though it's unclear
@@ -1803,28 +1838,35 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
1803
1838
// c) it's unclear how to do so without exposing ourselves to massive DoS risk.
1804
1839
self . remove_channel_in_nodes ( & mut nodes, & entry. get ( ) , short_channel_id) ;
1805
1840
* entry. get_mut ( ) = channel_info;
1841
+ entry. into_mut ( )
1806
1842
} else {
1807
1843
return Err ( LightningError { err : "Already have knowledge of channel" . to_owned ( ) , action : ErrorAction :: IgnoreDuplicateGossip } ) ;
1808
1844
}
1809
1845
} ,
1810
1846
IndexedMapEntry :: Vacant ( entry) => {
1811
- entry. insert ( channel_info) ;
1847
+ entry. insert ( channel_info)
1812
1848
}
1813
1849
} ;
1814
1850
1815
- for current_node_id in [ node_id_a, node_id_b] . iter ( ) {
1851
+ let mut node_counter_id = [
1852
+ ( & mut channel_info. node_one_counter , node_id_a) ,
1853
+ ( & mut channel_info. node_two_counter , node_id_b)
1854
+ ] ;
1855
+ for ( node_counter, current_node_id) in node_counter_id. iter_mut ( ) {
1816
1856
match nodes. entry ( current_node_id. clone ( ) ) {
1817
1857
IndexedMapEntry :: Occupied ( node_entry) => {
1818
- node_entry. into_mut ( ) . channels . push ( short_channel_id) ;
1858
+ let node = node_entry. into_mut ( ) ;
1859
+ node. channels . push ( short_channel_id) ;
1860
+ * * node_counter = node. node_counter ;
1819
1861
} ,
1820
1862
IndexedMapEntry :: Vacant ( node_entry) => {
1821
1863
let mut removed_node_counters = self . removed_node_counters . lock ( ) . unwrap ( ) ;
1822
- let node_counter = removed_node_counters. pop ( )
1864
+ * * node_counter = removed_node_counters. pop ( )
1823
1865
. unwrap_or ( self . next_node_counter . fetch_add ( 1 , Ordering :: Relaxed ) as u32 ) ;
1824
1866
node_entry. insert ( NodeInfo {
1825
1867
channels : vec ! ( short_channel_id) ,
1826
1868
announcement_info : None ,
1827
- node_counter,
1869
+ node_counter : * * node_counter ,
1828
1870
} ) ;
1829
1871
}
1830
1872
} ;
@@ -1915,6 +1957,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
1915
1957
announcement_message : if msg. excess_data . len ( ) <= MAX_EXCESS_BYTES_FOR_RELAY
1916
1958
{ full_msg. cloned ( ) } else { None } ,
1917
1959
announcement_received_time,
1960
+ node_one_counter : u32:: max_value ( ) ,
1961
+ node_two_counter : u32:: max_value ( ) ,
1918
1962
} ;
1919
1963
1920
1964
self . add_channel_between_nodes ( msg. short_channel_id , chan_info, utxo_value) ?;
@@ -1976,6 +2020,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
1976
2020
}
1977
2021
}
1978
2022
removed_channels. insert ( * scid, current_time_unix) ;
2023
+ } else {
2024
+ debug_assert ! ( false , "Channels in nodes must always have channel info" ) ;
1979
2025
}
1980
2026
}
1981
2027
removed_node_counters. push ( node. node_counter ) ;
@@ -3595,6 +3641,8 @@ pub(crate) mod tests {
3595
3641
capacity_sats : None ,
3596
3642
announcement_message : None ,
3597
3643
announcement_received_time : 87654 ,
3644
+ node_one_counter : 0 ,
3645
+ node_two_counter : 1 ,
3598
3646
} ;
3599
3647
3600
3648
let mut encoded_chan_info: Vec < u8 > = Vec :: new ( ) ;
@@ -3613,6 +3661,8 @@ pub(crate) mod tests {
3613
3661
capacity_sats : None ,
3614
3662
announcement_message : None ,
3615
3663
announcement_received_time : 87654 ,
3664
+ node_one_counter : 0 ,
3665
+ node_two_counter : 1 ,
3616
3666
} ;
3617
3667
3618
3668
let mut encoded_chan_info: Vec < u8 > = Vec :: new ( ) ;
0 commit comments