Skip to content

Commit 5fe1466

Browse files
authored
[InstCombine] Simplify switch with selects (#84143)
An example from https://github.com/image-rs/image: ``` define void @test_ult_rhsc(i8 %x) { %val = add nsw i8 %x, -2 %cmp = icmp ult i8 %val, 11 %cond = select i1 %cmp, i8 %val, i8 6 switch i8 %cond, label %bb1 [ i8 0, label %bb2 i8 10, label %bb3 ] bb1: call void @func1() unreachable bb2: call void @func2() unreachable bb3: call void @func3() unreachable } ``` When `%cmp` evaluates to false, we can prove that the range of `%val` is [11, umax]. Thus we can safely replace `%cond` with `%val` since both `switch 6` and `switch %val` go to the default dest `%bb1`. Alive2: https://alive2.llvm.org/ce/z/uSTj6w Godbolt: https://godbolt.org/z/MGrG84bzr This patch will benefit many rust applications and some C/C++ applications (e.g., cvc5).
1 parent 06714e1 commit 5fe1466

File tree

2 files changed

+201
-0
lines changed

2 files changed

+201
-0
lines changed

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3572,6 +3572,38 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
35723572
return nullptr;
35733573
}
35743574

3575+
// Replaces (switch (select cond, X, C)/(select cond, C, X)) with (switch X) if
3576+
// we can prove that both (switch C) and (switch X) go to the default when cond
3577+
// is false/true.
3578+
static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
3579+
SelectInst *Select,
3580+
bool IsTrueArm) {
3581+
unsigned CstOpIdx = IsTrueArm ? 1 : 2;
3582+
auto *C = dyn_cast<ConstantInt>(Select->getOperand(CstOpIdx));
3583+
if (!C)
3584+
return nullptr;
3585+
3586+
BasicBlock *CstBB = SI.findCaseValue(C)->getCaseSuccessor();
3587+
if (CstBB != SI.getDefaultDest())
3588+
return nullptr;
3589+
Value *X = Select->getOperand(3 - CstOpIdx);
3590+
ICmpInst::Predicate Pred;
3591+
const APInt *RHSC;
3592+
if (!match(Select->getCondition(),
3593+
m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
3594+
return nullptr;
3595+
if (IsTrueArm)
3596+
Pred = ICmpInst::getInversePredicate(Pred);
3597+
3598+
// See whether we can replace the select with X
3599+
ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
3600+
for (auto Case : SI.cases())
3601+
if (!CR.contains(Case.getCaseValue()->getValue()))
3602+
return nullptr;
3603+
3604+
return X;
3605+
}
3606+
35753607
Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
35763608
Value *Cond = SI.getCondition();
35773609
Value *Op0;
@@ -3645,6 +3677,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
36453677
}
36463678
}
36473679

3680+
// Fold switch(select cond, X, Y) into switch(X/Y) if possible
3681+
if (auto *Select = dyn_cast<SelectInst>(Cond)) {
3682+
if (Value *V =
3683+
simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/true))
3684+
return replaceOperand(SI, 0, V);
3685+
if (Value *V =
3686+
simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/false))
3687+
return replaceOperand(SI, 0, V);
3688+
}
3689+
36483690
KnownBits Known = computeKnownBits(Cond, 0, &SI);
36493691
unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
36503692
unsigned LeadingKnownOnes = Known.countMinLeadingOnes();
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
4+
define void @test_ult_rhsc(i8 %x) {
5+
; CHECK-LABEL: define void @test_ult_rhsc(
6+
; CHECK-SAME: i8 [[X:%.*]]) {
7+
; CHECK-NEXT: switch i8 [[X]], label [[BB1:%.*]] [
8+
; CHECK-NEXT: i8 2, label [[BB2:%.*]]
9+
; CHECK-NEXT: i8 12, label [[BB3:%.*]]
10+
; CHECK-NEXT: ]
11+
; CHECK: bb1:
12+
; CHECK-NEXT: call void @func1()
13+
; CHECK-NEXT: unreachable
14+
; CHECK: bb2:
15+
; CHECK-NEXT: call void @func2()
16+
; CHECK-NEXT: unreachable
17+
; CHECK: bb3:
18+
; CHECK-NEXT: call void @func3()
19+
; CHECK-NEXT: unreachable
20+
;
21+
%val = add nsw i8 %x, -2
22+
%cmp = icmp ult i8 %val, 11
23+
%cond = select i1 %cmp, i8 %val, i8 6
24+
switch i8 %cond, label %bb1 [
25+
i8 0, label %bb2
26+
i8 10, label %bb3
27+
]
28+
29+
bb1:
30+
call void @func1()
31+
unreachable
32+
bb2:
33+
call void @func2()
34+
unreachable
35+
bb3:
36+
call void @func3()
37+
unreachable
38+
}
39+
40+
define void @test_eq_lhsc(i8 %x) {
41+
; CHECK-LABEL: define void @test_eq_lhsc(
42+
; CHECK-SAME: i8 [[X:%.*]]) {
43+
; CHECK-NEXT: switch i8 [[X]], label [[BB1:%.*]] [
44+
; CHECK-NEXT: i8 0, label [[BB2:%.*]]
45+
; CHECK-NEXT: i8 10, label [[BB3:%.*]]
46+
; CHECK-NEXT: ]
47+
; CHECK: bb1:
48+
; CHECK-NEXT: call void @func1()
49+
; CHECK-NEXT: unreachable
50+
; CHECK: bb2:
51+
; CHECK-NEXT: call void @func2()
52+
; CHECK-NEXT: unreachable
53+
; CHECK: bb3:
54+
; CHECK-NEXT: call void @func3()
55+
; CHECK-NEXT: unreachable
56+
;
57+
%cmp = icmp eq i8 %x, 4
58+
%cond = select i1 %cmp, i8 6, i8 %x
59+
switch i8 %cond, label %bb1 [
60+
i8 0, label %bb2
61+
i8 10, label %bb3
62+
]
63+
64+
bb1:
65+
call void @func1()
66+
unreachable
67+
bb2:
68+
call void @func2()
69+
unreachable
70+
bb3:
71+
call void @func3()
72+
unreachable
73+
}
74+
75+
define void @test_ult_rhsc_invalid_cond(i8 %x, i8 %y) {
76+
; CHECK-LABEL: define void @test_ult_rhsc_invalid_cond(
77+
; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
78+
; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
79+
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[Y]], 11
80+
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
81+
; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
82+
; CHECK-NEXT: i8 0, label [[BB2:%.*]]
83+
; CHECK-NEXT: i8 10, label [[BB3:%.*]]
84+
; CHECK-NEXT: i8 13, label [[BB3]]
85+
; CHECK-NEXT: ]
86+
; CHECK: bb1:
87+
; CHECK-NEXT: call void @func1()
88+
; CHECK-NEXT: unreachable
89+
; CHECK: bb2:
90+
; CHECK-NEXT: call void @func2()
91+
; CHECK-NEXT: unreachable
92+
; CHECK: bb3:
93+
; CHECK-NEXT: call void @func3()
94+
; CHECK-NEXT: unreachable
95+
;
96+
%val = add nsw i8 %x, -2
97+
%cmp = icmp ult i8 %y, 11
98+
%cond = select i1 %cmp, i8 %val, i8 6
99+
switch i8 %cond, label %bb1 [
100+
i8 0, label %bb2
101+
i8 10, label %bb3
102+
i8 13, label %bb3
103+
]
104+
105+
bb1:
106+
call void @func1()
107+
unreachable
108+
bb2:
109+
call void @func2()
110+
unreachable
111+
bb3:
112+
call void @func3()
113+
unreachable
114+
}
115+
116+
define void @test_ult_rhsc_fail(i8 %x) {
117+
; CHECK-LABEL: define void @test_ult_rhsc_fail(
118+
; CHECK-SAME: i8 [[X:%.*]]) {
119+
; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
120+
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
121+
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
122+
; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
123+
; CHECK-NEXT: i8 0, label [[BB2:%.*]]
124+
; CHECK-NEXT: i8 10, label [[BB3:%.*]]
125+
; CHECK-NEXT: i8 13, label [[BB3]]
126+
; CHECK-NEXT: ]
127+
; CHECK: bb1:
128+
; CHECK-NEXT: call void @func1()
129+
; CHECK-NEXT: unreachable
130+
; CHECK: bb2:
131+
; CHECK-NEXT: call void @func2()
132+
; CHECK-NEXT: unreachable
133+
; CHECK: bb3:
134+
; CHECK-NEXT: call void @func3()
135+
; CHECK-NEXT: unreachable
136+
;
137+
%val = add nsw i8 %x, -2
138+
%cmp = icmp ult i8 %val, 11
139+
%cond = select i1 %cmp, i8 %val, i8 6
140+
switch i8 %cond, label %bb1 [
141+
i8 0, label %bb2
142+
i8 10, label %bb3
143+
i8 13, label %bb3
144+
]
145+
146+
bb1:
147+
call void @func1()
148+
unreachable
149+
bb2:
150+
call void @func2()
151+
unreachable
152+
bb3:
153+
call void @func3()
154+
unreachable
155+
}
156+
157+
declare void @func1()
158+
declare void @func2()
159+
declare void @func3()

0 commit comments

Comments
 (0)