@@ -169,6 +169,18 @@ impl PeerChannelEncryptor {
169
169
res. extend_from_slice ( & tag) ;
170
170
}
171
171
172
+ fn decrypt_in_place_with_ad ( inout : & mut [ u8 ] , n : u64 , key : & [ u8 ; 32 ] , h : & [ u8 ] ) -> Result < ( ) , LightningError > {
173
+ let mut nonce = [ 0 ; 12 ] ;
174
+ nonce[ 4 ..] . copy_from_slice ( & n. to_le_bytes ( ) [ ..] ) ;
175
+
176
+ let mut chacha = ChaCha20Poly1305RFC :: new ( key, & nonce, h) ;
177
+ let ( inout, tag) = inout. split_at_mut ( inout. len ( ) - 16 ) ;
178
+ if chacha. decrypt_in_place ( inout, tag) . is_err ( ) {
179
+ return Err ( LightningError { err : "Bad MAC" . to_owned ( ) , action : msgs:: ErrorAction :: DisconnectPeer { msg : None } } ) ;
180
+ }
181
+ Ok ( ( ) )
182
+ }
183
+
172
184
#[ inline]
173
185
fn decrypt_with_ad ( res : & mut [ u8 ] , n : u64 , key : & [ u8 ; 32 ] , h : & [ u8 ] , cyphertext : & [ u8 ] ) -> Result < ( ) , LightningError > {
174
186
let mut nonce = [ 0 ; 12 ] ;
@@ -505,21 +517,20 @@ impl PeerChannelEncryptor {
505
517
}
506
518
}
507
519
508
- /// Decrypts the given message.
520
+ /// Decrypts the given message up to msg.len() - 16. Bytes after msg.len() - 16 will be left
521
+ /// undefined (as they contain the Poly1305 tag bytes).
522
+ ///
509
523
/// panics if msg.len() > 65535 + 16
510
- pub fn decrypt_message ( & mut self , msg : & [ u8 ] ) -> Result < Vec < u8 > , LightningError > {
524
+ pub fn decrypt_message ( & mut self , msg : & mut [ u8 ] ) -> Result < ( ) , LightningError > {
511
525
if msg. len ( ) > LN_MAX_MSG_LEN + 16 {
512
526
panic ! ( "Attempted to decrypt message longer than 65535 + 16 bytes!" ) ;
513
527
}
514
528
515
529
match self . noise_state {
516
530
NoiseState :: Finished { sk : _, sn : _, sck : _, ref rk, ref mut rn, rck : _ } => {
517
- let mut res = Vec :: with_capacity ( msg. len ( ) - 16 ) ;
518
- res. resize ( msg. len ( ) - 16 , 0 ) ;
519
- Self :: decrypt_with_ad ( & mut res[ ..] , * rn, rk, & [ 0 ; 0 ] , msg) ?;
531
+ Self :: decrypt_in_place_with_ad ( & mut msg[ ..] , * rn, rk, & [ 0 ; 0 ] ) ?;
520
532
* rn += 1 ;
521
-
522
- Ok ( res)
533
+ Ok ( ( ) )
523
534
} ,
524
535
_ => panic ! ( "Tried to decrypt a message prior to noise handshake completion" ) ,
525
536
}
@@ -764,12 +775,11 @@ mod tests {
764
775
765
776
for i in 0 ..1005 {
766
777
let msg = [ 0x68 , 0x65 , 0x6c , 0x6c , 0x6f ] ;
767
- let res = outbound_peer. encrypt_buffer ( & msg) ;
778
+ let mut res = outbound_peer. encrypt_buffer ( & msg) ;
768
779
assert_eq ! ( res. len( ) , 5 + 2 * 16 + 2 ) ;
769
780
770
781
let len_header = res[ 0 ..2 +16 ] . to_vec ( ) ;
771
782
assert_eq ! ( inbound_peer. decrypt_length_header( & len_header[ ..] ) . unwrap( ) as usize , msg. len( ) ) ;
772
- assert_eq ! ( inbound_peer. decrypt_message( & res[ 2 +16 ..] ) . unwrap( ) [ ..] , msg[ ..] ) ;
773
783
774
784
if i == 0 {
775
785
assert_eq ! ( res, hex:: decode( "cf2b30ddf0cf3f80e7c35a6e6730b59fe802473180f396d88a8fb0db8cbcf25d2f214cf9ea1d95" ) . unwrap( ) ) ;
@@ -784,6 +794,9 @@ mod tests {
784
794
} else if i == 1001 {
785
795
assert_eq ! ( res, hex:: decode( "2ecd8c8a5629d0d02ab457a0fdd0f7b90a192cd46be5ecb6ca570bfc5e268338b1a16cf4ef2d36" ) . unwrap( ) ) ;
786
796
}
797
+
798
+ inbound_peer. decrypt_message ( & mut res[ 2 +16 ..] ) . unwrap ( ) ;
799
+ assert_eq ! ( res[ 2 + 16 ..res. len( ) - 16 ] , msg[ ..] ) ;
787
800
}
788
801
}
789
802
@@ -807,7 +820,7 @@ mod tests {
807
820
let mut inbound_peer = get_inbound_peer_for_test_vectors ( ) ;
808
821
809
822
// MSG should not exceed LN_MAX_MSG_LEN + 16
810
- let msg = [ 4u8 ; LN_MAX_MSG_LEN + 17 ] ;
811
- inbound_peer. decrypt_message ( & msg) . unwrap ( ) ;
823
+ let mut msg = [ 4u8 ; LN_MAX_MSG_LEN + 17 ] ;
824
+ inbound_peer. decrypt_message ( & mut msg) . unwrap ( ) ;
812
825
}
813
826
}
0 commit comments