@@ -20,13 +20,15 @@ use crate::util::test_utils;
20
20
use bitcoin:: network:: constants:: Network ;
21
21
use bitcoin:: secp256k1:: { PublicKey , Secp256k1 } ;
22
22
23
+ use core:: sync:: atomic:: { AtomicU16 , Ordering } ;
23
24
use crate :: io;
24
25
use crate :: io_extras:: read_to_end;
25
26
use crate :: sync:: Arc ;
26
27
27
28
struct MessengerNode {
28
29
keys_manager : Arc < test_utils:: TestKeysInterface > ,
29
30
messenger : OnionMessenger < Arc < test_utils:: TestKeysInterface > , Arc < test_utils:: TestKeysInterface > , Arc < test_utils:: TestLogger > , Arc < TestCustomMessageHandler > > ,
31
+ custom_message_handler : Arc < TestCustomMessageHandler > ,
30
32
logger : Arc < test_utils:: TestLogger > ,
31
33
}
32
34
@@ -54,11 +56,32 @@ impl Writeable for TestCustomMessage {
54
56
}
55
57
}
56
58
57
- struct TestCustomMessageHandler { }
59
+ struct TestCustomMessageHandler {
60
+ num_messages_expected : AtomicU16 ,
61
+ }
62
+
63
+ impl TestCustomMessageHandler {
64
+ fn new ( ) -> Self {
65
+ Self { num_messages_expected : AtomicU16 :: new ( 0 ) }
66
+ }
67
+ }
68
+
69
+ impl Drop for TestCustomMessageHandler {
70
+ fn drop ( & mut self ) {
71
+ #[ cfg( feature = "std" ) ] {
72
+ if std:: thread:: panicking ( ) {
73
+ return ;
74
+ }
75
+ }
76
+ assert_eq ! ( self . num_messages_expected. load( Ordering :: SeqCst ) , 0 ) ;
77
+ }
78
+ }
58
79
59
80
impl CustomOnionMessageHandler for TestCustomMessageHandler {
60
81
type CustomMessage = TestCustomMessage ;
61
- fn handle_custom_message ( & self , _msg : Self :: CustomMessage ) { }
82
+ fn handle_custom_message ( & self , _msg : Self :: CustomMessage ) {
83
+ self . num_messages_expected . fetch_sub ( 1 , Ordering :: SeqCst ) ;
84
+ }
62
85
fn read_custom_message < R : io:: Read > ( & self , message_type : u64 , buffer : & mut R ) -> Result < Option < Self :: CustomMessage > , DecodeError > where Self : Sized {
63
86
if message_type == CUSTOM_MESSAGE_TYPE {
64
87
let buf = read_to_end ( buffer) ?;
@@ -75,9 +98,11 @@ fn create_nodes(num_messengers: u8) -> Vec<MessengerNode> {
75
98
let logger = Arc :: new ( test_utils:: TestLogger :: with_id ( format ! ( "node {}" , i) ) ) ;
76
99
let seed = [ i as u8 ; 32 ] ;
77
100
let keys_manager = Arc :: new ( test_utils:: TestKeysInterface :: new ( & seed, Network :: Testnet ) ) ;
101
+ let custom_message_handler = Arc :: new ( TestCustomMessageHandler :: new ( ) ) ;
78
102
nodes. push ( MessengerNode {
79
103
keys_manager : keys_manager. clone ( ) ,
80
- messenger : OnionMessenger :: new ( keys_manager. clone ( ) , keys_manager. clone ( ) , logger. clone ( ) , Arc :: new ( TestCustomMessageHandler { } ) ) ,
104
+ messenger : OnionMessenger :: new ( keys_manager. clone ( ) , keys_manager. clone ( ) , logger. clone ( ) , custom_message_handler. clone ( ) ) ,
105
+ custom_message_handler,
81
106
logger,
82
107
} ) ;
83
108
}
@@ -92,22 +117,17 @@ fn create_nodes(num_messengers: u8) -> Vec<MessengerNode> {
92
117
nodes
93
118
}
94
119
95
- fn pass_along_path ( path : & Vec < MessengerNode > , expected_path_id : Option < [ u8 ; 32 ] > ) {
120
+ fn pass_along_path ( path : & Vec < MessengerNode > ) {
121
+ path[ path. len ( ) - 1 ] . custom_message_handler . num_messages_expected . fetch_add ( 1 , Ordering :: SeqCst ) ;
96
122
let mut prev_node = & path[ 0 ] ;
97
- let num_nodes = path. len ( ) ;
98
- for ( idx, node) in path. into_iter ( ) . skip ( 1 ) . enumerate ( ) {
123
+ for node in path. into_iter ( ) . skip ( 1 ) {
99
124
let events = prev_node. messenger . release_pending_msgs ( ) ;
100
125
let onion_msg = {
101
126
let msgs = events. get ( & node. get_node_pk ( ) ) . unwrap ( ) ;
102
127
assert_eq ! ( msgs. len( ) , 1 ) ;
103
128
msgs[ 0 ] . clone ( )
104
129
} ;
105
130
node. messenger . handle_onion_message ( & prev_node. get_node_pk ( ) , & onion_msg) ;
106
- if idx == num_nodes - 1 {
107
- node. logger . assert_log_contains (
108
- "lightning::onion_message::messenger" ,
109
- & format ! ( "Received an onion message with path_id: {:02x?}" , expected_path_id) , 1 ) ;
110
- }
111
131
prev_node = node;
112
132
}
113
133
}
@@ -118,7 +138,7 @@ fn one_hop() {
118
138
let test_msg = OnionMessageContents :: Custom ( TestCustomMessage { } ) ;
119
139
120
140
nodes[ 0 ] . messenger . send_onion_message ( & [ ] , Destination :: Node ( nodes[ 1 ] . get_node_pk ( ) ) , test_msg, None ) . unwrap ( ) ;
121
- pass_along_path ( & nodes, None ) ;
141
+ pass_along_path ( & nodes) ;
122
142
}
123
143
124
144
#[ test]
@@ -127,7 +147,7 @@ fn two_unblinded_hops() {
127
147
let test_msg = OnionMessageContents :: Custom ( TestCustomMessage { } ) ;
128
148
129
149
nodes[ 0 ] . messenger . send_onion_message ( & [ nodes[ 1 ] . get_node_pk ( ) ] , Destination :: Node ( nodes[ 2 ] . get_node_pk ( ) ) , test_msg, None ) . unwrap ( ) ;
130
- pass_along_path ( & nodes, None ) ;
150
+ pass_along_path ( & nodes) ;
131
151
}
132
152
133
153
#[ test]
@@ -139,7 +159,7 @@ fn two_unblinded_two_blinded() {
139
159
let blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 3 ] . get_node_pk ( ) , nodes[ 4 ] . get_node_pk ( ) ] , & * nodes[ 4 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
140
160
141
161
nodes[ 0 ] . messenger . send_onion_message ( & [ nodes[ 1 ] . get_node_pk ( ) , nodes[ 2 ] . get_node_pk ( ) ] , Destination :: BlindedPath ( blinded_path) , test_msg, None ) . unwrap ( ) ;
142
- pass_along_path ( & nodes, None ) ;
162
+ pass_along_path ( & nodes) ;
143
163
}
144
164
145
165
#[ test]
@@ -151,7 +171,7 @@ fn three_blinded_hops() {
151
171
let blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 1 ] . get_node_pk ( ) , nodes[ 2 ] . get_node_pk ( ) , nodes[ 3 ] . get_node_pk ( ) ] , & * nodes[ 3 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
152
172
153
173
nodes[ 0 ] . messenger . send_onion_message ( & [ ] , Destination :: BlindedPath ( blinded_path) , test_msg, None ) . unwrap ( ) ;
154
- pass_along_path ( & nodes, None ) ;
174
+ pass_along_path ( & nodes) ;
155
175
}
156
176
157
177
#[ test]
@@ -177,13 +197,13 @@ fn we_are_intro_node() {
177
197
let blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 0 ] . get_node_pk ( ) , nodes[ 1 ] . get_node_pk ( ) , nodes[ 2 ] . get_node_pk ( ) ] , & * nodes[ 2 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
178
198
179
199
nodes[ 0 ] . messenger . send_onion_message ( & [ ] , Destination :: BlindedPath ( blinded_path) , OnionMessageContents :: Custom ( test_msg. clone ( ) ) , None ) . unwrap ( ) ;
180
- pass_along_path ( & nodes, None ) ;
200
+ pass_along_path ( & nodes) ;
181
201
182
202
// Try with a two-hop blinded path where we are the introduction node.
183
203
let blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 0 ] . get_node_pk ( ) , nodes[ 1 ] . get_node_pk ( ) ] , & * nodes[ 1 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
184
204
nodes[ 0 ] . messenger . send_onion_message ( & [ ] , Destination :: BlindedPath ( blinded_path) , OnionMessageContents :: Custom ( test_msg) , None ) . unwrap ( ) ;
185
205
nodes. remove ( 2 ) ;
186
- pass_along_path ( & nodes, None ) ;
206
+ pass_along_path ( & nodes) ;
187
207
}
188
208
189
209
#[ test]
@@ -216,7 +236,7 @@ fn reply_path() {
216
236
// Destination::Node
217
237
let reply_path = BlindedPath :: new_for_message ( & [ nodes[ 2 ] . get_node_pk ( ) , nodes[ 1 ] . get_node_pk ( ) , nodes[ 0 ] . get_node_pk ( ) ] , & * nodes[ 0 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
218
238
nodes[ 0 ] . messenger . send_onion_message ( & [ nodes[ 1 ] . get_node_pk ( ) , nodes[ 2 ] . get_node_pk ( ) ] , Destination :: Node ( nodes[ 3 ] . get_node_pk ( ) ) , OnionMessageContents :: Custom ( test_msg. clone ( ) ) , Some ( reply_path) ) . unwrap ( ) ;
219
- pass_along_path ( & nodes, None ) ;
239
+ pass_along_path ( & nodes) ;
220
240
// Make sure the last node successfully decoded the reply path.
221
241
nodes[ 3 ] . logger . assert_log_contains (
222
242
"lightning::onion_message::messenger" ,
@@ -227,7 +247,7 @@ fn reply_path() {
227
247
let reply_path = BlindedPath :: new_for_message ( & [ nodes[ 2 ] . get_node_pk ( ) , nodes[ 1 ] . get_node_pk ( ) , nodes[ 0 ] . get_node_pk ( ) ] , & * nodes[ 0 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
228
248
229
249
nodes[ 0 ] . messenger . send_onion_message ( & [ ] , Destination :: BlindedPath ( blinded_path) , OnionMessageContents :: Custom ( test_msg) , Some ( reply_path) ) . unwrap ( ) ;
230
- pass_along_path ( & nodes, None ) ;
250
+ pass_along_path ( & nodes) ;
231
251
nodes[ 3 ] . logger . assert_log_contains (
232
252
"lightning::onion_message::messenger" ,
233
253
& format ! ( "Received an onion message with path_id None and a reply_path" ) , 2 ) ;
@@ -264,3 +284,20 @@ fn peer_buffer_full() {
264
284
let err = nodes[ 0 ] . messenger . send_onion_message ( & [ ] , Destination :: Node ( nodes[ 1 ] . get_node_pk ( ) ) , OnionMessageContents :: Custom ( test_msg) , None ) . unwrap_err ( ) ;
265
285
assert_eq ! ( err, SendError :: BufferFull ) ;
266
286
}
287
+
288
+ #[ test]
289
+ fn many_hops ( ) {
290
+ // Check we can send over a route with many hops. This will exercise our logic for onion messages
291
+ // of size [`crate::onion_message::packet::BIG_PACKET_HOP_DATA_LEN`].
292
+ let num_nodes: usize = 25 ;
293
+ let nodes = create_nodes ( num_nodes as u8 ) ;
294
+ let test_msg = OnionMessageContents :: Custom ( TestCustomMessage { } ) ;
295
+
296
+ let mut intermediates = vec ! [ ] ;
297
+ for i in 1 ..( num_nodes-1 ) {
298
+ intermediates. push ( nodes[ i] . get_node_pk ( ) ) ;
299
+ }
300
+
301
+ nodes[ 0 ] . messenger . send_onion_message ( & intermediates, Destination :: Node ( nodes[ num_nodes-1 ] . get_node_pk ( ) ) , test_msg, None ) . unwrap ( ) ;
302
+ pass_along_path ( & nodes) ;
303
+ }
0 commit comments