Skip to content

Commit aa43ca9

Browse files
committed
refactor: Make Transport the source of their_node_id
This patch ontinues to separate state the exists before NOISE is complete and after it is complete to unlock future refactoring. Most callers immediately unwrapped the value from Peer and can just call Transport::get_their_node_id(). The duplicate connection disconnect path has been rewritten to determine whether or not to remove & send a disconnect event without needing to use a None value for Option<PublicKey> All other users are in contexts where they either exit early or continue if !transport.is_connected() so it is also safe to call Transport::get_their_node_id()
1 parent 8e75a53 commit aa43ca9

File tree

2 files changed

+90
-55
lines changed

2 files changed

+90
-55
lines changed

lightning/src/ln/peers/handler.rs

Lines changed: 47 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ pub(super) trait ITransport {
6060
/// Returns true if the connection is established and encrypted messages can be sent.
6161
fn is_connected(&self) -> bool;
6262

63+
/// Returns the node_id of the remote node. Panics if not connected.
64+
fn get_their_node_id(&self) -> PublicKey;
65+
6366
/// Returns all Messages that have been received and can be parsed by the Transport
6467
fn drain_messages<L: Deref>(&mut self, logger: L) -> Result<Vec<Message>, PeerHandleError> where L::Target: Logger;
6568

@@ -193,7 +196,6 @@ enum InitSyncTracker{
193196
struct Peer {
194197
transport: Transport,
195198
outbound: bool,
196-
their_node_id: Option<PublicKey>,
197199
their_features: Option<InitFeatures>,
198200

199201
pending_outbound_buffer: OutboundQueue,
@@ -339,7 +341,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
339341
if !p.transport.is_connected() || p.their_features.is_none() {
340342
return None;
341343
}
342-
p.their_node_id
344+
Some(p.transport.get_their_node_id())
343345
}).collect()
344346
}
345347

@@ -372,7 +374,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
372374
if peers.peers.insert(descriptor, Peer {
373375
transport,
374376
outbound: true,
375-
their_node_id: Some(their_node_id.clone()),
376377
their_features: None,
377378

378379
pending_outbound_buffer: OutboundQueue::new(MSG_BUFF_SIZE),
@@ -400,7 +401,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
400401
if peers.peers.insert(descriptor, Peer {
401402
transport: Transport::new_inbound(&self.our_node_secret, &self.get_ephemeral_key()),
402403
outbound: false,
403-
their_node_id: None,
404404
their_features: None,
405405

406406
pending_outbound_buffer: OutboundQueue::new(MSG_BUFF_SIZE),
@@ -539,22 +539,20 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
539539

540540
// If the transport is newly connected, do the appropriate set up for the connection
541541
if peer.transport.is_connected() {
542-
let their_node_id = peer.transport.their_node_id.unwrap();
542+
let their_node_id = peer.transport.get_their_node_id();
543543

544544
match peers.node_id_to_descriptor.entry(their_node_id.clone()) {
545545
hash_map::Entry::Occupied(entry) => {
546546
if entry.get() != peer_descriptor {
547547
// Existing entry in map is from a different descriptor, this is a duplicate
548548
log_trace!(self.logger, "Got second connection with {}, closing", log_pubkey!(&their_node_id));
549-
peer.their_node_id = None;
550549
return Err(PeerHandleError { no_connection_possible: false });
551550
} else {
552551
// read_event for existing peer
553552
}
554553
},
555554
hash_map::Entry::Vacant(entry) => {
556555
log_trace!(self.logger, "Finished noise handshake for connection with {}", log_pubkey!(&their_node_id));
557-
peer.their_node_id = Some(their_node_id.clone());
558556

559557
if peer.outbound {
560558
let mut features = InitFeatures::known();
@@ -625,12 +623,12 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
625623

626624
/// Process an incoming message and return a decision (ok, lightning error, peer handling error) regarding the next action with the peer
627625
fn handle_message(&self, peers_needing_send: &mut HashSet<Descriptor>, peer: &mut Peer, peer_descriptor: Descriptor, message: wire::Message) -> Result<(), MessageHandlingError> {
628-
log_trace!(self.logger, "Received message of type {} from {}", message.type_id(), log_pubkey!(peer.their_node_id.unwrap()));
626+
log_trace!(self.logger, "Received message of type {} from {}", message.type_id(), log_pubkey!(peer.transport.get_their_node_id()));
629627

630628
// Need an Init as first message
631629
if let wire::Message::Init(_) = message {
632630
} else if peer.their_features.is_none() {
633-
log_trace!(self.logger, "Peer {} sent non-Init first message", log_pubkey!(peer.their_node_id.unwrap()));
631+
log_trace!(self.logger, "Peer {} sent non-Init first message", log_pubkey!(peer.transport.get_their_node_id()));
634632
return Err(PeerHandleError{ no_connection_possible: false }.into());
635633
}
636634

@@ -663,21 +661,21 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
663661
peers_needing_send.insert(peer_descriptor.clone());
664662
}
665663
if !msg.features.supports_static_remote_key() {
666-
log_debug!(self.logger, "Peer {} does not support static remote key, disconnecting with no_connection_possible", log_pubkey!(peer.their_node_id.unwrap()));
664+
log_debug!(self.logger, "Peer {} does not support static remote key, disconnecting with no_connection_possible", log_pubkey!(peer.transport.get_their_node_id()));
667665
return Err(PeerHandleError{ no_connection_possible: true }.into());
668666
}
669667

670668
if !peer.outbound {
671669
let mut features = InitFeatures::known();
672-
if !self.message_handler.route_handler.should_request_full_sync(&peer.their_node_id.unwrap()) {
670+
if !self.message_handler.route_handler.should_request_full_sync(&peer.transport.get_their_node_id()) {
673671
features.clear_initial_routing_sync();
674672
}
675673

676674
let resp = msgs::Init { features };
677675
self.enqueue_message(peers_needing_send, &mut peer.transport, &mut peer.pending_outbound_buffer, &peer_descriptor, &resp);
678676
}
679677

680-
self.message_handler.chan_handler.peer_connected(&peer.their_node_id.unwrap(), &msg);
678+
self.message_handler.chan_handler.peer_connected(&peer.transport.get_their_node_id(), &msg);
681679
peer.their_features = Some(msg.features);
682680
},
683681
wire::Message::Error(msg) => {
@@ -690,11 +688,11 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
690688
}
691689

692690
if data_is_printable {
693-
log_debug!(self.logger, "Got Err message from {}: {}", log_pubkey!(peer.their_node_id.unwrap()), msg.data);
691+
log_debug!(self.logger, "Got Err message from {}: {}", log_pubkey!(peer.transport.get_their_node_id()), msg.data);
694692
} else {
695-
log_debug!(self.logger, "Got Err message from {} with non-ASCII error message", log_pubkey!(peer.their_node_id.unwrap()));
693+
log_debug!(self.logger, "Got Err message from {} with non-ASCII error message", log_pubkey!(peer.transport.get_their_node_id()));
696694
}
697-
self.message_handler.chan_handler.handle_error(&peer.their_node_id.unwrap(), &msg);
695+
self.message_handler.chan_handler.handle_error(&peer.transport.get_their_node_id(), &msg);
698696
if msg.channel_id == [0; 32] {
699697
return Err(PeerHandleError{ no_connection_possible: true }.into());
700698
}
@@ -712,59 +710,59 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
712710

713711
// Channel messages:
714712
wire::Message::OpenChannel(msg) => {
715-
self.message_handler.chan_handler.handle_open_channel(&peer.their_node_id.unwrap(), peer.their_features.clone().unwrap(), &msg);
713+
self.message_handler.chan_handler.handle_open_channel(&peer.transport.get_their_node_id(), peer.their_features.clone().unwrap(), &msg);
716714
},
717715
wire::Message::AcceptChannel(msg) => {
718-
self.message_handler.chan_handler.handle_accept_channel(&peer.their_node_id.unwrap(), peer.their_features.clone().unwrap(), &msg);
716+
self.message_handler.chan_handler.handle_accept_channel(&peer.transport.get_their_node_id(), peer.their_features.clone().unwrap(), &msg);
719717
},
720718

721719
wire::Message::FundingCreated(msg) => {
722-
self.message_handler.chan_handler.handle_funding_created(&peer.their_node_id.unwrap(), &msg);
720+
self.message_handler.chan_handler.handle_funding_created(&peer.transport.get_their_node_id(), &msg);
723721
},
724722
wire::Message::FundingSigned(msg) => {
725-
self.message_handler.chan_handler.handle_funding_signed(&peer.their_node_id.unwrap(), &msg);
723+
self.message_handler.chan_handler.handle_funding_signed(&peer.transport.get_their_node_id(), &msg);
726724
},
727725
wire::Message::FundingLocked(msg) => {
728-
self.message_handler.chan_handler.handle_funding_locked(&peer.their_node_id.unwrap(), &msg);
726+
self.message_handler.chan_handler.handle_funding_locked(&peer.transport.get_their_node_id(), &msg);
729727
},
730728

731729
wire::Message::Shutdown(msg) => {
732-
self.message_handler.chan_handler.handle_shutdown(&peer.their_node_id.unwrap(), &msg);
730+
self.message_handler.chan_handler.handle_shutdown(&peer.transport.get_their_node_id(), &msg);
733731
},
734732
wire::Message::ClosingSigned(msg) => {
735-
self.message_handler.chan_handler.handle_closing_signed(&peer.their_node_id.unwrap(), &msg);
733+
self.message_handler.chan_handler.handle_closing_signed(&peer.transport.get_their_node_id(), &msg);
736734
},
737735

738736
// Commitment messages:
739737
wire::Message::UpdateAddHTLC(msg) => {
740-
self.message_handler.chan_handler.handle_update_add_htlc(&peer.their_node_id.unwrap(), &msg);
738+
self.message_handler.chan_handler.handle_update_add_htlc(&peer.transport.get_their_node_id(), &msg);
741739
},
742740
wire::Message::UpdateFulfillHTLC(msg) => {
743-
self.message_handler.chan_handler.handle_update_fulfill_htlc(&peer.their_node_id.unwrap(), &msg);
741+
self.message_handler.chan_handler.handle_update_fulfill_htlc(&peer.transport.get_their_node_id(), &msg);
744742
},
745743
wire::Message::UpdateFailHTLC(msg) => {
746-
self.message_handler.chan_handler.handle_update_fail_htlc(&peer.their_node_id.unwrap(), &msg);
744+
self.message_handler.chan_handler.handle_update_fail_htlc(&peer.transport.get_their_node_id(), &msg);
747745
},
748746
wire::Message::UpdateFailMalformedHTLC(msg) => {
749-
self.message_handler.chan_handler.handle_update_fail_malformed_htlc(&peer.their_node_id.unwrap(), &msg);
747+
self.message_handler.chan_handler.handle_update_fail_malformed_htlc(&peer.transport.get_their_node_id(), &msg);
750748
},
751749

752750
wire::Message::CommitmentSigned(msg) => {
753-
self.message_handler.chan_handler.handle_commitment_signed(&peer.their_node_id.unwrap(), &msg);
751+
self.message_handler.chan_handler.handle_commitment_signed(&peer.transport.get_their_node_id(), &msg);
754752
},
755753
wire::Message::RevokeAndACK(msg) => {
756-
self.message_handler.chan_handler.handle_revoke_and_ack(&peer.their_node_id.unwrap(), &msg);
754+
self.message_handler.chan_handler.handle_revoke_and_ack(&peer.transport.get_their_node_id(), &msg);
757755
},
758756
wire::Message::UpdateFee(msg) => {
759-
self.message_handler.chan_handler.handle_update_fee(&peer.their_node_id.unwrap(), &msg);
757+
self.message_handler.chan_handler.handle_update_fee(&peer.transport.get_their_node_id(), &msg);
760758
},
761759
wire::Message::ChannelReestablish(msg) => {
762-
self.message_handler.chan_handler.handle_channel_reestablish(&peer.their_node_id.unwrap(), &msg);
760+
self.message_handler.chan_handler.handle_channel_reestablish(&peer.transport.get_their_node_id(), &msg);
763761
},
764762

765763
// Routing messages:
766764
wire::Message::AnnouncementSignatures(msg) => {
767-
self.message_handler.chan_handler.handle_announcement_signatures(&peer.their_node_id.unwrap(), &msg);
765+
self.message_handler.chan_handler.handle_announcement_signatures(&peer.transport.get_their_node_id(), &msg);
768766
},
769767
wire::Message::ChannelAnnouncement(msg) => {
770768
let should_forward = match self.message_handler.route_handler.handle_channel_announcement(&msg) {
@@ -1009,13 +1007,10 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
10091007
!peer.should_forward_channel_announcement(msg.contents.short_channel_id) {
10101008
continue
10111009
}
1012-
match peer.their_node_id {
1013-
None => continue,
1014-
Some(their_node_id) => {
1015-
if their_node_id == msg.contents.node_id_1 || their_node_id == msg.contents.node_id_2 {
1016-
continue
1017-
}
1018-
}
1010+
1011+
let their_node_id = peer.transport.get_their_node_id();
1012+
if their_node_id == msg.contents.node_id_1 || their_node_id == msg.contents.node_id_2 {
1013+
continue
10191014
}
10201015
if peer.transport.is_connected() {
10211016
peer.transport.enqueue_message(msg, &mut peer.pending_outbound_buffer, &*self.logger);
@@ -1128,12 +1123,17 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
11281123
match peer_option {
11291124
None => panic!("Descriptor for disconnect_event is not already known to PeerManager"),
11301125
Some(peer) => {
1131-
match peer.their_node_id {
1132-
Some(node_id) => {
1126+
if peer.transport.is_connected() {
1127+
let node_id = peer.transport.get_their_node_id();
1128+
1129+
if peers.node_id_to_descriptor.get(&node_id).unwrap() == descriptor {
11331130
peers.node_id_to_descriptor.remove(&node_id);
11341131
self.message_handler.chan_handler.peer_disconnected(&node_id, no_connection_possible);
1135-
},
1136-
None => {}
1132+
} else {
1133+
// This must have been generated from a duplicate connection error
1134+
}
1135+
} else {
1136+
// Unconnected nodes never make it into node_id_to_descriptor
11371137
}
11381138
}
11391139
};
@@ -1156,18 +1156,11 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
11561156
if peer.awaiting_pong {
11571157
peers_needing_send.remove(descriptor);
11581158
descriptors_needing_disconnect.push(descriptor.clone());
1159-
match peer.their_node_id {
1160-
Some(node_id) => {
1161-
log_trace!(self.logger, "Disconnecting peer with id {} due to ping timeout", node_id);
1162-
node_id_to_descriptor.remove(&node_id);
1163-
self.message_handler.chan_handler.peer_disconnected(&node_id, false);
1164-
}
1165-
None => {
1166-
// This can't actually happen as we should have hit
1167-
// is_connected() previously on this same peer.
1168-
unreachable!();
1169-
},
1170-
}
1159+
let their_node_id = peer.transport.get_their_node_id();
1160+
log_trace!(self.logger, "Disconnecting peer with id {} due to ping timeout", their_node_id);
1161+
node_id_to_descriptor.remove(&their_node_id);
1162+
self.message_handler.chan_handler.peer_disconnected(&their_node_id, false);
1163+
11711164
return false;
11721165
}
11731166

lightning/src/ln/peers/transport.rs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub trait IPeerHandshake {
3131
pub(super) struct Transport<PeerHandshakeImpl: IPeerHandshake=PeerHandshake> {
3232
pub(super) conduit: Option<Conduit>,
3333
handshake: PeerHandshakeImpl,
34-
pub(super) their_node_id: Option<PublicKey>,
34+
their_node_id: Option<PublicKey>,
3535
}
3636

3737
impl<PeerHandshakeImpl: IPeerHandshake> ITransport for Transport<PeerHandshakeImpl> {
@@ -160,6 +160,11 @@ impl<PeerHandshakeImpl: IPeerHandshake> ITransport for Transport<PeerHandshakeIm
160160
}
161161
}
162162
}
163+
164+
fn get_their_node_id(&self) -> PublicKey {
165+
assert!(self.is_connected(), "Retrieving the remote node_id is only supported after transport is connected");
166+
self.their_node_id.unwrap()
167+
}
163168
}
164169

165170
#[cfg(test)]
@@ -254,6 +259,42 @@ mod tests {
254259
assert!(transport.is_connected());
255260
}
256261

262+
// Test get_their_node_id() in unconnected and connected scenarios
263+
#[test]
264+
#[should_panic(expected = "Retrieving the remote node_id is only supported after transport is connected")]
265+
fn inbound_unconnected_get_their_node_id_panics() {
266+
let transport = create_inbound_for_test::<PeerHandshakeTestStubFail>();
267+
268+
let _should_panic = transport.get_their_node_id();
269+
}
270+
271+
#[test]
272+
#[should_panic(expected = "Retrieving the remote node_id is only supported after transport is connected")]
273+
fn outbound_unconnected_get_their_node_id_panics() {
274+
let mut transport = create_outbound_for_test::<PeerHandshakeTestStubFail>();
275+
transport.set_up_outbound();
276+
277+
let _should_panic = transport.get_their_node_id();
278+
}
279+
280+
#[test]
281+
fn inbound_unconnected_get_their_node_id() {
282+
let mut transport = create_inbound_for_test::<PeerHandshakeTestStubComplete>();
283+
let mut spy = Vec::new();
284+
285+
transport.process_input(&[], &mut spy).unwrap();
286+
let _no_panic = transport.get_their_node_id();
287+
}
288+
289+
#[test]
290+
fn outbound_unconnected_get_their_node_id() {
291+
let mut transport = create_inbound_for_test::<PeerHandshakeTestStubComplete>();
292+
let mut spy = Vec::new();
293+
294+
transport.process_input(&[], &mut spy).unwrap();
295+
let _no_panic = transport.get_their_node_id();
296+
}
297+
257298
// Test that when a handshake completes is_connected() is correct
258299
#[test]
259300
fn outbound_handshake_complete_ready_for_encryption() {
@@ -262,6 +303,7 @@ mod tests {
262303

263304
transport.process_input(&[], &mut spy).unwrap();
264305
assert!(transport.is_connected());
306+
let _no_panic = transport.get_their_node_id();
265307
}
266308

267309
#[test]

0 commit comments

Comments
 (0)