Skip to content

Commit 5fb32c2

Browse files
committed
New MIR optimization pass to reduce branches on match of tuples of enums
1 parent 5e449b9 commit 5fb32c2

22 files changed

+2119
-0
lines changed
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
use crate::{
2+
transform::{MirPass, MirSource},
3+
util::patch::MirPatch,
4+
};
5+
use rustc_middle::mir::*;
6+
use rustc_middle::ty::{Ty, TyCtxt};
7+
use std::{borrow::Cow, fmt::Debug};
8+
9+
/// This pass optimizes something like
10+
/// ```text
11+
/// let x: Option<()>;
12+
/// let y: Option<()>;
13+
/// match (x,y) {
14+
/// (Some(_), Some(_)) => {0},
15+
/// _ => {1}
16+
/// }
17+
/// ```
18+
/// into something like
19+
/// ```text
20+
/// let x: Option<()>;
21+
/// let y: Option<()>;
22+
/// let discriminant_x = // get discriminant of x
23+
/// let discriminant_y = // get discriminant of x
24+
/// if discriminant_x != discriminant_y {1} else {0}
25+
/// ```
26+
pub struct EarlyOtherwiseBranch;
27+
28+
impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
29+
fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) {
30+
if tcx.sess.opts.debugging_opts.mir_opt_level < 3 {
31+
return;
32+
}
33+
trace!("running EarlyOtherwiseBranch on {:?}", source);
34+
// we are only interested in this bb if the terminator is a switchInt
35+
let bbs_with_switch =
36+
body.basic_blocks().iter_enumerated().filter(|(_, bb)| is_switch(bb.terminator()));
37+
38+
let opts_to_apply: Vec<OptimizationToApply<'tcx>> = bbs_with_switch
39+
.flat_map(|(bb_idx, bb)| {
40+
let switch = bb.terminator();
41+
let helper = Helper { body, tcx };
42+
let infos = helper.go(bb, switch)?;
43+
Some(OptimizationToApply { infos, basic_block_first_switch: bb_idx })
44+
})
45+
.collect();
46+
47+
for opt_to_apply in opts_to_apply {
48+
trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_to_apply);
49+
// create the patch using MirPatch
50+
let mut patch = MirPatch::new(body);
51+
52+
// create temp to store second discriminant in
53+
let discr_type = opt_to_apply.infos[0].second_switch_info.discr_ty;
54+
let discr_span = opt_to_apply.infos[0].second_switch_info.discr_source_info.span;
55+
let temp = patch.new_temp(discr_type, discr_span);
56+
let statements_before =
57+
body.basic_blocks()[opt_to_apply.basic_block_first_switch].statements.len();
58+
let end_of_block_location = Location {
59+
block: opt_to_apply.basic_block_first_switch,
60+
statement_index: statements_before,
61+
};
62+
patch.add_statement(end_of_block_location, StatementKind::StorageLive(temp));
63+
64+
// create assignment of discriminant
65+
let place_of_adt_to_get_discriminant_of =
66+
opt_to_apply.infos[0].second_switch_info.place_of_adt_discr_read;
67+
patch.add_assign(
68+
end_of_block_location,
69+
Place::from(temp),
70+
Rvalue::Discriminant(place_of_adt_to_get_discriminant_of),
71+
);
72+
73+
// create temp to store NotEqual comparison between the two discriminants
74+
let not_equal = BinOp::Ne;
75+
let not_equal_res_type = not_equal.ty(tcx, discr_type, discr_type);
76+
let not_equal_temp = patch.new_temp(not_equal_res_type, discr_span);
77+
patch.add_statement(end_of_block_location, StatementKind::StorageLive(not_equal_temp));
78+
79+
// create NotEqual comparison between the two discriminants
80+
let first_descriminant_place =
81+
opt_to_apply.infos[0].first_switch_info.discr_used_in_switch;
82+
let not_equal_rvalue = Rvalue::BinaryOp(
83+
not_equal,
84+
Operand::Copy(Place::from(temp)),
85+
Operand::Copy(Place::from(first_descriminant_place)),
86+
);
87+
patch.add_statement(
88+
end_of_block_location,
89+
StatementKind::Assign(box (Place::from(not_equal_temp), not_equal_rvalue)),
90+
);
91+
92+
let (mut targets_to_jump_to, values_to_jump_to): (Vec<_>, Vec<_>) = opt_to_apply
93+
.infos
94+
.iter()
95+
.flat_map(|x| x.second_switch_info.targets_with_values.iter())
96+
.cloned()
97+
.unzip();
98+
99+
// add otherwise case in the end
100+
targets_to_jump_to.push(opt_to_apply.infos[0].first_switch_info.otherwise_bb);
101+
// new block that jumps to the correct discriminant case. This block is switched to if the discriminants are equal
102+
let new_switch_data = BasicBlockData::new(Some(Terminator {
103+
source_info: opt_to_apply.infos[0].second_switch_info.discr_source_info,
104+
kind: TerminatorKind::SwitchInt {
105+
// the first and second discriminants are equal, so just pick one
106+
discr: Operand::Copy(first_descriminant_place),
107+
switch_ty: discr_type,
108+
values: Cow::from(values_to_jump_to),
109+
targets: targets_to_jump_to,
110+
},
111+
}));
112+
113+
let new_switch_bb = patch.new_block(new_switch_data);
114+
115+
// switch on the NotEqual. If true, then jump to the `otherwise` case.
116+
// If false, then jump to a basic block that then jumps to the correct disciminant case
117+
let true_case = opt_to_apply.infos[0].first_switch_info.otherwise_bb;
118+
let false_case = new_switch_bb;
119+
patch.patch_terminator(
120+
opt_to_apply.basic_block_first_switch,
121+
TerminatorKind::if_(
122+
tcx,
123+
Operand::Move(Place::from(not_equal_temp)),
124+
true_case,
125+
false_case,
126+
),
127+
);
128+
129+
// generate StorageDead for the temp not in use anymore. We use the not_equal_temp in the switch, so we can't mark that dead
130+
patch.add_statement(end_of_block_location, StatementKind::StorageDead(temp));
131+
132+
patch.apply(body);
133+
}
134+
}
135+
}
136+
137+
fn is_switch<'tcx>(terminator: &Terminator<'tcx>) -> bool {
138+
match terminator.kind {
139+
TerminatorKind::SwitchInt { .. } => true,
140+
_ => false,
141+
}
142+
}
143+
144+
struct Helper<'a, 'tcx> {
145+
body: &'a Body<'tcx>,
146+
tcx: TyCtxt<'tcx>,
147+
}
148+
149+
#[derive(Debug, Clone)]
150+
struct SwitchDiscriminantInfo<'tcx> {
151+
/// Type of the discriminant being switched on
152+
discr_ty: Ty<'tcx>,
153+
/// The basic block that the otherwise branch points to
154+
otherwise_bb: BasicBlock,
155+
/// Target along with the value being branched from. Otherwise is not included
156+
targets_with_values: Vec<(BasicBlock, u128)>,
157+
discr_source_info: SourceInfo,
158+
/// The place of the discriminant used in the switch
159+
discr_used_in_switch: Place<'tcx>,
160+
/// The place of the adt that has its discriminant read
161+
place_of_adt_discr_read: Place<'tcx>,
162+
/// The type of the adt that has its discriminant read
163+
type_adt_matched_on: Ty<'tcx>,
164+
}
165+
166+
#[derive(Debug)]
167+
struct OptimizationToApply<'tcx> {
168+
infos: Vec<OptimizationInfo<'tcx>>,
169+
/// Basic block of the original first switch
170+
basic_block_first_switch: BasicBlock,
171+
}
172+
173+
#[derive(Debug)]
174+
struct OptimizationInfo<'tcx> {
175+
/// Info about the first switch and discriminant
176+
first_switch_info: SwitchDiscriminantInfo<'tcx>,
177+
/// Info about the second switch and discriminant
178+
second_switch_info: SwitchDiscriminantInfo<'tcx>,
179+
}
180+
181+
impl<'a, 'tcx> Helper<'a, 'tcx> {
182+
pub fn go(
183+
&self,
184+
bb: &BasicBlockData<'tcx>,
185+
switch: &Terminator<'tcx>,
186+
) -> Option<Vec<OptimizationInfo<'tcx>>> {
187+
// try to find the statement that defines the discriminant that is used for the switch
188+
let discr = self.find_switch_discriminant_info(bb, switch)?;
189+
190+
// go through each target, finding a discriminant read, and a switch
191+
let results = discr.targets_with_values.iter().map(|(target, value)| {
192+
self.find_discriminant_switch_pairing(&discr, target.clone(), value.clone())
193+
});
194+
195+
// if the optimization did not apply for one of the targets, then abort
196+
if results.clone().any(|x| x.is_none()) || results.len() == 0 {
197+
trace!("NO: not all of the targets matched the pattern for optimization");
198+
return None;
199+
}
200+
201+
Some(results.flatten().collect())
202+
}
203+
204+
fn find_discriminant_switch_pairing(
205+
&self,
206+
discr_info: &SwitchDiscriminantInfo<'tcx>,
207+
target: BasicBlock,
208+
value: u128,
209+
) -> Option<OptimizationInfo<'tcx>> {
210+
let bb = &self.body.basic_blocks()[target];
211+
// find switch
212+
let terminator = bb.terminator();
213+
if is_switch(terminator) {
214+
let this_bb_discr_info = self.find_switch_discriminant_info(bb, terminator)?;
215+
216+
// the types of the two adts matched on have to be equalfor this optimization to apply
217+
if discr_info.type_adt_matched_on != this_bb_discr_info.type_adt_matched_on {
218+
trace!(
219+
"NO: types do not match. LHS: {:?}, RHS: {:?}",
220+
discr_info.type_adt_matched_on,
221+
this_bb_discr_info.type_adt_matched_on
222+
);
223+
return None;
224+
}
225+
226+
// the otherwise branch of the two switches have to point to the same bb
227+
if discr_info.otherwise_bb != this_bb_discr_info.otherwise_bb {
228+
trace!("NO: otherwise target is not the same");
229+
return None;
230+
}
231+
232+
// check that the value being matched on is the same. The
233+
if this_bb_discr_info.targets_with_values.iter().find(|x| x.1 == value).is_none() {
234+
trace!("NO: values being matched on are not the same");
235+
return None;
236+
}
237+
238+
// only allow optimization if the left and right of the tuple being matched are the same variants.
239+
// so the following should not optimize
240+
// ```rust
241+
// let x: Option<()>;
242+
// let y: Option<()>;
243+
// match (x,y) {
244+
// (Some(_), None) => {},
245+
// _ => {}
246+
// }
247+
// ```
248+
// We check this by seeing that the value of the first discriminant is the only other discriminant value being used as a target in the second switch
249+
if !(this_bb_discr_info.targets_with_values.len() == 1
250+
&& this_bb_discr_info.targets_with_values[0].1 == value)
251+
{
252+
trace!(
253+
"NO: The second switch did not have only 1 target (besides otherwise) that had the same value as the value from the first switch that got us here"
254+
);
255+
return None;
256+
}
257+
258+
// if we reach this point, the optimization applies, and we should be able to optimize this case
259+
// store the info that is needed to apply the optimization
260+
261+
Some(OptimizationInfo {
262+
first_switch_info: discr_info.clone(),
263+
second_switch_info: this_bb_discr_info,
264+
})
265+
} else {
266+
None
267+
}
268+
}
269+
270+
fn find_switch_discriminant_info(
271+
&self,
272+
bb: &BasicBlockData<'tcx>,
273+
switch: &Terminator<'tcx>,
274+
) -> Option<SwitchDiscriminantInfo<'tcx>> {
275+
match &switch.kind {
276+
TerminatorKind::SwitchInt { discr, targets, values, .. } => {
277+
let discr_local = discr.place()?.as_local()?;
278+
// the declaration of the discriminant read. Place of this read is being used in the switch
279+
let discr_decl = &self.body.local_decls()[discr_local];
280+
let discr_ty = discr_decl.ty;
281+
// the otherwise target lies as the last element
282+
let otherwise_bb = targets.get(values.len())?.clone();
283+
let targets_with_values = targets
284+
.iter()
285+
.zip(values.iter())
286+
.map(|(t, v)| (t.clone(), v.clone()))
287+
.collect();
288+
289+
// find the place of the adt where the discriminant is being read from
290+
// assume this is the last statement of the block
291+
let place_of_adt_discr_read = match bb.statements.last()?.kind {
292+
StatementKind::Assign(box (_, Rvalue::Discriminant(adt_place))) => {
293+
Some(adt_place)
294+
}
295+
_ => None,
296+
}?;
297+
298+
let type_adt_matched_on = place_of_adt_discr_read.ty(self.body, self.tcx).ty;
299+
300+
Some(SwitchDiscriminantInfo {
301+
discr_used_in_switch: discr.place()?,
302+
discr_ty,
303+
otherwise_bb,
304+
targets_with_values,
305+
discr_source_info: discr_decl.source_info,
306+
place_of_adt_discr_read,
307+
type_adt_matched_on,
308+
})
309+
}
310+
_ => unreachable!("must only be passed terminator that is a switch"),
311+
}
312+
}
313+
}

compiler/rustc_mir/src/transform/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub mod copy_prop;
2626
pub mod deaggregator;
2727
pub mod dest_prop;
2828
pub mod dump_mir;
29+
pub mod early_otherwise_branch;
2930
pub mod elaborate_drops;
3031
pub mod generator;
3132
pub mod inline;
@@ -465,6 +466,7 @@ fn run_optimization_passes<'tcx>(
465466
&instcombine::InstCombine,
466467
&const_prop::ConstProp,
467468
&simplify_branches::SimplifyBranches::new("after-const-prop"),
469+
&early_otherwise_branch::EarlyOtherwiseBranch,
468470
&simplify_comparison_integral::SimplifyComparisonIntegral,
469471
&simplify_try::SimplifyArmIdentity,
470472
&simplify_try::SimplifyBranchSame,

0 commit comments

Comments
 (0)