Skip to content

Commit 7008cce

Browse files
committed
New MIR opt pass simplify_pow_of_two
1 parent 89acdae commit 7008cce

20 files changed

+820
-0
lines changed

compiler/rustc_middle/src/mir/interpret/value.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ impl<'tcx> ConstValue<'tcx> {
9393
ConstValue::Scalar(Scalar::from_bool(b))
9494
}
9595

96+
pub fn from_u32(i: u32) -> Self {
97+
ConstValue::Scalar(Scalar::from_u32(i))
98+
}
99+
96100
pub fn from_u64(i: u64) -> Self {
97101
ConstValue::Scalar(Scalar::from_u64(i))
98102
}

compiler/rustc_middle/src/ty/sty.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,10 @@ impl<'tcx> Region<'tcx> {
18721872

18731873
/// Constructors for `Ty`
18741874
impl<'tcx> Ty<'tcx> {
1875+
pub fn new_bool(tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
1876+
Ty::new(tcx, TyKind::Bool)
1877+
}
1878+
18751879
// Avoid this in favour of more specific `new_*` methods, where possible.
18761880
#[allow(rustc::usage_of_ty_tykind)]
18771881
#[inline]

compiler/rustc_mir_transform/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ mod required_consts;
9595
mod reveal_all;
9696
mod separate_const_switch;
9797
mod shim;
98+
mod simplify_pow_of_two;
9899
mod ssa;
99100
// This pass is public to allow external drivers to perform MIR cleanup
100101
mod check_alignment;
@@ -546,6 +547,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
546547
&lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first
547548
&unreachable_prop::UnreachablePropagation,
548549
&uninhabited_enum_branching::UninhabitedEnumBranching,
550+
&simplify_pow_of_two::SimplifyPowOfTwo,
549551
&o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching),
550552
&inline::Inline,
551553
&remove_storage_markers::RemoveStorageMarkers,
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
//! A pass that checks for and simplifies calls to `pow` where the receiver is a power of
2+
//! two. This can be done with `<<` instead.
3+
4+
use crate::MirPass;
5+
use rustc_const_eval::interpret::{ConstValue, Scalar};
6+
use rustc_hir::definitions::{DefPathData, DisambiguatedDefPathData};
7+
use rustc_middle::mir::patch::MirPatch;
8+
use rustc_middle::mir::*;
9+
use rustc_middle::ty::{self, Ty, TyCtxt, UintTy};
10+
use rustc_span::sym;
11+
use rustc_target::abi::FieldIdx;
12+
13+
pub struct SimplifyPowOfTwo;
14+
15+
impl<'tcx> MirPass<'tcx> for SimplifyPowOfTwo {
16+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
17+
let mut patch = MirPatch::new(body);
18+
19+
for (i, bb) in body.basic_blocks.iter_enumerated() {
20+
let term = bb.terminator();
21+
let source_info = term.source_info;
22+
let span = source_info.span;
23+
24+
if let TerminatorKind::Call {
25+
func,
26+
args,
27+
destination,
28+
target: Some(target),
29+
unwind,
30+
call_source: CallSource::Normal,
31+
..
32+
} = &term.kind
33+
&& let Some(def_id) = func.const_fn_def().map(|def| def.0)
34+
&& let def_path = tcx.def_path(def_id)
35+
&& tcx.crate_name(def_path.krate) == sym::core
36+
// FIXME(Centri3): I feel like we should do this differently...
37+
&& let [
38+
DisambiguatedDefPathData { data: DefPathData::TypeNs(sym::num), disambiguator: 0 },
39+
DisambiguatedDefPathData { data: DefPathData::Impl, .. },
40+
DisambiguatedDefPathData { data: DefPathData::ValueNs(sym::pow), .. },
41+
] = &*def_path.data
42+
&& let [recv, exp] = args.as_slice()
43+
&& let Some(recv_const) = recv.constant()
44+
&& let ConstantKind::Val(
45+
ConstValue::Scalar(Scalar::Int(recv_int)),
46+
recv_ty,
47+
) = recv_const.literal
48+
&& let Ok(recv_val) = match recv_ty.kind() {
49+
ty::Int(_) => {
50+
let result = recv_int.try_to_int(recv_int.size()).unwrap_or(-1).max(0);
51+
if result > 0 {
52+
Ok(result as u128)
53+
} else {
54+
continue;
55+
}
56+
},
57+
ty::Uint(_) => recv_int.try_to_uint(recv_int.size()),
58+
_ => continue,
59+
}
60+
&& let power_used = f32::log2(recv_val as f32)
61+
// Precision loss means it's not a power of two
62+
&& power_used == (power_used as u32) as f32
63+
// `0` would be `1.pow()`, which we shouldn't try to optimize as it's
64+
// already entirely optimized away
65+
&& power_used != 0.0
66+
// Same here
67+
&& recv_val != 0
68+
{
69+
let power_used = power_used as u32;
70+
let loc = Location { block: i, statement_index: bb.statements.len() };
71+
let exp_ty = Ty::new(tcx, ty::Uint(UintTy::U32));
72+
let checked_mul =
73+
patch.new_temp(Ty::new_tup(tcx, &[exp_ty, Ty::new_bool(tcx)]), span);
74+
75+
// If this is not `2.pow(...)`, we need to multiply the number of times we
76+
// shift the bits left by the receiver's power of two used, e.g.:
77+
//
78+
// > 2 -> 1
79+
// > 4 -> 2
80+
// > 16 -> 4
81+
// > 256 -> 8
82+
//
83+
// If this is `1`, then we *could* remove this entirely but it'll be
84+
// optimized out anyway by later passes (or perhaps LLVM) so it's entirely
85+
// unnecessary to do so.
86+
patch.add_assign(
87+
loc,
88+
checked_mul.into(),
89+
Rvalue::CheckedBinaryOp(
90+
BinOp::Mul,
91+
Box::new((
92+
exp.clone(),
93+
Operand::Constant(Box::new(Constant {
94+
span,
95+
user_ty: None,
96+
literal: ConstantKind::Val(
97+
ConstValue::from_u32(power_used),
98+
exp_ty,
99+
),
100+
})),
101+
)),
102+
),
103+
);
104+
105+
let num_shl = tcx.mk_place_field(checked_mul.into(), FieldIdx::from_u32(0), exp_ty);
106+
let mul_result =
107+
tcx.mk_place_field(checked_mul.into(), FieldIdx::from_u32(1), Ty::new_bool(tcx));
108+
let shl_result = patch.new_temp(Ty::new_bool(tcx), span);
109+
110+
// Whether the shl will overflow, if so we return 0
111+
patch.add_assign(
112+
loc,
113+
shl_result.into(),
114+
Rvalue::BinaryOp(
115+
BinOp::Lt,
116+
Box::new((
117+
Operand::Copy(num_shl),
118+
Operand::Constant(Box::new(Constant {
119+
span,
120+
user_ty: None,
121+
literal: ConstantKind::Val(ConstValue::from_u32(32), exp_ty),
122+
})),
123+
)),
124+
),
125+
);
126+
127+
let should_be_zero_bool = patch.new_temp(Ty::new_bool(tcx), span);
128+
let should_be_zero = patch.new_temp(recv_ty, span);
129+
130+
patch.add_assign(
131+
loc,
132+
should_be_zero_bool.into(),
133+
Rvalue::BinaryOp(
134+
BinOp::BitOr,
135+
Box::new((
136+
Operand::Copy(mul_result.into()),
137+
Operand::Copy(shl_result.into()),
138+
)),
139+
),
140+
);
141+
142+
patch.add_assign(
143+
loc,
144+
should_be_zero.into(),
145+
Rvalue::Cast(
146+
CastKind::IntToInt,
147+
Operand::Copy(should_be_zero_bool.into()),
148+
recv_ty,
149+
),
150+
);
151+
152+
let shl_exp_ty = patch.new_temp(exp_ty, span);
153+
let shl = patch.new_temp(recv_ty, span);
154+
155+
patch.add_assign(
156+
loc,
157+
shl_exp_ty.into(),
158+
Rvalue::BinaryOp(
159+
BinOp::Shl,
160+
Box::new((
161+
Operand::Constant(Box::new(Constant {
162+
span,
163+
user_ty: None,
164+
literal: ConstantKind::Val(ConstValue::from_u32(1), exp_ty),
165+
})),
166+
Operand::Copy(num_shl.into()),
167+
)),
168+
),
169+
);
170+
171+
patch.add_assign(
172+
loc,
173+
shl.into(),
174+
Rvalue::Cast(
175+
CastKind::IntToInt,
176+
Operand::Copy(shl_exp_ty.into()),
177+
recv_ty,
178+
),
179+
);
180+
181+
patch.add_assign(
182+
loc,
183+
*destination,
184+
Rvalue::BinaryOp(
185+
BinOp::MulUnchecked,
186+
Box::new((Operand::Copy(shl.into()), Operand::Copy(should_be_zero.into()))),
187+
),
188+
);
189+
190+
// shl doesn't set the overflow flag on x86_64 or even in Rust, so shr to
191+
// see if it overflowed. If it equals 1, it did not, but we also need to
192+
// check `shl_result` to ensure that if this is a multiple of the type's
193+
// size it won't wrap back over to 1
194+
//
195+
// FIXME(Centri3): Do we use `debug_assertions` or `overflow_checks` here?
196+
if tcx.sess.opts.debug_assertions {
197+
let shr = patch.new_temp(recv_ty, span);
198+
let shl_eq_shr = patch.new_temp(Ty::new_bool(tcx), span);
199+
let overflowed = patch.new_temp(Ty::new_bool(tcx), span);
200+
201+
patch.add_assign(
202+
loc,
203+
shr.into(),
204+
Rvalue::BinaryOp(
205+
BinOp::Shr,
206+
Box::new((Operand::Copy(shl.into()), Operand::Copy(num_shl.into()))),
207+
),
208+
);
209+
210+
patch.add_assign(
211+
loc,
212+
shl_eq_shr.into(),
213+
Rvalue::BinaryOp(
214+
BinOp::Eq,
215+
Box::new((Operand::Copy(shl.into()), Operand::Copy(shr.into()))),
216+
),
217+
);
218+
219+
patch.add_assign(
220+
loc,
221+
overflowed.into(),
222+
Rvalue::BinaryOp(
223+
BinOp::BitAnd,
224+
Box::new((Operand::Copy(shl_eq_shr.into()), Operand::Copy(shl_result.into()))),
225+
),
226+
);
227+
228+
patch.patch_terminator(
229+
i,
230+
TerminatorKind::Assert {
231+
cond: Operand::Copy(overflowed.into()),
232+
expected: true,
233+
msg: Box::new(AssertMessage::Overflow(
234+
// For consistency with the previous error message, though
235+
// it's technically incorrect
236+
BinOp::Mul,
237+
Operand::Constant(Box::new(Constant {
238+
span,
239+
user_ty: None,
240+
literal: ConstantKind::Val(ConstValue::Scalar(Scalar::from_u32(1)), exp_ty),
241+
})),
242+
Operand::Copy(num_shl.into()),
243+
)),
244+
target: *target,
245+
unwind: *unwind,
246+
},
247+
);
248+
} else {
249+
patch.patch_terminator(i, TerminatorKind::Goto { target: *target });
250+
}
251+
}
252+
}
253+
254+
patch.apply(body);
255+
}
256+
}

compiler/rustc_span/src/symbol.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,7 @@ symbols! {
10541054
not,
10551055
notable_trait,
10561056
note,
1057+
num,
10571058
object_safe_for_dispatch,
10581059
of,
10591060
offset,
@@ -1121,6 +1122,7 @@ symbols! {
11211122
poll,
11221123
position,
11231124
post_dash_lto: "post-lto",
1125+
pow,
11241126
powerpc_target_feature,
11251127
powf32,
11261128
powf64,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// compile-flags: -Coverflow-checks=false
2+
3+
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_2_u.SimplifyPowOfTwo.diff
4+
fn slow_2_u(a: u32) -> u32 {
5+
2u32.pow(a)
6+
}
7+
8+
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_2_i.SimplifyPowOfTwo.diff
9+
fn slow_2_i(a: u32) -> i32 {
10+
2i32.pow(a)
11+
}
12+
13+
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_4_u.SimplifyPowOfTwo.diff
14+
fn slow_4_u(a: u32) -> u32 {
15+
4u32.pow(a)
16+
}
17+
18+
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_4_i.SimplifyPowOfTwo.diff
19+
fn slow_4_i(a: u32) -> i32 {
20+
4i32.pow(a)
21+
}
22+
23+
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_256_u.SimplifyPowOfTwo.diff
24+
fn slow_256_u(a: u32) -> u32 {
25+
256u32.pow(a)
26+
}
27+
28+
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_256_i.SimplifyPowOfTwo.diff
29+
fn slow_256_i(a: u32) -> i32 {
30+
256i32.pow(a)
31+
}
32+
33+
fn main() {
34+
slow_2_u(0);
35+
slow_2_i(0);
36+
slow_2_u(1);
37+
slow_2_i(1);
38+
slow_2_u(2);
39+
slow_2_i(2);
40+
slow_4_u(4);
41+
slow_4_i(4);
42+
slow_4_u(15);
43+
slow_4_i(15);
44+
slow_4_u(16);
45+
slow_4_i(16);
46+
slow_4_u(17);
47+
slow_4_i(17);
48+
slow_256_u(2);
49+
slow_256_i(2);
50+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
- // MIR for `slow_256_i` before SimplifyPowOfTwo
2+
+ // MIR for `slow_256_i` after SimplifyPowOfTwo
3+
4+
fn slow_256_i(_1: u32) -> i32 {
5+
debug a => _1;
6+
let mut _0: i32;
7+
let mut _2: u32;
8+
+ let mut _3: (u32, bool);
9+
+ let mut _4: bool;
10+
+ let mut _5: bool;
11+
+ let mut _6: i32;
12+
+ let mut _7: u32;
13+
+ let mut _8: i32;
14+
15+
bb0: {
16+
StorageLive(_2);
17+
_2 = _1;
18+
- _0 = core::num::<impl i32>::pow(const 256_i32, move _2) -> [return: bb1, unwind unreachable];
19+
+ _3 = CheckedMul(move _2, const 8_u32);
20+
+ _4 = Lt((_3.0: u32), const 32_u32);
21+
+ _5 = BitOr((_3.1: bool), _4);
22+
+ _6 = _5 as i32 (IntToInt);
23+
+ _7 = Shl(const 1_u32, (_3.0: u32));
24+
+ _8 = _7 as i32 (IntToInt);
25+
+ _0 = MulUnchecked(_8, _6);
26+
+ goto -> bb1;
27+
}
28+
29+
bb1: {
30+
StorageDead(_2);
31+
return;
32+
}
33+
}
34+

0 commit comments

Comments
 (0)