Skip to content

[InstCombine] Simplify switch with selects #84143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 15, 2024

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Mar 6, 2024

An example from https://github.com/image-rs/image:

define void @test_ult_rhsc(i8 %x) {
  %val = add nsw i8 %x, -2
  %cmp = icmp ult i8 %val, 11
  %cond = select i1 %cmp, i8 %val, i8 6
  switch i8 %cond, label %bb1 [
  i8 0, label %bb2
  i8 10, label %bb3
  ]

bb1:
  call void @func1()
  unreachable
bb2:
  call void @func2()
  unreachable
bb3:
  call void @func3()
  unreachable
}

When %cmp evaluates to false, we can prove that the range of %val is [11, umax]. Thus we can safely replace %cond with %val since both switch 6 and switch %val go to the default dest %bb1.

Alive2: https://alive2.llvm.org/ce/z/uSTj6w
Godbolt: https://godbolt.org/z/MGrG84bzr

This patch will benefit many rust applications and some C/C++ applications (e.g., cvc5).

@dtcxzyw dtcxzyw requested review from dianqk and goldsteinn March 6, 2024 08:59
@dtcxzyw dtcxzyw requested a review from nikic as a code owner March 6, 2024 08:59
@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

An example from https://github.com/image-rs/image:

define void @<!-- -->test_ult_rhsc(i8 %x) {
  %val = add nsw i8 %x, -2
  %cmp = icmp ult i8 %val, 11
  %cond = select i1 %cmp, i8 %val, i8 6
  switch i8 %cond, label %bb1 [
  i8 0, label %bb2
  i8 10, label %bb3
  ]

bb1:
  call void @<!-- -->func1()
  unreachable
bb2:
  call void @<!-- -->func2()
  unreachable
bb3:
  call void @<!-- -->func3()
  unreachable
}

When %cmp evaluates to false, we can prove that the range of %val is [11, umax]. Thus we can safely replace %cond with %val since both switch 6 and switch %val go to the default dest %bb1.

Alive2: https://alive2.llvm.org/ce/z/uSTj6w

This patch will benefit many rust applications and some C/C++ applications (e.g., cvc5).


Full diff: https://github.com/llvm/llvm-project/pull/84143.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+41)
  • (added) llvm/test/Transforms/InstCombine/switch-select.ll (+159)
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 80ce0c9275b2cb..3d231f9fc7ff4b 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3334,6 +3334,37 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
   return nullptr;
 }
 
+// Replaces (switch (select cond, X, C)/(select cond, C, X)) with (switch X) if
+// we can prove that both (switch C) and (switch X) go to the default when cond
+// is false/true.
+static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
+                                                SelectInst *Select,
+                                                unsigned CstOpIdx) {
+  auto *C = dyn_cast<ConstantInt>(Select->getOperand(CstOpIdx));
+  if (!C)
+    return nullptr;
+
+  BasicBlock *CstBB = SI.findCaseValue(C)->getCaseSuccessor();
+  if (CstBB != SI.getDefaultDest())
+    return nullptr;
+  Value *X = Select->getOperand(3 - CstOpIdx);
+  ICmpInst::Predicate Pred;
+  const APInt *RHSC;
+  if (!match(Select->getCondition(),
+             m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
+    return nullptr;
+  if (CstOpIdx == 1)
+    Pred = ICmpInst::getInversePredicate(Pred);
+
+  // See whether we can replace the select with X
+  ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
+  for (auto Case : SI.cases())
+    if (!CR.contains(Case.getCaseValue()->getValue()))
+      return nullptr;
+
+  return X;
+}
+
 Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
   Value *Cond = SI.getCondition();
   Value *Op0;
@@ -3407,6 +3438,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
     }
   }
 
+  // Fold switch(select cond, X, Y) into switch(X/Y) if possible
+  if (auto *Select = dyn_cast<SelectInst>(Cond)) {
+    if (Value *V =
+            simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/1))
+      return replaceOperand(SI, 0, V);
+    if (Value *V =
+            simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/2))
+      return replaceOperand(SI, 0, V);
+  }
+
   KnownBits Known = computeKnownBits(Cond, 0, &SI);
   unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
   unsigned LeadingKnownOnes = Known.countMinLeadingOnes();
diff --git a/llvm/test/Transforms/InstCombine/switch-select.ll b/llvm/test/Transforms/InstCombine/switch-select.ll
new file mode 100644
index 00000000000000..60757c5d22527f
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/switch-select.ll
@@ -0,0 +1,159 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define void @test_ult_rhsc(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    switch i8 [[X]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 2, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 12, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %val, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_eq_lhsc(i8 %x) {
+; CHECK-LABEL: define void @test_eq_lhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    switch i8 [[X]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %cmp = icmp eq i8 %x, 4
+  %cond = select i1 %cmp, i8 6, i8 %x
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_ult_rhsc_invalid_cond(i8 %x, i8 %y) {
+; CHECK-LABEL: define void @test_ult_rhsc_invalid_cond(
+; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
+; CHECK-NEXT:    [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[Y]], 11
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:      i8 13, label [[BB3]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %y, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  i8 13, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_ult_rhsc_fail(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc_fail(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:      i8 13, label [[BB3]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %val, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  i8 13, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+declare void @func1()
+declare void @func2()
+declare void @func3()

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Mar 6, 2024
ICmpInst::Predicate Pred;
const APInt *RHSC;
if (!match(Select->getCondition(),
m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to only handle RHSC?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can improve this by using KnownBits/ConstantRange in the future. But I guess this patch covers most of the cases in rust applications.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think you could also technically generalize to RHSC (or Other arm really) implying any single case and not condition implying the same condition for the arm, although not sure if that would practically ever come up. Think we would just simplify the select.

@dtcxzyw dtcxzyw force-pushed the perf/switch-on-select branch from f782c5e to 1828c72 Compare March 22, 2024 07:14
Copy link
Member

@dianqk dianqk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Apr 11, 2024

Ping.

@goldsteinn
Copy link
Contributor

LGTM to me as well.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dtcxzyw dtcxzyw merged commit 5fe1466 into llvm:main Apr 15, 2024
@dtcxzyw dtcxzyw deleted the perf/switch-on-select branch April 15, 2024 08:40
bazuzi pushed a commit to bazuzi/llvm-project that referenced this pull request Apr 15, 2024
An example from https://github.com/image-rs/image:
```
define void @test_ult_rhsc(i8 %x) {
  %val = add nsw i8 %x, -2
  %cmp = icmp ult i8 %val, 11
  %cond = select i1 %cmp, i8 %val, i8 6
  switch i8 %cond, label %bb1 [
  i8 0, label %bb2
  i8 10, label %bb3
  ]

bb1:
  call void @func1()
  unreachable
bb2:
  call void @func2()
  unreachable
bb3:
  call void @func3()
  unreachable
}
```

When `%cmp` evaluates to false, we can prove that the range of `%val` is
[11, umax]. Thus we can safely replace `%cond` with `%val` since both
`switch 6` and `switch %val` go to the default dest `%bb1`.

Alive2: https://alive2.llvm.org/ce/z/uSTj6w
Godbolt: https://godbolt.org/z/MGrG84bzr

This patch will benefit many rust applications and some C/C++
applications (e.g., cvc5).
aniplcc pushed a commit to aniplcc/llvm-project that referenced this pull request Apr 15, 2024
An example from https://github.com/image-rs/image:
```
define void @test_ult_rhsc(i8 %x) {
  %val = add nsw i8 %x, -2
  %cmp = icmp ult i8 %val, 11
  %cond = select i1 %cmp, i8 %val, i8 6
  switch i8 %cond, label %bb1 [
  i8 0, label %bb2
  i8 10, label %bb3
  ]

bb1:
  call void @func1()
  unreachable
bb2:
  call void @func2()
  unreachable
bb3:
  call void @func3()
  unreachable
}
```

When `%cmp` evaluates to false, we can prove that the range of `%val` is
[11, umax]. Thus we can safely replace `%cond` with `%val` since both
`switch 6` and `switch %val` go to the default dest `%bb1`.

Alive2: https://alive2.llvm.org/ce/z/uSTj6w
Godbolt: https://godbolt.org/z/MGrG84bzr

This patch will benefit many rust applications and some C/C++
applications (e.g., cvc5).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants