Skip to content

[InstCombine] Simplify switch with selects #84143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3572,6 +3572,38 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
return nullptr;
}

// Replaces (switch (select cond, X, C)/(select cond, C, X)) with (switch X) if
// we can prove that both (switch C) and (switch X) go to the default when cond
// is false/true.
static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
SelectInst *Select,
bool IsTrueArm) {
unsigned CstOpIdx = IsTrueArm ? 1 : 2;
auto *C = dyn_cast<ConstantInt>(Select->getOperand(CstOpIdx));
if (!C)
return nullptr;

BasicBlock *CstBB = SI.findCaseValue(C)->getCaseSuccessor();
if (CstBB != SI.getDefaultDest())
return nullptr;
Value *X = Select->getOperand(3 - CstOpIdx);
ICmpInst::Predicate Pred;
const APInt *RHSC;
if (!match(Select->getCondition(),
m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to only handle RHSC?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can improve this by using KnownBits/ConstantRange in the future. But I guess this patch covers most of the cases in rust applications.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think you could also technically generalize to RHSC (or Other arm really) implying any single case and not condition implying the same condition for the arm, although not sure if that would practically ever come up. Think we would just simplify the select.

return nullptr;
if (IsTrueArm)
Pred = ICmpInst::getInversePredicate(Pred);

// See whether we can replace the select with X
ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
for (auto Case : SI.cases())
if (!CR.contains(Case.getCaseValue()->getValue()))
return nullptr;

return X;
}

Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
Value *Cond = SI.getCondition();
Value *Op0;
Expand Down Expand Up @@ -3645,6 +3677,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
}
}

// Fold switch(select cond, X, Y) into switch(X/Y) if possible
if (auto *Select = dyn_cast<SelectInst>(Cond)) {
if (Value *V =
simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/true))
return replaceOperand(SI, 0, V);
if (Value *V =
simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/false))
return replaceOperand(SI, 0, V);
}

KnownBits Known = computeKnownBits(Cond, 0, &SI);
unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
unsigned LeadingKnownOnes = Known.countMinLeadingOnes();
Expand Down
159 changes: 159 additions & 0 deletions llvm/test/Transforms/InstCombine/switch-select.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
; RUN: opt < %s -passes=instcombine -S | FileCheck %s

define void @test_ult_rhsc(i8 %x) {
; CHECK-LABEL: define void @test_ult_rhsc(
; CHECK-SAME: i8 [[X:%.*]]) {
; CHECK-NEXT: switch i8 [[X]], label [[BB1:%.*]] [
; CHECK-NEXT: i8 2, label [[BB2:%.*]]
; CHECK-NEXT: i8 12, label [[BB3:%.*]]
; CHECK-NEXT: ]
; CHECK: bb1:
; CHECK-NEXT: call void @func1()
; CHECK-NEXT: unreachable
; CHECK: bb2:
; CHECK-NEXT: call void @func2()
; CHECK-NEXT: unreachable
; CHECK: bb3:
; CHECK-NEXT: call void @func3()
; CHECK-NEXT: unreachable
;
%val = add nsw i8 %x, -2
%cmp = icmp ult i8 %val, 11
%cond = select i1 %cmp, i8 %val, i8 6
switch i8 %cond, label %bb1 [
i8 0, label %bb2
i8 10, label %bb3
]

bb1:
call void @func1()
unreachable
bb2:
call void @func2()
unreachable
bb3:
call void @func3()
unreachable
}

define void @test_eq_lhsc(i8 %x) {
; CHECK-LABEL: define void @test_eq_lhsc(
; CHECK-SAME: i8 [[X:%.*]]) {
; CHECK-NEXT: switch i8 [[X]], label [[BB1:%.*]] [
; CHECK-NEXT: i8 0, label [[BB2:%.*]]
; CHECK-NEXT: i8 10, label [[BB3:%.*]]
; CHECK-NEXT: ]
; CHECK: bb1:
; CHECK-NEXT: call void @func1()
; CHECK-NEXT: unreachable
; CHECK: bb2:
; CHECK-NEXT: call void @func2()
; CHECK-NEXT: unreachable
; CHECK: bb3:
; CHECK-NEXT: call void @func3()
; CHECK-NEXT: unreachable
;
%cmp = icmp eq i8 %x, 4
%cond = select i1 %cmp, i8 6, i8 %x
switch i8 %cond, label %bb1 [
i8 0, label %bb2
i8 10, label %bb3
]

bb1:
call void @func1()
unreachable
bb2:
call void @func2()
unreachable
bb3:
call void @func3()
unreachable
}

define void @test_ult_rhsc_invalid_cond(i8 %x, i8 %y) {
; CHECK-LABEL: define void @test_ult_rhsc_invalid_cond(
; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[Y]], 11
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
; CHECK-NEXT: i8 0, label [[BB2:%.*]]
; CHECK-NEXT: i8 10, label [[BB3:%.*]]
; CHECK-NEXT: i8 13, label [[BB3]]
; CHECK-NEXT: ]
; CHECK: bb1:
; CHECK-NEXT: call void @func1()
; CHECK-NEXT: unreachable
; CHECK: bb2:
; CHECK-NEXT: call void @func2()
; CHECK-NEXT: unreachable
; CHECK: bb3:
; CHECK-NEXT: call void @func3()
; CHECK-NEXT: unreachable
;
%val = add nsw i8 %x, -2
%cmp = icmp ult i8 %y, 11
%cond = select i1 %cmp, i8 %val, i8 6
switch i8 %cond, label %bb1 [
i8 0, label %bb2
i8 10, label %bb3
i8 13, label %bb3
]

bb1:
call void @func1()
unreachable
bb2:
call void @func2()
unreachable
bb3:
call void @func3()
unreachable
}

define void @test_ult_rhsc_fail(i8 %x) {
; CHECK-LABEL: define void @test_ult_rhsc_fail(
; CHECK-SAME: i8 [[X:%.*]]) {
; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
; CHECK-NEXT: i8 0, label [[BB2:%.*]]
; CHECK-NEXT: i8 10, label [[BB3:%.*]]
; CHECK-NEXT: i8 13, label [[BB3]]
; CHECK-NEXT: ]
; CHECK: bb1:
; CHECK-NEXT: call void @func1()
; CHECK-NEXT: unreachable
; CHECK: bb2:
; CHECK-NEXT: call void @func2()
; CHECK-NEXT: unreachable
; CHECK: bb3:
; CHECK-NEXT: call void @func3()
; CHECK-NEXT: unreachable
;
%val = add nsw i8 %x, -2
%cmp = icmp ult i8 %val, 11
%cond = select i1 %cmp, i8 %val, i8 6
switch i8 %cond, label %bb1 [
i8 0, label %bb2
i8 10, label %bb3
i8 13, label %bb3
]

bb1:
call void @func1()
unreachable
bb2:
call void @func2()
unreachable
bb3:
call void @func3()
unreachable
}

declare void @func1()
declare void @func2()
declare void @func3()