@@ -38,13 +38,29 @@ pub(super) struct Decryptor {
38
38
39
39
pending_message_length : Option < usize > ,
40
40
read_buffer : Option < Vec < u8 > > ,
41
+ poisoned : bool , // signal an error has occurred so None is returned on iteration after failure
41
42
}
42
43
43
44
impl Iterator for Decryptor {
44
- type Item = Vec < u8 > ;
45
+ type Item = Result < Option < Vec < u8 > > , String > ;
45
46
46
47
fn next ( & mut self ) -> Option < Self :: Item > {
47
- self . decrypt_single_message ( None )
48
+ if self . poisoned {
49
+ return None ;
50
+ }
51
+
52
+ match self . decrypt_single_message ( None ) {
53
+ Ok ( Some ( result) ) => {
54
+ Some ( Ok ( Some ( result) ) )
55
+ } ,
56
+ Ok ( None ) => {
57
+ None
58
+ }
59
+ Err ( e) => {
60
+ self . poisoned = true ;
61
+ Some ( Err ( e) )
62
+ }
63
+ }
48
64
}
49
65
}
50
66
@@ -62,7 +78,8 @@ impl Conduit {
62
78
receiving_chaining_key : chaining_key,
63
79
receiving_nonce : 0 ,
64
80
read_buffer : None ,
65
- pending_message_length : None
81
+ pending_message_length : None ,
82
+ poisoned : false
66
83
}
67
84
}
68
85
}
@@ -81,8 +98,8 @@ impl Conduit {
81
98
/// If a message pending in the buffer still hasn't been decrypted, that message will be
82
99
/// returned in lieu of anything new, even if new data is provided.
83
100
#[ cfg( any( test, feature = "fuzztarget" ) ) ]
84
- pub fn decrypt_single_message ( & mut self , new_data : Option < & [ u8 ] > ) -> Option < Vec < u8 > > {
85
- self . decryptor . decrypt_single_message ( new_data)
101
+ pub fn decrypt_single_message ( & mut self , new_data : Option < & [ u8 ] > ) -> Result < Option < Vec < u8 > > , String > {
102
+ Ok ( self . decryptor . decrypt_single_message ( new_data) ? )
86
103
}
87
104
88
105
fn increment_nonce ( nonce : & mut u32 , chaining_key : & mut SymmetricKey , key : & mut SymmetricKey ) {
@@ -135,7 +152,7 @@ impl Decryptor {
135
152
/// only the first message will be returned, and the rest stored in the internal buffer.
136
153
/// If a message pending in the buffer still hasn't been decrypted, that message will be
137
154
/// returned in lieu of anything new, even if new data is provided.
138
- pub fn decrypt_single_message ( & mut self , new_data : Option < & [ u8 ] > ) -> Option < Vec < u8 > > {
155
+ pub fn decrypt_single_message ( & mut self , new_data : Option < & [ u8 ] > ) -> Result < Option < Vec < u8 > > , String > {
139
156
let mut read_buffer = if let Some ( buffer) = self . read_buffer . take ( ) {
140
157
buffer
141
158
} else {
@@ -150,25 +167,25 @@ impl Decryptor {
150
167
panic ! ( "Attempted to decrypt message longer than 65535 + 16 bytes!" ) ;
151
168
}
152
169
153
- let ( current_message, offset) = self . decrypt ( & read_buffer[ ..] ) ;
170
+ let ( current_message, offset) = self . decrypt ( & read_buffer[ ..] ) ? ;
154
171
read_buffer. drain ( ..offset) ; // drain the read buffer
155
172
self . read_buffer = Some ( read_buffer) ; // assign the new value to the built-in buffer
156
- current_message
173
+ Ok ( current_message)
157
174
}
158
175
159
- fn decrypt ( & mut self , buffer : & [ u8 ] ) -> ( Option < Vec < u8 > > , usize ) {
176
+ fn decrypt ( & mut self , buffer : & [ u8 ] ) -> Result < ( Option < Vec < u8 > > , usize ) , String > {
160
177
let message_length = if let Some ( length) = self . pending_message_length {
161
178
// we have already decrypted the header
162
179
length
163
180
} else {
164
181
if buffer. len ( ) < TAGGED_MESSAGE_LENGTH_HEADER_SIZE {
165
182
// A message must be at least 18 bytes (2 for encrypted length, 16 for the tag)
166
- return ( None , 0 ) ;
183
+ return Ok ( ( None , 0 ) ) ;
167
184
}
168
185
169
186
let encrypted_length = & buffer[ 0 ..TAGGED_MESSAGE_LENGTH_HEADER_SIZE ] ;
170
187
let mut length_bytes = [ 0u8 ; MESSAGE_LENGTH_HEADER_SIZE ] ;
171
- length_bytes. copy_from_slice ( & chacha:: decrypt ( & self . receiving_key , self . receiving_nonce as u64 , & [ 0 ; 0 ] , encrypted_length) . unwrap ( ) ) ;
188
+ length_bytes. copy_from_slice ( & chacha:: decrypt ( & self . receiving_key , self . receiving_nonce as u64 , & [ 0 ; 0 ] , encrypted_length) ? ) ;
172
189
173
190
self . increment_nonce ( ) ;
174
191
@@ -180,18 +197,18 @@ impl Decryptor {
180
197
181
198
if buffer. len ( ) < message_end_index {
182
199
self . pending_message_length = Some ( message_length) ;
183
- return ( None , 0 ) ;
200
+ return Ok ( ( None , 0 ) ) ;
184
201
}
185
202
186
203
self . pending_message_length = None ;
187
204
188
205
let encrypted_message = & buffer[ TAGGED_MESSAGE_LENGTH_HEADER_SIZE ..message_end_index] ;
189
206
190
- let message = chacha:: decrypt ( & self . receiving_key , self . receiving_nonce as u64 , & [ 0 ; 0 ] , encrypted_message) . unwrap ( ) ;
207
+ let message = chacha:: decrypt ( & self . receiving_key , self . receiving_nonce as u64 , & [ 0 ; 0 ] , encrypted_message) ? ;
191
208
192
209
self . increment_nonce ( ) ;
193
210
194
- ( Some ( message) , message_end_index)
211
+ Ok ( ( Some ( message) , message_end_index) )
195
212
}
196
213
197
214
fn increment_nonce ( & mut self ) {
@@ -243,7 +260,7 @@ mod tests {
243
260
let encrypted_message = connected_peer. encrypt ( & message) ;
244
261
assert_eq ! ( encrypted_message. len( ) , 2 + 16 + 16 ) ;
245
262
246
- let decrypted_message = remote_peer. decrypt_single_message ( Some ( & encrypted_message) ) . unwrap ( ) ;
263
+ let decrypted_message = remote_peer. decrypt_single_message ( Some ( & encrypted_message) ) . unwrap ( ) . unwrap ( ) ;
247
264
assert_eq ! ( decrypted_message, Vec :: <u8 >:: new( ) ) ;
248
265
}
249
266
@@ -296,17 +313,123 @@ mod tests {
296
313
let mut current_encrypted_message = encrypted_messages. remove ( 0 ) ;
297
314
let next_encrypted_message = encrypted_messages. remove ( 0 ) ;
298
315
current_encrypted_message. extend_from_slice ( & next_encrypted_message) ;
299
- let decrypted_message = remote_peer. decrypt_single_message ( Some ( & current_encrypted_message) ) . unwrap ( ) ;
316
+ let decrypted_message = remote_peer. decrypt_single_message ( Some ( & current_encrypted_message) ) . unwrap ( ) . unwrap ( ) ;
300
317
assert_eq ! ( decrypted_message, message) ;
301
318
}
302
319
303
320
for _ in 0 ..501 {
304
321
// decrypt messages directly from buffer without adding to it
305
- let decrypted_message = remote_peer. decrypt_single_message ( None ) . unwrap ( ) ;
322
+ let decrypted_message = remote_peer. decrypt_single_message ( None ) . unwrap ( ) . unwrap ( ) ;
306
323
assert_eq ! ( decrypted_message, message) ;
307
324
}
308
325
}
309
326
327
+ // Decryption errors should result in Err
328
+ #[ test]
329
+ fn decryption_failure_errors ( ) {
330
+ let ( mut connected_peer, mut remote_peer) = setup_peers ( ) ;
331
+ let encrypted = remote_peer. encrypt ( & [ 1 ] ) ;
332
+
333
+ connected_peer. decryptor . receiving_key = [ 0 ; 32 ] ;
334
+ assert_eq ! ( connected_peer. decrypt_single_message( Some ( & encrypted) ) , Err ( "invalid hmac" . to_string( ) ) ) ;
335
+ }
336
+
337
+ // Test next()::None
338
+ #[ test]
339
+ fn decryptor_iterator_empty ( ) {
340
+ let ( mut connected_peer, _) = setup_peers ( ) ;
341
+
342
+ assert_eq ! ( connected_peer. decryptor. next( ) , None ) ;
343
+ }
344
+
345
+ // Test next() -> next()::None
346
+ #[ test]
347
+ fn decryptor_iterator_one_item_valid ( ) {
348
+ let ( mut connected_peer, mut remote_peer) = setup_peers ( ) ;
349
+ let encrypted = remote_peer. encrypt ( & [ 1 ] ) ;
350
+ connected_peer. read ( & encrypted) ;
351
+
352
+ assert_eq ! ( connected_peer. decryptor. next( ) , Some ( Ok ( Some ( vec![ 1 ] ) ) ) ) ;
353
+ assert_eq ! ( connected_peer. decryptor. next( ) , None ) ;
354
+ }
355
+
356
+ // Test next()::err -> next()::None
357
+ #[ test]
358
+ fn decryptor_iterator_error ( ) {
359
+ let ( mut connected_peer, mut remote_peer) = setup_peers ( ) ;
360
+ let encrypted = remote_peer. encrypt ( & [ 1 ] ) ;
361
+ connected_peer. read ( & encrypted) ;
362
+
363
+ connected_peer. decryptor . receiving_key = [ 0 ; 32 ] ;
364
+ assert_eq ! ( connected_peer. decryptor. next( ) , Some ( Err ( "invalid hmac" . to_string( ) ) ) ) ;
365
+ assert_eq ! ( connected_peer. decryptor. next( ) , None ) ;
366
+ }
367
+
368
+ // Test next()::Some -> next()::err -> next()::None
369
+ #[ test]
370
+ fn decryptor_iterator_error_after_success ( ) {
371
+ let ( mut connected_peer, mut remote_peer) = setup_peers ( ) ;
372
+ let encrypted = remote_peer. encrypt ( & [ 1 ] ) ;
373
+ connected_peer. read ( & encrypted) ;
374
+ let encrypted = remote_peer. encrypt ( & [ 2 ] ) ;
375
+ connected_peer. read ( & encrypted) ;
376
+
377
+ assert_eq ! ( connected_peer. decryptor. next( ) , Some ( Ok ( Some ( vec![ 1 ] ) ) ) ) ;
378
+ connected_peer. decryptor . receiving_key = [ 0 ; 32 ] ;
379
+ assert_eq ! ( connected_peer. decryptor. next( ) , Some ( Err ( "invalid hmac" . to_string( ) ) ) ) ;
380
+ assert_eq ! ( connected_peer. decryptor. next( ) , None ) ;
381
+ }
382
+
383
+ // Test that next()::Some -> next()::err -> next()::None
384
+ // Error should poison decryptor
385
+ #[ test]
386
+ fn decryptor_iterator_next_after_error_returns_none ( ) {
387
+ let ( mut connected_peer, mut remote_peer) = setup_peers ( ) ;
388
+ let encrypted = remote_peer. encrypt ( & [ 1 ] ) ;
389
+ connected_peer. read ( & encrypted) ;
390
+ let encrypted = remote_peer. encrypt ( & [ 2 ] ) ;
391
+ connected_peer. read ( & encrypted) ;
392
+ let encrypted = remote_peer. encrypt ( & [ 3 ] ) ;
393
+ connected_peer. read ( & encrypted) ;
394
+
395
+ // Get one valid value
396
+ assert_eq ! ( connected_peer. decryptor. next( ) , Some ( Ok ( Some ( vec![ 1 ] ) ) ) ) ;
397
+ let valid_receiving_key = connected_peer. decryptor . receiving_key ;
398
+
399
+ // Corrupt the receiving key and ensure we get a failure
400
+ connected_peer. decryptor . receiving_key = [ 0 ; 32 ] ;
401
+ assert_eq ! ( connected_peer. decryptor. next( ) , Some ( Err ( "invalid hmac" . to_string( ) ) ) ) ;
402
+
403
+ // Restore the receiving key, do a read and ensure None is returned (poisoned)
404
+ connected_peer. decryptor . receiving_key = valid_receiving_key;
405
+ assert_eq ! ( connected_peer. decryptor. next( ) , None ) ;
406
+ }
407
+
408
+ // Test next()::Some -> next()::err -> read() -> next()::None
409
+ // Error should poison decryptor even after future reads
410
+ #[ test]
411
+ fn decryptor_iterator_read_next_after_error_returns_none ( ) {
412
+ let ( mut connected_peer, mut remote_peer) = setup_peers ( ) ;
413
+ let encrypted = remote_peer. encrypt ( & [ 1 ] ) ;
414
+ connected_peer. read ( & encrypted) ;
415
+ let encrypted = remote_peer. encrypt ( & [ 2 ] ) ;
416
+ connected_peer. read ( & encrypted) ;
417
+
418
+ // Get one valid value
419
+ assert_eq ! ( connected_peer. decryptor. next( ) , Some ( Ok ( Some ( vec![ 1 ] ) ) ) ) ;
420
+ let valid_receiving_key = connected_peer. decryptor . receiving_key ;
421
+
422
+ // Corrupt the receiving key and ensure we get a failure
423
+ connected_peer. decryptor . receiving_key = [ 0 ; 32 ] ;
424
+ assert_eq ! ( connected_peer. decryptor. next( ) , Some ( Err ( "invalid hmac" . to_string( ) ) ) ) ;
425
+
426
+ // Restore the receiving key, do a read and ensure None is returned (poisoned)
427
+ let encrypted = remote_peer. encrypt ( & [ 3 ] ) ;
428
+ connected_peer. read ( & encrypted) ;
429
+ connected_peer. decryptor . receiving_key = valid_receiving_key;
430
+ assert_eq ! ( connected_peer. decryptor. next( ) , None ) ;
431
+ }
432
+
310
433
#[ test]
311
434
fn max_msg_len_limit_value ( ) {
312
435
assert_eq ! ( LN_MAX_MSG_LEN , 65535 ) ;
@@ -328,6 +451,6 @@ mod tests {
328
451
329
452
// MSG should not exceed LN_MAX_MSG_LEN + 16
330
453
let msg = [ 4u8 ; LN_MAX_MSG_LEN + 17 ] ;
331
- connected_peer. decrypt_single_message ( Some ( & msg) ) ;
454
+ connected_peer. decrypt_single_message ( Some ( & msg) ) . unwrap ( ) ;
332
455
}
333
456
}
0 commit comments