Skip to content

Refactor threshold method #412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/interpreter/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,6 @@ pub(super) fn from_txdata<'txin>(
// Creating new contexts is cheap
let secp = bitcoin::secp256k1::Secp256k1::verification_only();
let tap_script = tap_script.encode();
// Should not really need to call dangerous assumed tweaked here.
// Should be fixed after RC
if ctrl_blk.verify_taproot_commitment(&secp, output_key, &tap_script) {
Ok((
Inner::Script(ms, ScriptType::Tr),
Expand Down
154 changes: 97 additions & 57 deletions src/miniscript/types/extra_props.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,17 +762,13 @@ impl Property for ExtData {
let mut ops_count = 0;
let mut ops_count_sat_vec = Vec::with_capacity(n);
let mut ops_count_nsat_sum = 0;
let mut op_count_sat = Some(0);
let mut timelocks = Vec::with_capacity(n);
let mut stack_elem_count_sat_vec = Vec::with_capacity(n);
let mut stack_elem_count_sat = Some(0);
let mut stack_elem_count_dissat = Some(0);
let mut max_sat_size_vec = Vec::with_capacity(n);
let mut max_sat_size = Some((0, 0));
let mut max_dissat_size = Some((0, 0));
// the max element count is same as max sat element count when satisfying one element + 1
let mut exec_stack_elem_count_sat_vec = Vec::with_capacity(n);
let mut exec_stack_elem_count_sat = Some(0);
let mut exec_stack_elem_count_dissat = Some(0);

for i in 0..n {
Expand Down Expand Up @@ -809,61 +805,60 @@ impl Property for ExtData {
);
}

// We sort by [satisfaction cost - dissatisfaction cost] to make a worst-case (the most
// costy satisfaction are satisfied, the most costy dissatisfactions are dissatisfied)
// sum of the cost by iterating through the sorted vector *backward*.
stack_elem_count_sat_vec.sort_by(|a, b| {
a.0.map(|x| a.1.map(|y| x as isize - y as isize))
.cmp(&b.0.map(|x| b.1.map(|y| x as isize - y as isize)))
});
for (i, &(x, y)) in stack_elem_count_sat_vec.iter().rev().enumerate() {
stack_elem_count_sat = if i <= k {
x.and_then(|x| stack_elem_count_sat.map(|count| count + x))
} else {
y.and_then(|y| stack_elem_count_sat.map(|count| count + y))
};
}
stack_elem_count_sat_vec.sort_by(sat_minus_option_dissat);
let stack_elem_count_sat =
stack_elem_count_sat_vec
.iter()
.rev()
.enumerate()
.fold(Some(0), |acc, (i, &(x, y))| {
if i <= k {
opt_add(acc, x)
} else {
opt_add(acc, y)
}
});

// Same logic as above
exec_stack_elem_count_sat_vec.sort_by(|a, b| {
a.0.map(|x| a.1.map(|y| x as isize - y as isize))
.cmp(&b.0.map(|x| b.1.map(|y| x as isize - y as isize)))
});
for (i, &(x, y)) in exec_stack_elem_count_sat_vec.iter().rev().enumerate() {
exec_stack_elem_count_sat = if i <= k {
opt_max(exec_stack_elem_count_sat, x)
} else {
opt_max(exec_stack_elem_count_sat, y)
};
}
exec_stack_elem_count_sat_vec.sort_by(sat_minus_option_dissat);
let exec_stack_elem_count_sat = exec_stack_elem_count_sat_vec
.iter()
.rev()
.enumerate()
.fold(Some(0), |acc, (i, &(x, y))| {
if i <= k {
opt_max(acc, x)
} else {
opt_max(acc, y)
}
});

// Same for the size cost. A bit more intricated as we need to account for both the witness
// and scriptSig cost, so we end up with a tuple of Options of tuples. We use the witness
// cost (first element of the mentioned tuple) here.
// FIXME: Maybe make the ExtData struct aware of Ctx and add a one_cost() method here ?
max_sat_size_vec.sort_by(|a, b| {
a.0.map(|x| a.1.map(|y| x.0 as isize - y.0 as isize))
.cmp(&b.0.map(|x| b.1.map(|y| x.0 as isize - y.0 as isize)))
});
for (i, &(x, y)) in max_sat_size_vec.iter().enumerate() {
max_sat_size = if i <= k {
x.and_then(|x| max_sat_size.map(|(w, s)| (w + x.0, s + x.1)))
} else {
y.and_then(|y| max_sat_size.map(|(w, s)| (w + y.0, s + y.1)))
};
}
max_sat_size_vec.sort_by(sat_minus_dissat_witness);
let max_sat_size =
max_sat_size_vec
.iter()
.enumerate()
.fold(Some((0, 0)), |acc, (i, &(x, y))| {
if i <= k {
opt_tuple_add(acc, x)
} else {
opt_tuple_add(acc, y)
}
});

ops_count_sat_vec.sort_by(sat_minus_dissat);
let op_count_sat =
ops_count_sat_vec
.iter()
.enumerate()
.fold(Some(0), |acc, (i, &(x, y))| {
if i <= k {
opt_add(acc, x)
} else {
opt_add(acc, Some(y))
}
});

ops_count_sat_vec.sort_by(|a, b| {
a.0.map(|x| x as isize - a.1 as isize)
.cmp(&b.0.map(|x| x as isize - b.1 as isize))
});
for (i, &(x, y)) in ops_count_sat_vec.iter().enumerate() {
op_count_sat = if i <= k {
opt_add(op_count_sat, x)
} else {
opt_add(op_count_sat, Some(y))
};
}
Ok(ExtData {
pk_cost: pk_cost + n - 1, //all pk cost + (n-1)*ADD
has_free_verify: true,
Expand Down Expand Up @@ -1021,7 +1016,47 @@ impl Property for ExtData {
}
}

// Returns Some(max(x,y)) is both x and y are Some. Otherwise, return none
// Function to pass to sort_by. Sort by (satisfaction cost - dissatisfaction cost).
//
// We sort by (satisfaction cost - dissatisfaction cost) to make a worst-case (the most
// costy satisfactions are satisfied, the most costy dissatisfactions are dissatisfied).
//
// Args are of form: (<count_sat>, <count_dissat>)
fn sat_minus_dissat<'r, 's>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In cb4c037:

I think all these functions should be defined at the top of the function that they're used in.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm glad you brought this up, I've wanted to discuss it with you for ages now :) When I started writing Rust I had come from C and used to do as you suggest. I was then introduced to the idea of laying out a source file with the most important things at the top, this implies putting functions below where they are called. Another data point; I also read the idea somewhere of writing functions at a single level of abstraction the writing next, and below that, all the functions at the next layer of abstraction and so on.

Related; we have a bunch of places where the error struct and implementation is at the top of the file, I'd prefer to see that code buried way down lower so I don't see it very often.

Is there a reason you favour putting functions above where they were first called (apart from habit from writing/reading C code)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like these functions' visibility to match the only scope where they're used, since they are essentially helpers for the code in that scope.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't about being C-like -- you can't even do what I'm suggesting in C.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no strong preferences here.
If the name is descriptive and is being used multiple times in different methods(I think it is the case for some of the functions here). It makes sense to have it at the bottom.

On the other hand, if it is just a helper function, it makes sense to have it a function near the body or even a sub-function inside the main function. The readers can cleanly separate these things.

a: &'r (Option<usize>, usize),
b: &'s (Option<usize>, usize),
) -> std::cmp::Ordering {
a.0.map(|x| x as isize - a.1 as isize)
.cmp(&b.0.map(|x| x as isize - b.1 as isize))
}

// Function to pass to sort_by. Sort by (satisfaction cost - dissatisfaction cost).
//
// We sort by (satisfaction cost - dissatisfaction cost) to make a worst-case (the most
// costy satisfactions are satisfied, the most costy dissatisfactions are dissatisfied).
//
// Args are of form: (<count_sat>, <count_dissat>)
fn sat_minus_option_dissat<'r, 's>(
a: &'r (Option<usize>, Option<usize>),
b: &'s (Option<usize>, Option<usize>),
) -> std::cmp::Ordering {
a.0.map(|x| a.1.map(|y| x as isize - y as isize))
.cmp(&b.0.map(|x| b.1.map(|y| x as isize - y as isize)))
}

// Function to pass to sort_by. Sort by (satisfaction cost - dissatisfaction cost) of cost of witness.
//
// Args are of form: (<max_sat_size>, <count_dissat_size>)
// max_[dis]sat_size of form: (<cost_of_witness>, <cost_of_sciptsig>)
fn sat_minus_dissat_witness<'r, 's>(
a: &'r (Option<(usize, usize)>, Option<(usize, usize)>),
b: &'s (Option<(usize, usize)>, Option<(usize, usize)>),
) -> std::cmp::Ordering {
a.0.map(|x| a.1.map(|y| x.0 as isize - y.0 as isize))
.cmp(&b.0.map(|x| b.1.map(|y| x.0 as isize - y.0 as isize)))
}

/// Returns Some(max(x,y)) is both x and y are Some. Otherwise, returns `None`.
fn opt_max<T: Ord>(a: Option<T>, b: Option<T>) -> Option<T> {
if let (Some(x), Some(y)) = (a, b) {
Some(cmp::max(x, y))
Expand All @@ -1030,7 +1065,12 @@ fn opt_max<T: Ord>(a: Option<T>, b: Option<T>) -> Option<T> {
}
}

// Returns Some(x+y) is both x and y are Some. Otherwise, return none
/// Returns Some(x+y) is both x and y are Some. Otherwise, returns `None`.
fn opt_add(a: Option<usize>, b: Option<usize>) -> Option<usize> {
a.and_then(|x| b.map(|y| x + y))
}

/// Returns Some((x0+y0, x1+y1)) is both x and y are Some. Otherwise, returns `None`.
fn opt_tuple_add(a: Option<(usize, usize)>, b: Option<(usize, usize)>) -> Option<(usize, usize)> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In 28662d2:

Same here. I think this'd be clearer defined close to its use.

a.and_then(|x| b.map(|(w, s)| (w + x.0, s + x.1)))
}