@@ -7,7 +7,7 @@ use util::byte_utils;
7
7
use util:: events:: { EventsProvider , Event } ;
8
8
use util:: logger:: Logger ;
9
9
10
- use std:: collections:: { HashMap , LinkedList } ;
10
+ use std:: collections:: { HashMap , hash_map , LinkedList } ;
11
11
use std:: sync:: { Arc , Mutex } ;
12
12
use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
13
13
use std:: { cmp, error, mem, hash, fmt} ;
@@ -90,6 +90,18 @@ struct PeerHolder<Descriptor: SocketDescriptor> {
90
90
/// Only add to this set when noise completes:
91
91
node_id_to_descriptor : HashMap < PublicKey , Descriptor > ,
92
92
}
93
+ struct MutPeerHolder < ' a , Descriptor : SocketDescriptor + ' a > {
94
+ peers : & ' a mut HashMap < Descriptor , Peer > ,
95
+ node_id_to_descriptor : & ' a mut HashMap < PublicKey , Descriptor > ,
96
+ }
97
+ impl < Descriptor : SocketDescriptor > PeerHolder < Descriptor > {
98
+ fn borrow_parts ( & mut self ) -> MutPeerHolder < Descriptor > {
99
+ MutPeerHolder {
100
+ peers : & mut self . peers ,
101
+ node_id_to_descriptor : & mut self . node_id_to_descriptor ,
102
+ }
103
+ }
104
+ }
93
105
94
106
pub struct PeerManager < Descriptor : SocketDescriptor > {
95
107
message_handler : MessageHandler ,
@@ -100,7 +112,6 @@ pub struct PeerManager<Descriptor: SocketDescriptor> {
100
112
logger : Arc < Logger > ,
101
113
}
102
114
103
-
104
115
macro_rules! encode_msg {
105
116
( $msg: expr, $msg_code: expr) => {
106
117
{
@@ -136,7 +147,12 @@ impl<Descriptor: SocketDescriptor> PeerManager<Descriptor> {
136
147
/// completed and we are sure the remote peer has the private key for the given node_id.
137
148
pub fn get_peer_node_ids ( & self ) -> Vec < PublicKey > {
138
149
let peers = self . peers . lock ( ) . unwrap ( ) ;
139
- peers. peers . values ( ) . filter_map ( |p| p. their_node_id ) . collect ( )
150
+ peers. peers . values ( ) . filter_map ( |p| {
151
+ if !p. channel_encryptor . is_ready_for_encryption ( ) || p. their_global_features . is_none ( ) {
152
+ return None ;
153
+ }
154
+ p. their_node_id
155
+ } ) . collect ( )
140
156
}
141
157
142
158
/// Indicates a new outbound connection has been established to a node with the given node_id.
@@ -267,14 +283,14 @@ impl<Descriptor: SocketDescriptor> PeerManager<Descriptor> {
267
283
268
284
fn do_read_event ( & self , peer_descriptor : & mut Descriptor , data : Vec < u8 > ) -> Result < bool , PeerHandleError > {
269
285
let pause_read = {
270
- let mut peers = self . peers . lock ( ) . unwrap ( ) ;
271
- let ( should_insert_node_id, pause_read) = match peers. peers . get_mut ( peer_descriptor) {
286
+ let mut peers_lock = self . peers . lock ( ) . unwrap ( ) ;
287
+ let peers = peers_lock. borrow_parts ( ) ;
288
+ let pause_read = match peers. peers . get_mut ( peer_descriptor) {
272
289
None => panic ! ( "Descriptor for read_event is not already known to PeerManager" ) ,
273
290
Some ( peer) => {
274
291
assert ! ( peer. pending_read_buffer. len( ) > 0 ) ;
275
292
assert ! ( peer. pending_read_buffer. len( ) > peer. pending_read_buffer_pos) ;
276
293
277
- let mut insert_node_id = None ;
278
294
let mut read_pos = 0 ;
279
295
while read_pos < data. len ( ) {
280
296
{
@@ -353,6 +369,18 @@ impl<Descriptor: SocketDescriptor> PeerManager<Descriptor> {
353
369
}
354
370
}
355
371
372
+ macro_rules! insert_node_id {
373
+ ( ) => {
374
+ match peers. node_id_to_descriptor. entry( peer. their_node_id. unwrap( ) ) {
375
+ hash_map:: Entry :: Occupied ( _) => {
376
+ peer. their_node_id = None ; // Unset so that we don't generate a peer_disconnected event
377
+ return Err ( PeerHandleError { no_connection_possible: false } )
378
+ } ,
379
+ hash_map:: Entry :: Vacant ( entry) => entry. insert( peer_descriptor. clone( ) ) ,
380
+ } ;
381
+ }
382
+ }
383
+
356
384
let next_step = peer. channel_encryptor . get_noise_step ( ) ;
357
385
match next_step {
358
386
NextNoiseStep :: ActOne => {
@@ -366,7 +394,7 @@ impl<Descriptor: SocketDescriptor> PeerManager<Descriptor> {
366
394
peer. pending_read_buffer = [ 0 ; 18 ] . to_vec ( ) ; // Message length header is 18 bytes
367
395
peer. pending_read_is_header = true ;
368
396
369
- insert_node_id = Some ( peer . their_node_id . unwrap ( ) ) ;
397
+ insert_node_id ! ( ) ;
370
398
let mut local_features = msgs:: LocalFeatures :: new ( ) ;
371
399
if self . initial_syncs_sent . load ( Ordering :: Acquire ) < INITIAL_SYNCS_TO_SEND {
372
400
self . initial_syncs_sent . fetch_add ( 1 , Ordering :: AcqRel ) ;
@@ -382,7 +410,7 @@ impl<Descriptor: SocketDescriptor> PeerManager<Descriptor> {
382
410
peer. pending_read_buffer = [ 0 ; 18 ] . to_vec ( ) ; // Message length header is 18 bytes
383
411
peer. pending_read_is_header = true ;
384
412
peer. their_node_id = Some ( their_node_id) ;
385
- insert_node_id = Some ( peer . their_node_id . unwrap ( ) ) ;
413
+ insert_node_id ! ( ) ;
386
414
} ,
387
415
NextNoiseStep :: NoiseComplete => {
388
416
if peer. pending_read_is_header {
@@ -417,6 +445,9 @@ impl<Descriptor: SocketDescriptor> PeerManager<Descriptor> {
417
445
if msg. local_features . requires_unknown_bits ( ) {
418
446
return Err ( PeerHandleError { no_connection_possible : true } ) ;
419
447
}
448
+ if peer. their_global_features . is_some ( ) {
449
+ return Err ( PeerHandleError { no_connection_possible : false } ) ;
450
+ }
420
451
peer. their_global_features = Some ( msg. global_features ) ;
421
452
peer. their_local_features = Some ( msg. local_features ) ;
422
453
@@ -607,15 +638,10 @@ impl<Descriptor: SocketDescriptor> PeerManager<Descriptor> {
607
638
608
639
Self :: do_attempt_write_data ( peer_descriptor, peer) ;
609
640
610
- ( insert_node_id /* should_insert_node_id */ , peer. pending_outbound_buffer . len ( ) > 10 ) // pause_read
641
+ peer. pending_outbound_buffer . len ( ) > 10 // pause_read
611
642
}
612
643
} ;
613
644
614
- match should_insert_node_id {
615
- Some ( node_id) => { peers. node_id_to_descriptor . insert ( node_id, peer_descriptor. clone ( ) ) ; } ,
616
- None => { }
617
- } ;
618
-
619
645
pause_read
620
646
} ;
621
647
0 commit comments