|
18 | 18 | //! Refer to [`crate::encode`] for information on the encoding.
|
19 | 19 |
|
20 | 20 | use crate::bititer::BitIter;
|
| 21 | +use crate::core::commit::{CommitNodeInner, RefWrapper}; |
| 22 | +use crate::core::iter::DagIterable; |
| 23 | +use crate::core::types::{Type, TypeInner}; |
| 24 | +use crate::core::{CommitNode, Value}; |
| 25 | +use crate::jet::Application; |
| 26 | +use crate::merkle::cmr::Cmr; |
21 | 27 | use crate::Error;
|
| 28 | +use std::collections::HashMap; |
| 29 | +use std::rc::Rc; |
22 | 30 |
|
23 |
| -/* |
24 |
| -/// Decode an untyped Simplicity program from bits. |
| 31 | +/// Decode a Simplicity program from bits, without witness data. |
25 | 32 | pub fn decode_program_no_witness<I: Iterator<Item = u8>, App: Application>(
|
26 |
| - iter: &mut BitIter<I>, |
27 |
| -) -> Result<UntypedProgram<(), App>, Error> { |
28 |
| - let prog_len = decode_natural(iter, None)?; |
| 33 | + bits: &mut BitIter<I>, |
| 34 | +) -> Result<Rc<CommitNode<(), App>>, Error> { |
| 35 | + let len = decode_natural(bits, None)?; |
29 | 36 |
|
30 |
| - // FIXME: check maximum length of DAG that is allowed by consensus |
31 |
| - if prog_len > 1_000_000 { |
32 |
| - return Err(Error::TooManyNodes(prog_len)); |
| 37 | + if len == 0 { |
| 38 | + return Err(Error::EmptyProgram); |
33 | 39 | }
|
34 |
| -
|
35 |
| - let mut program = Vec::with_capacity(prog_len); |
36 |
| - for _ in 0..prog_len { |
37 |
| - decode_node(&mut program, iter)?; |
| 40 | + // FIXME: check maximum length of DAG that is allowed by consensus |
| 41 | + if len > 1_000_000 { |
| 42 | + return Err(Error::TooManyNodes(len)); |
38 | 43 | }
|
39 | 44 |
|
40 |
| - let program = UntypedProgram(program); |
| 45 | + let mut index_to_node = HashMap::new(); |
41 | 46 |
|
42 |
| - if program.has_canonical_order() { |
43 |
| - Ok(program) |
44 |
| - } else { |
45 |
| - Err(Error::ParseError("Program is not in canonical order!")) |
| 47 | + for index in 0..len { |
| 48 | + decode_node(bits, index, &mut index_to_node)?; |
46 | 49 | }
|
47 |
| -} |
48 | 50 |
|
49 |
| -/// Decode witness data from bits. |
50 |
| -pub fn decode_witness<Wit, App: Application, I: Iterator<Item = u8>>( |
51 |
| - program: &TypedProgram<Wit, App>, |
52 |
| - iter: &mut BitIter<I>, |
53 |
| -) -> Result<Vec<Value>, Error> { |
54 |
| - let bit_len = match iter.next() { |
55 |
| - Some(false) => 0, |
56 |
| - Some(true) => decode_natural(iter, None)?, |
57 |
| - None => return Err(Error::EndOfStream), |
58 |
| - }; |
59 |
| - let mut witness = Vec::new(); |
60 |
| - let n_start = iter.n_total_read(); |
| 51 | + let root_index = len - 1; |
| 52 | + let root = index_to_node.get(&root_index).unwrap().clone(); |
| 53 | + let connected_len = RefWrapper(&root).iter_post_order().count(); |
61 | 54 |
|
62 |
| - for node in &program.0 { |
63 |
| - if let Term::Witness(_old_witness) = &node.term { |
64 |
| - witness.push(decode_value(&node.target_ty, iter)?); |
65 |
| - } |
| 55 | + if connected_len != len { |
| 56 | + return Err(Error::InconsistentProgramLength); |
66 | 57 | }
|
67 | 58 |
|
68 |
| - if iter.n_total_read() - n_start != bit_len { |
69 |
| - Err(Error::ParseError( |
70 |
| - "Witness bit string has different length than defined in its preamble", |
71 |
| - )) |
72 |
| - } else { |
73 |
| - Ok(witness) |
74 |
| - } |
75 |
| -} |
76 |
| -
|
77 |
| -/// Decode a value from bits, based on the given type. |
78 |
| -pub fn decode_value<I: Iterator<Item = bool>>(ty: &Type, iter: &mut I) -> Result<Value, Error> { |
79 |
| - let value = match ty.ty { |
80 |
| - TypeInner::Unit => Value::Unit, |
81 |
| - TypeInner::Sum(ref l, ref r) => match iter.next() { |
82 |
| - Some(false) => Value::SumL(Box::new(decode_value(l, iter)?)), |
83 |
| - Some(true) => Value::SumR(Box::new(decode_value(r, iter)?)), |
84 |
| - None => return Err(Error::EndOfStream), |
85 |
| - }, |
86 |
| - TypeInner::Product(ref l, ref r) => Value::Prod( |
87 |
| - Box::new(decode_value(l, iter)?), |
88 |
| - Box::new(decode_value(r, iter)?), |
89 |
| - ), |
90 |
| - }; |
91 |
| -
|
92 |
| - Ok(value) |
| 59 | + Ok(root) |
93 | 60 | }
|
94 | 61 |
|
95 |
| -/// Decode an untyped Simplicity term from bits and add it to the given program. |
| 62 | +/// Decode a single Simplicity node from bits and |
| 63 | +/// insert it into a hash map at its index for future reference by ancestor nodes. |
96 | 64 | fn decode_node<I: Iterator<Item = u8>, App: Application>(
|
97 |
| - program: &mut Vec<Term<(), App>>, |
98 |
| - iter: &mut BitIter<I>, |
| 65 | + bits: &mut BitIter<I>, |
| 66 | + index: usize, |
| 67 | + index_to_node: &mut HashMap<usize, Rc<CommitNode<(), App>>>, |
99 | 68 | ) -> Result<(), Error> {
|
100 |
| - match iter.next() { |
| 69 | + match bits.next() { |
101 | 70 | None => return Err(Error::EndOfStream),
|
102 |
| - Some(true) => return decode_jet(program, iter), |
| 71 | + Some(true) => { |
| 72 | + let node = CommitNode::jet(App::decode_jet(bits)?); |
| 73 | + debug_assert!(!index_to_node.contains_key(&index)); |
| 74 | + index_to_node.insert(index, node); |
| 75 | + return Ok(()); |
| 76 | + } |
103 | 77 | Some(false) => {}
|
104 | 78 | };
|
105 | 79 |
|
106 |
| - let code = match iter.read_bits_be(2) { |
| 80 | + let code = match bits.read_bits_be(2) { |
107 | 81 | Some(n) => n,
|
108 | 82 | None => return Err(Error::EndOfStream),
|
109 | 83 | };
|
110 |
| - let subcode = match iter.read_bits_be(if code < 3 { 2 } else { 1 }) { |
| 84 | + let subcode = match bits.read_bits_be(if code < 3 { 2 } else { 1 }) { |
111 | 85 | Some(n) => n,
|
112 | 86 | None => return Err(Error::EndOfStream),
|
113 | 87 | };
|
114 | 88 | let node = if code <= 1 {
|
115 |
| - let idx = program.len(); |
116 |
| - let i = decode_natural(iter, Some(idx))?; |
| 89 | + let i_abs = index - decode_natural(bits, Some(index))?; |
| 90 | + let left = get_child_from_index(i_abs, index_to_node); |
117 | 91 |
|
118 | 92 | if code == 0 {
|
119 |
| - let j = decode_natural(iter, Some(idx))?; |
| 93 | + let j_abs = index - decode_natural(bits, Some(index))?; |
| 94 | + let right = get_child_from_index(j_abs, index_to_node); |
120 | 95 |
|
121 | 96 | match subcode {
|
122 |
| - 0 => Term::Comp(i, j), |
| 97 | + 0 => CommitNode::comp(left, right), |
123 | 98 | 1 => {
|
124 |
| - let mut node = Term::Case(i, j); |
125 |
| - let mut left_hidden = false; |
126 |
| -
|
127 |
| - if let Term::Hidden(..) = program[idx - i] { |
128 |
| - node = Term::AssertR(i, j); |
129 |
| - left_hidden = true; |
130 |
| - } |
131 |
| - if let Term::Hidden(..) = program[idx - j] { |
132 |
| - if left_hidden { |
| 99 | + if let CommitNodeInner::Hidden(..) = left.inner { |
| 100 | + if let CommitNodeInner::Hidden(..) = right.inner { |
133 | 101 | return Err(Error::CaseMultipleHiddenChildren);
|
134 | 102 | }
|
135 |
| -
|
136 |
| - node = Term::AssertL(i, j); |
137 | 103 | }
|
138 | 104 |
|
139 |
| - node |
| 105 | + if let CommitNodeInner::Hidden(..) = right.inner { |
| 106 | + CommitNode::assertl(left, right) |
| 107 | + } else if let CommitNodeInner::Hidden(..) = left.inner { |
| 108 | + CommitNode::assertr(left, right) |
| 109 | + } else { |
| 110 | + CommitNode::case(left, right) |
| 111 | + } |
140 | 112 | }
|
141 |
| - 2 => Term::Pair(i, j), |
142 |
| - 3 => Term::Disconnect(i, j), |
143 |
| - _ => unreachable!(), |
| 113 | + 2 => CommitNode::pair(left, right), |
| 114 | + 3 => CommitNode::disconnect(left, right), |
| 115 | + // TODO: convert into crate::Error::ParseError |
| 116 | + _ => unimplemented!(), |
144 | 117 | }
|
145 | 118 | } else {
|
146 | 119 | match subcode {
|
147 |
| - 0 => Term::InjL(i), |
148 |
| - 1 => Term::InjR(i), |
149 |
| - 2 => Term::Take(i), |
150 |
| - 3 => Term::Drop(i), |
151 |
| - _ => unreachable!(), |
| 120 | + 0 => CommitNode::injl(left), |
| 121 | + 1 => CommitNode::injr(left), |
| 122 | + 2 => CommitNode::take(left), |
| 123 | + 3 => CommitNode::drop(left), |
| 124 | + _ => unimplemented!(), |
152 | 125 | }
|
153 | 126 | }
|
154 | 127 | } else if code == 2 {
|
155 | 128 | match subcode {
|
156 |
| - 0 => Term::Iden, |
157 |
| - 1 => Term::Unit, |
158 |
| - 2 => Term::Fail(Cmr::from(decode_hash(iter)?), Cmr::from(decode_hash(iter)?)), |
| 129 | + 0 => CommitNode::iden(), |
| 130 | + 1 => CommitNode::unit(), |
| 131 | + 2 => CommitNode::fail(Cmr::from(decode_hash(bits)?), Cmr::from(decode_hash(bits)?)), |
159 | 132 | 3 => return Err(Error::ParseError("01011 (stop code)")),
|
160 |
| - _ => unreachable!(), |
| 133 | + _ => unimplemented!(), |
161 | 134 | }
|
162 | 135 | } else if code == 3 {
|
163 | 136 | match subcode {
|
164 |
| - 0 => Term::Hidden(Cmr::from(decode_hash(iter)?)), |
165 |
| - 1 => Term::Witness(()), |
166 |
| - _ => unreachable!(), |
| 137 | + 0 => CommitNode::hidden(Cmr::from(decode_hash(bits)?)), |
| 138 | + 1 => CommitNode::witness(()), |
| 139 | + _ => unimplemented!(), |
167 | 140 | }
|
168 | 141 | } else {
|
169 |
| - unreachable!() |
| 142 | + unimplemented!() |
170 | 143 | };
|
171 | 144 |
|
172 |
| - program.push(node); |
| 145 | + debug_assert!(!index_to_node.contains_key(&index)); |
| 146 | + index_to_node.insert(index, node); |
173 | 147 | Ok(())
|
174 | 148 | }
|
175 | 149 |
|
176 |
| -/// Decode a Simplicity jet from bits. |
177 |
| -fn decode_jet<I: Iterator<Item = u8>, App: Application>( |
178 |
| - program: &mut Vec<Term<(), App>>, |
179 |
| - iter: &mut BitIter<I>, |
180 |
| -) -> Result<(), Error> { |
181 |
| - let node = Term::Jet(App::decode_jet(iter)?); |
182 |
| - program.push(node); |
183 |
| - Ok(()) |
| 150 | +/// Return the child node at the given index from a hash map. |
| 151 | +fn get_child_from_index<App: Application>( |
| 152 | + index: usize, |
| 153 | + index_to_node: &HashMap<usize, Rc<CommitNode<(), App>>>, |
| 154 | +) -> Rc<CommitNode<(), App>> { |
| 155 | + index_to_node |
| 156 | + .get(&index) |
| 157 | + .expect("Children come before parent in post order") |
| 158 | + .clone() |
| 159 | + // TODO: Return fresh witness once sharing of unpopulated witness nodes is implemented |
| 160 | +} |
| 161 | + |
| 162 | +/// Iterator over witness values that asks for the value type on each iteration. |
| 163 | +pub trait WitnessIterator { |
| 164 | + /// Return the next witness value of the given type. |
| 165 | + fn next(&mut self, ty: &Type) -> Result<Value, Error>; |
| 166 | + |
| 167 | + /// Consume the iterator and check the total witness length. |
| 168 | + fn finish(self) -> Result<(), Error>; |
| 169 | +} |
| 170 | + |
| 171 | +impl<I: Iterator<Item = Value>> WitnessIterator for I { |
| 172 | + fn next(&mut self, _ty: &Type) -> Result<Value, Error> { |
| 173 | + Iterator::next(self).ok_or(Error::EndOfStream) |
| 174 | + } |
| 175 | + |
| 176 | + fn finish(self) -> Result<(), Error> { |
| 177 | + Ok(()) |
| 178 | + } |
| 179 | +} |
| 180 | + |
| 181 | +/// Implementation of [`WitnessIterator`] for an underlying [`BitIter`]. |
| 182 | +#[derive(Debug)] |
| 183 | +pub struct WitnessDecoder<'a, I: Iterator<Item = u8>> { |
| 184 | + bits: &'a mut BitIter<I>, |
| 185 | + max_n: usize, |
| 186 | +} |
| 187 | + |
| 188 | +impl<'a, I: Iterator<Item = u8>> WitnessDecoder<'a, I> { |
| 189 | + /// Create a new witness decoder for the given bit iterator. |
| 190 | + /// To work, this method must be used **after** [`decode_program_no_witness()`]! |
| 191 | + pub fn new(bits: &'a mut BitIter<I>) -> Result<Self, Error> { |
| 192 | + let bit_len = match bits.next() { |
| 193 | + Some(false) => 0, |
| 194 | + Some(true) => decode_natural(bits, None)?, |
| 195 | + None => return Err(Error::EndOfStream), |
| 196 | + }; |
| 197 | + let n_start = bits.n_total_read(); |
| 198 | + |
| 199 | + Ok(Self { |
| 200 | + bits, |
| 201 | + max_n: n_start + bit_len, |
| 202 | + }) |
| 203 | + } |
| 204 | +} |
| 205 | + |
| 206 | +impl<'a, I: Iterator<Item = u8>> WitnessIterator for WitnessDecoder<'a, I> { |
| 207 | + fn next(&mut self, ty: &Type) -> Result<Value, Error> { |
| 208 | + decode_value(ty, self.bits) |
| 209 | + } |
| 210 | + |
| 211 | + fn finish(self) -> Result<(), Error> { |
| 212 | + if self.bits.n_total_read() != self.max_n { |
| 213 | + Err(Error::InconsistentWitnessLength) |
| 214 | + } else { |
| 215 | + Ok(()) |
| 216 | + } |
| 217 | + } |
| 218 | +} |
| 219 | + |
| 220 | +/// Decode a value from bits, based on the given type. |
| 221 | +pub fn decode_value<I: Iterator<Item = bool>>(ty: &Type, iter: &mut I) -> Result<Value, Error> { |
| 222 | + let value = match ty.ty { |
| 223 | + TypeInner::Unit => Value::Unit, |
| 224 | + TypeInner::Sum(ref l, ref r) => match iter.next() { |
| 225 | + Some(false) => Value::SumL(Box::new(decode_value(l, iter)?)), |
| 226 | + Some(true) => Value::SumR(Box::new(decode_value(r, iter)?)), |
| 227 | + None => return Err(Error::EndOfStream), |
| 228 | + }, |
| 229 | + TypeInner::Product(ref l, ref r) => Value::Prod( |
| 230 | + Box::new(decode_value(l, iter)?), |
| 231 | + Box::new(decode_value(r, iter)?), |
| 232 | + ), |
| 233 | + }; |
| 234 | + |
| 235 | + Ok(value) |
184 | 236 | }
|
185 |
| -*/ |
186 | 237 |
|
187 | 238 | /// Decode a 256-bit hash from bits.
|
188 | 239 | fn decode_hash<I: Iterator<Item = u8>>(iter: &mut BitIter<I>) -> Result<[u8; 32], Error> {
|
|
0 commit comments