Skip to content

Commit a176777

Browse files
committed
Add macros for building TLV (de)serializers.
There's quite a bit of machinery included here, but it neatly avoids any dynamic allocation during TLV deserialization, and the calling side looks nice and simple. There's a few new state-tracking read/write streams, but they should be pretty cheap (just a few increments/decrements per read/write. The macro-generated code is pretty nice, though has some redundant if statements (I haven't checked if they get optimized out yet, but I can't imagine they don't).
1 parent 3c9538f commit a176777

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed

lightning/src/util/ser.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::io::{Read, Write};
66
use std::collections::HashMap;
77
use std::hash::Hash;
88
use std::sync::Mutex;
9+
use std::cmp;
910

1011
use secp256k1::Signature;
1112
use secp256k1::key::{PublicKey, SecretKey};
@@ -67,6 +68,55 @@ impl Writer for VecWriter {
6768
}
6869
}
6970

71+
pub(crate) struct LengthCalculatingWriter(pub usize);
72+
impl Writer for LengthCalculatingWriter {
73+
#[inline]
74+
fn write_all(&mut self, buf: &[u8]) -> Result<(), ::std::io::Error> {
75+
self.0 += buf.len();
76+
Ok(())
77+
}
78+
#[inline]
79+
fn size_hint(&mut self, _size: usize) {}
80+
}
81+
82+
/// Essentially std::io::Take but a bit simpler and with a method to walk the underlying stream
83+
/// forward to ensure we always consume exactly the fixed length specified.
84+
pub(crate) struct FixedLengthReader<R: Read> {
85+
read: R,
86+
bytes_read: u64,
87+
total_fixed_len: u64,
88+
}
89+
impl<R: Read> FixedLengthReader<R> {
90+
pub fn new(read: R, total_fixed_len: u64) -> Self {
91+
Self { read, bytes_read: 0, total_fixed_len }
92+
}
93+
94+
pub fn eat_remaining(&mut self) -> Result<(), DecodeError> {
95+
::std::io::copy(self, &mut ::std::io::sink()).unwrap();
96+
if self.bytes_read != self.total_fixed_len {
97+
Err(DecodeError::ShortRead)
98+
} else {
99+
Ok(())
100+
}
101+
}
102+
}
103+
impl<R: Read> Read for FixedLengthReader<R> {
104+
fn read(&mut self, dest: &mut [u8]) -> Result<usize, ::std::io::Error> {
105+
if self.total_fixed_len == self.bytes_read {
106+
Ok(0)
107+
} else {
108+
let read_len = cmp::min(dest.len() as u64, self.total_fixed_len - self.bytes_read);
109+
match self.read.read(&mut dest[0..(read_len as usize)]) {
110+
Ok(v) => {
111+
self.bytes_read += v as u64;
112+
Ok(v)
113+
},
114+
Err(e) => Err(e),
115+
}
116+
}
117+
}
118+
}
119+
70120
/// A trait that various rust-lightning types implement allowing them to be written out to a Writer
71121
pub trait Writeable {
72122
/// Writes self out to the given Writer

lightning/src/util/ser_macros.rs

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,105 @@
1+
macro_rules! encode_tlv {
2+
($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
3+
use bitcoin::consensus::Encodable;
4+
use bitcoin::consensus::encode::{Error, VarInt};
5+
use util::ser::{WriterWriteAdaptor, LengthCalculatingWriter};
6+
$(
7+
VarInt($type).consensus_encode(WriterWriteAdaptor($stream))
8+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
9+
let mut len_calc = LengthCalculatingWriter(0);
10+
$field.write(&mut len_calc)?;
11+
VarInt(len_calc.0 as u64).consensus_encode(WriterWriteAdaptor($stream))
12+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
13+
$field.write($stream)?;
14+
)*
15+
} }
16+
}
17+
18+
macro_rules! encode_varint_length_prefixed_tlv {
19+
($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
20+
use bitcoin::consensus::Encodable;
21+
use bitcoin::consensus::encode::{Error, VarInt};
22+
use util::ser::{WriterWriteAdaptor, LengthCalculatingWriter};
23+
let mut len = LengthCalculatingWriter(0);
24+
{
25+
$(
26+
VarInt($type).consensus_encode(WriterWriteAdaptor(&mut len))
27+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
28+
let mut field_len = LengthCalculatingWriter(0);
29+
$field.write(&mut field_len)?;
30+
VarInt(field_len.0 as u64).consensus_encode(WriterWriteAdaptor(&mut len))
31+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
32+
len.0 += field_len.0;
33+
)*
34+
}
35+
36+
VarInt(len.0 as u64).consensus_encode(WriterWriteAdaptor($stream))
37+
.map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
38+
encode_tlv!($stream, {
39+
$(($type, $field)),*
40+
});
41+
} }
42+
}
43+
44+
macro_rules! decode_tlv {
45+
($stream: expr, {$(($reqtype: expr, $reqfield: ident)),*}, {$(($type: expr, $field: ident)),*}) => { {
46+
use ln::msgs::DecodeError;
47+
let mut last_seen_type: Option<u64> = None;
48+
'tlv_read: loop {
49+
use bitcoin::consensus::encode;
50+
use util::ser;
51+
use std;
52+
53+
// First decode the type of this TLV:
54+
let typ: encode::VarInt = match encode::Decodable::consensus_decode($stream) {
55+
Err(encode::Error::Io(ref ioe)) if ioe.kind() == std::io::ErrorKind::UnexpectedEof
56+
=> break 'tlv_read,
57+
Err(encode::Error::Io(ioe)) => Err(DecodeError::from(ioe))?,
58+
Err(_) => Err(DecodeError::InvalidValue)?,
59+
Ok(t) => t,
60+
};
61+
62+
// Types must be unique and monotonically increasing:
63+
match last_seen_type {
64+
Some(t) if typ.0 <= t => {
65+
Err(DecodeError::InvalidValue)?
66+
},
67+
_ => {},
68+
}
69+
// As we read types, make sure we hit every required type:
70+
$(if (last_seen_type.is_none() || last_seen_type.unwrap() < $reqtype) && typ.0 > $reqtype {
71+
Err(DecodeError::InvalidValue)?
72+
})*
73+
last_seen_type = Some(typ.0);
74+
75+
// Finally, read the length and value itself:
76+
let length: encode::VarInt = encode::Decodable::consensus_decode($stream)
77+
.map_err(|e| match e {
78+
encode::Error::Io(ioe) => DecodeError::from(ioe),
79+
_ => DecodeError::InvalidValue
80+
})?;
81+
let mut s = ser::FixedLengthReader::new($stream, length.0);
82+
match typ.0 {
83+
$($reqtype => {
84+
$reqfield = ser::Readable::read(&mut s)?;
85+
},)*
86+
$($type => {
87+
$field = Some(ser::Readable::read(&mut s)?);
88+
},)*
89+
x if x % 2 == 0 => {
90+
Err(DecodeError::UnknownRequiredFeature)?
91+
},
92+
_ => {},
93+
}
94+
s.eat_remaining()?;
95+
}
96+
// Make sure we got to each required type after we've read every TLV:
97+
$(if last_seen_type.is_none() || last_seen_type.unwrap() < $reqtype {
98+
Err(DecodeError::InvalidValue)?
99+
})*
100+
} }
101+
}
102+
1103
macro_rules! impl_writeable {
2104
($st:ident, $len: expr, {$($field:ident),*}) => {
3105
impl ::util::ser::Writeable for $st {
@@ -40,3 +142,73 @@ macro_rules! impl_writeable_len_match {
40142
}
41143
}
42144
}
145+
146+
#[cfg(test)]
147+
mod tests {
148+
use std::io::Cursor;
149+
use ln::msgs::DecodeError;
150+
151+
fn tlv_reader(s: &[u8]) -> Result<(u64, u32, Option<u32>), DecodeError> {
152+
let mut s = Cursor::new(s);
153+
let mut a: u64 = 0;
154+
let mut b: u32 = 0;
155+
let mut c: Option<u32> = None;
156+
decode_tlv!(&mut s, {(2, a), (3, b)}, {(4, c)});
157+
Ok((a, b, c))
158+
}
159+
#[test]
160+
fn test_tlv() {
161+
// Value for 3 is longer than we expect, but that's ok...
162+
assert_eq!(tlv_reader(&::hex::decode(
163+
concat!("0100", "0208deadbeef1badbeef", "0308deadbeef1badf00d")
164+
).unwrap()[..]).unwrap(),
165+
(0xdeadbeef1badbeef, 0xdeadbeef, None));
166+
// ...even if there's something afterwards
167+
assert_eq!(tlv_reader(&::hex::decode(
168+
concat!("0100", "0208deadbeef1badbeef", "0308deadbeef1badf00d", "0404ffffffff")
169+
).unwrap()[..]).unwrap(),
170+
(0xdeadbeef1badbeef, 0xdeadbeef, Some(0xffffffff)));
171+
// ...but not if that extra length is missing
172+
if let Err(DecodeError::ShortRead) = tlv_reader(&::hex::decode(
173+
concat!("0100", "0208deadbeef1badbeef", "0308deadbeef")
174+
).unwrap()[..]) {
175+
} else { panic!(); }
176+
177+
// If they're out of order that's also bad
178+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
179+
concat!("0100", "0304deadbeef", "0208deadbeef1badbeef")
180+
).unwrap()[..]) {
181+
} else { panic!(); }
182+
// ...even if its some field we don't understand
183+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
184+
concat!("0208deadbeef1badbeef", "0100", "0304deadbeef")
185+
).unwrap()[..]) {
186+
} else { panic!(); }
187+
188+
// It's also bad if they included even fields we don't understand
189+
if let Err(DecodeError::UnknownRequiredFeature) = tlv_reader(&::hex::decode(
190+
concat!("0100", "0208deadbeef1badbeef", "0304deadbeef", "0600")
191+
).unwrap()[..]) {
192+
} else { panic!(); }
193+
// ... or if they're missing fields we need
194+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
195+
concat!("0100", "0208deadbeef1badbeef")
196+
).unwrap()[..]) {
197+
} else { panic!(); }
198+
// ... even if that field is even
199+
if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
200+
concat!("0304deadbeef", "0500")
201+
).unwrap()[..]) {
202+
} else { panic!(); }
203+
204+
// But usually things are pretty much what we expect:
205+
assert_eq!(tlv_reader(&::hex::decode(
206+
concat!("0208deadbeef1badbeef", "03041bad1dea")
207+
).unwrap()[..]).unwrap(),
208+
(0xdeadbeef1badbeef, 0x1bad1dea, None));
209+
assert_eq!(tlv_reader(&::hex::decode(
210+
concat!("0208deadbeef1badbeef", "03041bad1dea", "040401020304")
211+
).unwrap()[..]).unwrap(),
212+
(0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304)));
213+
}
214+
}

0 commit comments

Comments
 (0)