Skip to content

Commit b3825d3

Browse files
Add SigType check in CompilerExtData for deciding key byte-lengths
1 parent d8cc633 commit b3825d3

File tree

1 file changed

+80
-29
lines changed

1 file changed

+80
-29
lines changed

src/policy/compiler.rs

Lines changed: 80 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use std::convert::From;
2222
use std::marker::PhantomData;
2323
use std::{cmp, error, f64, fmt, mem};
2424

25+
use miniscript::context::SigType;
2526
use miniscript::limits::MAX_PUBKEYS_PER_MULTISIG;
2627
use miniscript::types::{self, ErrorKind, ExtData, Property, Type};
2728
use miniscript::ScriptContext;
@@ -136,7 +137,7 @@ impl CompilationKey {
136137
}
137138

138139
#[derive(Copy, Clone, Debug)]
139-
struct CompilerExtData {
140+
struct CompilerExtData<Ctx: ScriptContext> {
140141
/// If this node is the direct child of a disjunction, this field must
141142
/// have the probability of its branch being taken. Otherwise it is ignored.
142143
/// All functions initialize it to `None`.
@@ -148,14 +149,18 @@ struct CompilerExtData {
148149
/// (total length of all witness pushes, plus their own length prefixes)
149150
/// for fragments that can be dissatisfied without failing the script.
150151
dissat_cost: Option<f64>,
152+
/// Signature Type for deciding the number of bytes required for [`CompilerExtData::sat_cost`]
153+
/// and [`CompilerExtData::dissat_cost`] for the given script.
154+
sig_type: PhantomData<Ctx>,
151155
}
152156

153-
impl Property for CompilerExtData {
157+
impl<Ctx: ScriptContext> Property for CompilerExtData<Ctx> {
154158
fn from_true() -> Self {
155159
CompilerExtData {
156160
branch_prob: None,
157161
sat_cost: 0.0,
158162
dissat_cost: None,
163+
sig_type: PhantomData,
159164
}
160165
}
161166

@@ -164,38 +169,64 @@ impl Property for CompilerExtData {
164169
branch_prob: None,
165170
sat_cost: f64::MAX,
166171
dissat_cost: Some(0.0),
172+
sig_type: PhantomData,
167173
}
168174
}
169175

170176
fn from_pk_k() -> Self {
171177
CompilerExtData {
172178
branch_prob: None,
173-
sat_cost: 73.0,
179+
sat_cost: match Ctx::sig_type() {
180+
SigType::Ecdsa => 73.,
181+
SigType::Schnorr => 64.,
182+
},
174183
dissat_cost: Some(1.0),
184+
sig_type: PhantomData,
175185
}
176186
}
177187

178188
fn from_pk_h() -> Self {
179189
CompilerExtData {
180190
branch_prob: None,
181-
sat_cost: 73.0 + 34.0,
182-
dissat_cost: Some(1.0 + 34.0),
191+
sat_cost: match Ctx::sig_type() {
192+
SigType::Ecdsa => 73.0 + 34.0,
193+
SigType::Schnorr => 64. + 32.,
194+
},
195+
dissat_cost: Some(
196+
1.0 + match Ctx::sig_type() {
197+
SigType::Ecdsa => 34.0,
198+
SigType::Schnorr => 32.,
199+
},
200+
),
201+
sig_type: PhantomData,
183202
}
184203
}
185204

186205
fn from_multi(k: usize, _n: usize) -> Self {
187206
CompilerExtData {
188207
branch_prob: None,
189-
sat_cost: 1.0 + 73.0 * k as f64,
208+
sat_cost: 1.0
209+
+ match Ctx::sig_type() {
210+
SigType::Ecdsa => 73.0 * k as f64,
211+
SigType::Schnorr => 64. * k as f64,
212+
},
190213
dissat_cost: Some(1.0 * (k + 1) as f64),
214+
sig_type: PhantomData,
191215
}
192216
}
193217

194218
fn from_hash() -> Self {
195219
CompilerExtData {
196220
branch_prob: None,
197-
sat_cost: 33.0,
198-
dissat_cost: Some(33.0),
221+
sat_cost: match Ctx::sig_type() {
222+
SigType::Ecdsa => 33.0,
223+
SigType::Schnorr => 32.,
224+
},
225+
dissat_cost: Some(match Ctx::sig_type() {
226+
SigType::Ecdsa => 33.0,
227+
SigType::Schnorr => 32.,
228+
}),
229+
sig_type: PhantomData,
199230
}
200231
}
201232

@@ -204,6 +235,7 @@ impl Property for CompilerExtData {
204235
branch_prob: None,
205236
sat_cost: 0.0,
206237
dissat_cost: None,
238+
sig_type: PhantomData,
207239
}
208240
}
209241

@@ -212,6 +244,7 @@ impl Property for CompilerExtData {
212244
branch_prob: None,
213245
sat_cost: self.sat_cost,
214246
dissat_cost: self.dissat_cost,
247+
sig_type: PhantomData,
215248
})
216249
}
217250

@@ -220,6 +253,7 @@ impl Property for CompilerExtData {
220253
branch_prob: None,
221254
sat_cost: self.sat_cost,
222255
dissat_cost: self.dissat_cost,
256+
sig_type: PhantomData,
223257
})
224258
}
225259

@@ -228,6 +262,7 @@ impl Property for CompilerExtData {
228262
branch_prob: None,
229263
sat_cost: self.sat_cost,
230264
dissat_cost: self.dissat_cost,
265+
sig_type: PhantomData,
231266
})
232267
}
233268

@@ -236,6 +271,7 @@ impl Property for CompilerExtData {
236271
branch_prob: None,
237272
sat_cost: 2.0 + self.sat_cost,
238273
dissat_cost: Some(1.0),
274+
sig_type: PhantomData,
239275
})
240276
}
241277

@@ -244,6 +280,7 @@ impl Property for CompilerExtData {
244280
branch_prob: None,
245281
sat_cost: self.sat_cost,
246282
dissat_cost: None,
283+
sig_type: PhantomData,
247284
})
248285
}
249286

@@ -252,6 +289,7 @@ impl Property for CompilerExtData {
252289
branch_prob: None,
253290
sat_cost: self.sat_cost,
254291
dissat_cost: Some(1.0),
292+
sig_type: PhantomData,
255293
})
256294
}
257295

@@ -260,6 +298,7 @@ impl Property for CompilerExtData {
260298
branch_prob: None,
261299
sat_cost: self.sat_cost,
262300
dissat_cost: self.dissat_cost,
301+
sig_type: PhantomData,
263302
})
264303
}
265304

@@ -268,6 +307,7 @@ impl Property for CompilerExtData {
268307
branch_prob: None,
269308
sat_cost: self.sat_cost,
270309
dissat_cost: None,
310+
sig_type: PhantomData,
271311
})
272312
}
273313

@@ -281,6 +321,7 @@ impl Property for CompilerExtData {
281321
branch_prob: None,
282322
sat_cost: 2.0 + self.sat_cost,
283323
dissat_cost: Some(1.0),
324+
sig_type: PhantomData,
284325
})
285326
}
286327

@@ -289,6 +330,7 @@ impl Property for CompilerExtData {
289330
branch_prob: None,
290331
sat_cost: 1.0 + self.sat_cost,
291332
dissat_cost: Some(2.0),
333+
sig_type: PhantomData,
292334
})
293335
}
294336

@@ -300,6 +342,7 @@ impl Property for CompilerExtData {
300342
(Some(l), Some(r)) => Some(l + r),
301343
_ => None,
302344
},
345+
sig_type: PhantomData,
303346
})
304347
}
305348

@@ -308,6 +351,16 @@ impl Property for CompilerExtData {
308351
branch_prob: None,
309352
sat_cost: left.sat_cost + right.sat_cost,
310353
dissat_cost: None,
354+
sig_type: PhantomData,
355+
})
356+
}
357+
358+
fn and_n(a: Self, b: Self) -> Result<Self, types::ErrorKind> {
359+
Ok(CompilerExtData {
360+
branch_prob: None,
361+
sat_cost: a.sat_cost + b.sat_cost,
362+
dissat_cost: a.dissat_cost,
363+
sig_type: PhantomData,
311364
})
312365
}
313366

@@ -323,6 +376,7 @@ impl Property for CompilerExtData {
323376
sat_cost: lprob * (l.sat_cost + r.dissat_cost.unwrap())
324377
+ rprob * (r.sat_cost + l.dissat_cost.unwrap()),
325378
dissat_cost: Some(l.dissat_cost.unwrap() + r.dissat_cost.unwrap()),
379+
sig_type: PhantomData,
326380
})
327381
}
328382

@@ -337,6 +391,7 @@ impl Property for CompilerExtData {
337391
branch_prob: None,
338392
sat_cost: lprob * l.sat_cost + rprob * (r.sat_cost + l.dissat_cost.unwrap()),
339393
dissat_cost: r.dissat_cost.map(|rd| l.dissat_cost.unwrap() + rd),
394+
sig_type: PhantomData,
340395
})
341396
}
342397

@@ -351,6 +406,7 @@ impl Property for CompilerExtData {
351406
branch_prob: None,
352407
sat_cost: lprob * l.sat_cost + rprob * (r.sat_cost + l.dissat_cost.unwrap()),
353408
dissat_cost: None,
409+
sig_type: PhantomData,
354410
})
355411
}
356412

@@ -377,6 +433,7 @@ impl Property for CompilerExtData {
377433
} else {
378434
None
379435
},
436+
sig_type: PhantomData,
380437
})
381438
}
382439

@@ -400,14 +457,7 @@ impl Property for CompilerExtData {
400457
} else {
401458
None
402459
},
403-
})
404-
}
405-
406-
fn and_n(a: Self, b: Self) -> Result<Self, types::ErrorKind> {
407-
Ok(CompilerExtData {
408-
branch_prob: None,
409-
sat_cost: a.sat_cost + b.sat_cost,
410-
dissat_cost: a.dissat_cost,
460+
sig_type: PhantomData,
411461
})
412462
}
413463

@@ -427,6 +477,7 @@ impl Property for CompilerExtData {
427477
branch_prob: None,
428478
sat_cost: sat_cost * k_over_n + dissat_cost * (1.0 - k_over_n),
429479
dissat_cost: Some(dissat_cost),
480+
sig_type: PhantomData,
430481
})
431482
}
432483
}
@@ -437,7 +488,7 @@ struct AstElemExt<Pk: MiniscriptKey, Ctx: ScriptContext> {
437488
/// The actual Miniscript fragment with type information
438489
ms: Arc<Miniscript<Pk, Ctx>>,
439490
/// Its "type" in terms of compiler data
440-
comp_ext_data: CompilerExtData,
491+
comp_ext_data: CompilerExtData<Ctx>,
441492
}
442493

443494
impl<Pk: MiniscriptKey, Ctx: ScriptContext> AstElemExt<Pk, Ctx> {
@@ -470,8 +521,8 @@ impl<Pk: MiniscriptKey, Ctx: ScriptContext> AstElemExt<Pk, Ctx> {
470521
r: &AstElemExt<Pk, Ctx>,
471522
) -> Result<AstElemExt<Pk, Ctx>, types::Error<Pk, Ctx>> {
472523
let lookup_ext = |n| match n {
473-
0 => Some(l.comp_ext_data),
474-
1 => Some(r.comp_ext_data),
524+
0 => Some(l.comp_ext_data.clone()),
525+
1 => Some(r.comp_ext_data.clone()),
475526
_ => unreachable!(),
476527
};
477528
//Types and ExtData are already cached and stored in children. So, we can
@@ -497,9 +548,9 @@ impl<Pk: MiniscriptKey, Ctx: ScriptContext> AstElemExt<Pk, Ctx> {
497548
c: &AstElemExt<Pk, Ctx>,
498549
) -> Result<AstElemExt<Pk, Ctx>, types::Error<Pk, Ctx>> {
499550
let lookup_ext = |n| match n {
500-
0 => Some(a.comp_ext_data),
501-
1 => Some(b.comp_ext_data),
502-
2 => Some(c.comp_ext_data),
551+
0 => Some(a.comp_ext_data.clone()),
552+
1 => Some(b.comp_ext_data.clone()),
553+
2 => Some(c.comp_ext_data.clone()),
503554
_ => unreachable!(),
504555
};
505556
//Types and ExtData are already cached and stored in children. So, we can
@@ -525,7 +576,7 @@ struct Cast<Pk: MiniscriptKey, Ctx: ScriptContext> {
525576
node: fn(Arc<Miniscript<Pk, Ctx>>) -> Terminal<Pk, Ctx>,
526577
ast_type: fn(types::Type) -> Result<types::Type, ErrorKind>,
527578
ext_data: fn(types::ExtData) -> Result<types::ExtData, ErrorKind>,
528-
comp_ext_data: fn(CompilerExtData) -> Result<CompilerExtData, types::ErrorKind>,
579+
comp_ext_data: fn(CompilerExtData<Ctx>) -> Result<CompilerExtData<Ctx>, types::ErrorKind>,
529580
}
530581

531582
impl<Pk: MiniscriptKey, Ctx: ScriptContext> Cast<Pk, Ctx> {
@@ -537,7 +588,7 @@ impl<Pk: MiniscriptKey, Ctx: ScriptContext> Cast<Pk, Ctx> {
537588
node: (self.node)(Arc::clone(&ast.ms)),
538589
phantom: PhantomData,
539590
}),
540-
comp_ext_data: (self.comp_ext_data)(ast.comp_ext_data)?,
591+
comp_ext_data: (self.comp_ext_data)(ast.comp_ext_data.clone())?,
541592
})
542593
}
543594
}
@@ -955,19 +1006,19 @@ where
9551006
let bw = best(types::Base::W, policy_cache, ast, sp, dp)?;
9561007

9571008
let diff = be.cost_1d(sp, dp) - bw.cost_1d(sp, dp);
958-
best_es.push((be.comp_ext_data, be));
959-
best_ws.push((bw.comp_ext_data, bw));
1009+
best_es.push((be.comp_ext_data.clone(), be));
1010+
best_ws.push((bw.comp_ext_data.clone(), bw));
9601011

9611012
if diff < min_value.1 {
9621013
min_value.0 = i;
9631014
min_value.1 = diff;
9641015
}
9651016
}
966-
sub_ext_data.push(best_es[min_value.0].0);
1017+
sub_ext_data.push(best_es[min_value.0].0.clone());
9671018
sub_ast.push(Arc::clone(&best_es[min_value.0].1.ms));
9681019
for (i, _ast) in subs.iter().enumerate() {
9691020
if i != min_value.0 {
970-
sub_ext_data.push(best_ws[i].0);
1021+
sub_ext_data.push(best_ws[i].0.clone());
9711022
sub_ast.push(Arc::clone(&best_ws[i].1.ms));
9721023
}
9731024
}
@@ -978,7 +1029,7 @@ where
9781029
Miniscript::from_ast(ast)
9791030
.expect("threshold subs, which we just compiled, typeck"),
9801031
),
981-
comp_ext_data: CompilerExtData::threshold(k, n, |i| Ok(sub_ext_data[i]))
1032+
comp_ext_data: CompilerExtData::threshold(k, n, |i| Ok(sub_ext_data[i].clone()))
9821033
.expect("threshold subs, which we just compiled, typeck"),
9831034
};
9841035
insert_wrap!(ast_ext);

0 commit comments

Comments
 (0)