Skip to content

[InstCombine] fold Select with a predicate consists of Icmp connected by And #76363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

sun-jacobi
Copy link
Member

@sun-jacobi sun-jacobi commented Dec 25, 2023

This patch closes #76043.


We extended the pre-exist foldSelectWithBinaryOp, to make it support the below case:

%A = icmp eq %TV, %FV
%C = and %A, %B
%D = select %C, %TV, %FV
->
%FV

or

%A = icmp ne %TV, %FV
%C = or %A, %B
%D = select %C, %FV, %TV
->
%TV

The Alive2 proof: https://alive2.llvm.org/ce/z/XLyhE-


For updated test cases in select-and-cmp.ll and select-or-cmp.ll, we also provided Alive2 proof: https://alive2.llvm.org/ce/z/krhtZy

@sun-jacobi sun-jacobi requested a review from nikic as a code owner December 25, 2023 16:25
@llvmbot llvmbot added llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Dec 25, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 25, 2023

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: Chia (sun-jacobi)

Changes

This patch closes #76043.


We extended the pre-exist foldSelectWithBinaryOp, to make it support the below case:

%A = icmp eq %TV, %FV
%C = and %A, %B
%D = select %C, %TV, %FV
->
%FV

or

%A = icmp ne %TV, %FV
%C = or %A, %B
%D = select %C, %FV, %TV
->
%TV

The Alive2 proof: https://alive2.llvm.org/ce/z/XLyhE-


For updated test cases in select-and-cmp.ll and select-or-cmp.ll, we also provided Alive2 proof: https://alive2.llvm.org/ce/z/krhtZy


Full diff: https://github.com/llvm/llvm-project/pull/76363.diff

3 Files Affected:

  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+46-27)
  • (modified) llvm/test/Transforms/InstSimplify/select-and-cmp.ll (+27-24)
  • (modified) llvm/test/Transforms/InstSimplify/select-or-cmp.ll (+28-36)
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 5beac5547d65e0..ca2fc9ca173c9a 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -79,15 +79,56 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
                                               const SimplifyQuery &SQ,
                                               unsigned MaxRecurse);
 
+static Value *simplifySelectWithAndOR(Value *Cond, Value *TrueVal,
+                                      Value *FalseVal,
+                                      CmpInst::Predicate ExpectedPred,
+                                      BinaryOperator::BinaryOps BinOpCode,
+                                      unsigned int MaxRecurse) {
+  assert(
+      (BinOpCode == BinaryOperator::And || BinOpCode == BinaryOperator::Or) &&
+      "Binary Operator should be And or Or");
+
+  assert(
+      (BinOpCode == BinaryOperator::And && ExpectedPred == ICmpInst::ICMP_EQ) ||
+      (BinOpCode == BinaryOperator::Or && ExpectedPred == ICmpInst::ICMP_NE));
+
+  if (!MaxRecurse)
+    return nullptr;
+
+  auto getSimplifiedValue = [](BinaryOperator::BinaryOps BinOpCode,
+                               Value *TrueVal, Value *FalseVal) {
+    return BinOpCode == BinaryOperator::Or ? TrueVal : FalseVal;
+  };
+
+  CmpInst::Predicate Pred;
+  if (match(Cond, m_c_ICmp(Pred, m_Specific(TrueVal), m_Specific(FalseVal))) &&
+      Pred == ExpectedPred)
+    return getSimplifiedValue(BinOpCode, TrueVal, FalseVal);
+
+  Value *X, *Y;
+  if (match(Cond, m_c_BinOp(BinOpCode, m_Value(X), m_Value(Y)))) {
+
+    auto matchBinOpCode = [&](Value *V) {
+      return simplifySelectWithAndOR(V, TrueVal, FalseVal, ExpectedPred,
+                                     BinOpCode, MaxRecurse - 1);
+    };
+
+    if (matchBinOpCode(X) || matchBinOpCode(Y))
+      return getSimplifiedValue(BinOpCode, TrueVal, FalseVal);
+  }
+
+  return nullptr;
+}
+
 static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal,
-                                     Value *FalseVal) {
+                                     Value *FalseVal, unsigned int MaxRecurse) {
   BinaryOperator::BinaryOps BinOpCode;
   if (auto *BO = dyn_cast<BinaryOperator>(Cond))
     BinOpCode = BO->getOpcode();
   else
     return nullptr;
 
-  CmpInst::Predicate ExpectedPred, Pred1, Pred2;
+  CmpInst::Predicate ExpectedPred;
   if (BinOpCode == BinaryOperator::Or) {
     ExpectedPred = ICmpInst::ICMP_NE;
   } else if (BinOpCode == BinaryOperator::And) {
@@ -95,30 +136,8 @@ static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal,
   } else
     return nullptr;
 
-  // %A = icmp eq %TV, %FV
-  // %B = icmp eq %X, %Y (and one of these is a select operand)
-  // %C = and %A, %B
-  // %D = select %C, %TV, %FV
-  // -->
-  // %FV
-
-  // %A = icmp ne %TV, %FV
-  // %B = icmp ne %X, %Y (and one of these is a select operand)
-  // %C = or %A, %B
-  // %D = select %C, %TV, %FV
-  // -->
-  // %TV
-  Value *X, *Y;
-  if (!match(Cond, m_c_BinOp(m_c_ICmp(Pred1, m_Specific(TrueVal),
-                                      m_Specific(FalseVal)),
-                             m_ICmp(Pred2, m_Value(X), m_Value(Y)))) ||
-      Pred1 != Pred2 || Pred1 != ExpectedPred)
-    return nullptr;
-
-  if (X == TrueVal || X == FalseVal || Y == TrueVal || Y == FalseVal)
-    return BinOpCode == BinaryOperator::Or ? TrueVal : FalseVal;
-
-  return nullptr;
+  return simplifySelectWithAndOR(Cond, TrueVal, FalseVal, ExpectedPred,
+                                 BinOpCode, MaxRecurse);
 }
 
 /// For a boolean type or a vector of boolean type, return false or a vector
@@ -4906,7 +4925,7 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
   if (Value *V = simplifySelectWithFCmp(Cond, TrueVal, FalseVal, Q))
     return V;
 
-  if (Value *V = foldSelectWithBinaryOp(Cond, TrueVal, FalseVal))
+  if (Value *V = foldSelectWithBinaryOp(Cond, TrueVal, FalseVal, MaxRecurse))
     return V;
 
   std::optional<bool> Imp = isImpliedByDomCondition(Cond, Q.CxtI, Q.DL);
diff --git a/llvm/test/Transforms/InstSimplify/select-and-cmp.ll b/llvm/test/Transforms/InstSimplify/select-and-cmp.ll
index 41a4ab96bd62cc..8a48618e217084 100644
--- a/llvm/test/Transforms/InstSimplify/select-and-cmp.ll
+++ b/llvm/test/Transforms/InstSimplify/select-and-cmp.ll
@@ -78,6 +78,28 @@ define i32 @select_and_inv_icmp_alt(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
+define i32 @select_and_icmp_ne(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_and_icmp_ne(
+; CHECK-NEXT:    ret i32 [[X:%.*]]
+;
+  %A = icmp eq i32 %x, %z
+  %B = icmp ne i32 %y, %z
+  %C = and i1 %A, %B
+  %D = select i1 %C, i32 %z, i32 %x
+  ret i32 %D
+}
+
+define i32 @select_and_icmp_ne_alt(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_and_icmp_ne_alt(
+; CHECK-NEXT:    ret i32 [[Z:%.*]]
+;
+  %A = icmp eq i32 %x, %z
+  %B = icmp ne i32 %y, %z
+  %C = and i1 %A, %B
+  %D = select i1 %C, i32 %x, i32 %z
+  ret i32 %D
+}
+
 define i32 @select_and_inv_icmp(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: @select_and_inv_icmp(
 ; CHECK-NEXT:    ret i32 [[X:%.*]]
@@ -115,21 +137,6 @@ define i32 @select_and_icmp_inv(i32 %x, i32 %y, i32 %z) {
 ; Negative tests
 define i32 @select_and_icmp_pred_bad_1(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: @select_and_icmp_pred_bad_1(
-; CHECK-NEXT:    [[A:%.*]] = icmp eq i32 [[X:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z]]
-; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
-; CHECK-NEXT:    [[D:%.*]] = select i1 [[C]], i32 [[Z]], i32 [[X]]
-; CHECK-NEXT:    ret i32 [[D]]
-;
-  %A = icmp eq i32 %x, %z
-  %B = icmp ne i32 %y, %z
-  %C = and i1 %A, %B
-  %D = select i1 %C, i32 %z, i32 %x
-  ret i32 %D
-}
-
-define i32 @select_and_icmp_pred_bad_2(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_and_icmp_pred_bad_2(
 ; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
@@ -143,8 +150,8 @@ define i32 @select_and_icmp_pred_bad_2(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
-define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_and_icmp_pred_bad_3(
+define i32 @select_and_icmp_pred_bad_2(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_and_icmp_pred_bad_2(
 ; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
@@ -158,8 +165,8 @@ define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
-define i32 @select_and_icmp_pred_bad_4(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_and_icmp_pred_bad_4(
+define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_and_icmp_pred_bad_3(
 ; CHECK-NEXT:    [[A:%.*]] = icmp eq i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]
@@ -235,11 +242,7 @@ define i32 @select_and_icmp_bad_op_2(i32 %x, i32 %y, i32 %z, i32 %k) {
 
 define i32 @select_and_icmp_alt_bad_1(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: @select_and_icmp_alt_bad_1(
-; CHECK-NEXT:    [[A:%.*]] = icmp eq i32 [[X:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z]]
-; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
-; CHECK-NEXT:    [[D:%.*]] = select i1 [[C]], i32 [[X]], i32 [[Z]]
-; CHECK-NEXT:    ret i32 [[D]]
+; CHECK-NEXT:    ret i32 [[Z:%.*]]
 ;
   %A = icmp eq i32 %x, %z
   %B = icmp ne i32 %y, %z
diff --git a/llvm/test/Transforms/InstSimplify/select-or-cmp.ll b/llvm/test/Transforms/InstSimplify/select-or-cmp.ll
index 0e410a9645f0d2..8b91ea03062695 100644
--- a/llvm/test/Transforms/InstSimplify/select-or-cmp.ll
+++ b/llvm/test/Transforms/InstSimplify/select-or-cmp.ll
@@ -78,6 +78,28 @@ define i32 @select_or_inv_icmp_alt(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
+define i32 @select_or_icmp_eq(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_or_icmp_eq(
+; CHECK-NEXT:    ret i32 [[X:%.*]]
+;
+  %A = icmp ne i32 %x, %z
+  %B = icmp eq i32 %y, %z
+  %C = or i1 %A, %B
+  %D = select i1 %C, i32 %x, i32 %z
+  ret i32 %D
+}
+
+define i32 @select_or_icmp_eq_alt(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_or_icmp_eq_alt(
+; CHECK-NEXT:    ret i32 [[Z:%.*]]
+;
+  %A = icmp ne i32 %x, %z
+  %B = icmp eq i32 %y, %z
+  %C = or i1 %A, %B
+  %D = select i1 %C, i32 %z, i32 %x
+  ret i32 %D
+}
+
 define <2 x i8> @select_or_icmp_alt_vec(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) {
 ; CHECK-LABEL: @select_or_icmp_alt_vec(
 ; CHECK-NEXT:    ret <2 x i8> [[X:%.*]]
@@ -129,21 +151,6 @@ define i32 @select_and_icmp_pred_bad_1(i32 %x, i32 %y, i32 %z) {
 
 define i32 @select_and_icmp_pred_bad_2(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: @select_and_icmp_pred_bad_2(
-; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
-; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]
-; CHECK-NEXT:    [[D:%.*]] = select i1 [[C]], i32 [[Z]], i32 [[X]]
-; CHECK-NEXT:    ret i32 [[D]]
-;
-  %A = icmp ne i32 %x, %z
-  %B = icmp eq i32 %y, %z
-  %C = or i1 %A, %B
-  %D = select i1 %C, i32 %z, i32 %x
-  ret i32 %D
-}
-
-define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_and_icmp_pred_bad_3(
 ; CHECK-NEXT:    [[A:%.*]] = icmp eq i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]
@@ -157,8 +164,8 @@ define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
-define i32 @select_and_icmp_pred_bad_4(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_and_icmp_pred_bad_4(
+define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_and_icmp_pred_bad_3(
 ; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
@@ -250,21 +257,6 @@ define i32 @select_or_icmp_alt_bad_1(i32 %x, i32 %y, i32 %z) {
 
 define i32 @select_or_icmp_alt_bad_2(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: @select_or_icmp_alt_bad_2(
-; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
-; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]
-; CHECK-NEXT:    [[D:%.*]] = select i1 [[C]], i32 [[X]], i32 [[Z]]
-; CHECK-NEXT:    ret i32 [[D]]
-;
-  %A = icmp ne i32 %x, %z
-  %B = icmp eq i32 %y, %z
-  %C = or i1 %A, %B
-  %D = select i1 %C, i32 %x, i32 %z
-  ret i32 %D
-}
-
-define i32 @select_or_icmp_alt_bad_3(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_or_icmp_alt_bad_3(
 ; CHECK-NEXT:    [[A:%.*]] = icmp eq i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]
@@ -278,8 +270,8 @@ define i32 @select_or_icmp_alt_bad_3(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
-define i32 @select_or_icmp_alt_bad_4(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_or_icmp_alt_bad_4(
+define i32 @select_or_icmp_alt_bad_3(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_or_icmp_alt_bad_3(
 ; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
@@ -293,8 +285,8 @@ define i32 @select_or_icmp_alt_bad_4(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
-define i32 @select_or_icmp_alt_bad_5(i32 %x, i32 %y, i32 %z, i32 %k) {
-; CHECK-LABEL: @select_or_icmp_alt_bad_5(
+define i32 @select_or_icmp_alt_bad_4(i32 %x, i32 %y, i32 %z, i32 %k) {
+; CHECK-LABEL: @select_or_icmp_alt_bad_4(
 ; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[K:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]

@sun-jacobi sun-jacobi changed the title [InstCombine] fold Select with a predicate consists of icmp connected by And [InstCombine] fold Select with a predicate consists of Icmp connected by And Dec 25, 2023
dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Dec 25, 2023
dtcxzyw added a commit that referenced this pull request Dec 31, 2023
…ldable (#76621)

This patch does the following folds:
```
(select A && B, T, F) -> (select A, (select B, T, F), F)
(select A || B, T, F) -> (select A, T, (select B, T, F))
```
if `(select B, T, F)` can be folded into a value or a canonicalized SPF.
Alive2: https://alive2.llvm.org/ce/z/4Bdrbu

The original motivation of this patch is to simplify the following
pattern:
```
%.sroa.speculated.i = tail call i64 @llvm.umax.i64(i64 %sub.ptr.div.i.i, i64 1)
%add.i = add i64 %.sroa.speculated.i, %sub.ptr.div.i.i
%cmp7.i = icmp ult i64 %add.i, %sub.ptr.div.i.i
%cmp9.i = icmp ugt i64 %add.i, 1152921504606846975
%or.cond.i = or i1 %cmp7.i, %cmp9.i
%cond.i = select i1 %or.cond.i, i64 1152921504606846975, i64 %add.i
->
%.sroa.speculated.i = tail call i64 @llvm.umax.i64(i64 %sub.ptr.div.i.i, i64 1)
%add.i = add i64 %.sroa.speculated.i, %sub.ptr.div.i.i
%cmp7.i = icmp ult i64 %add.i, %sub.ptr.div.i.i
%max = call i64 @llvm.umax.i64(i64 %add.i, 1152921504606846975)
%cond.i = select i1 %cmp7.i, i64 1152921504606846975, i64 %max
```
The later form has a better codegen for some backends. It is also more
analysis-friendly than the original one.
Godbolt: https://godbolt.org/z/eK6eb5jf1
Alive2: https://alive2.llvm.org/ce/z/VHlxL2

Compile-time impact:
http://llvm-compile-time-tracker.com/compare.php?from=7c71d3996a72b9b024622f23bf556539b961c88c&to=638ce8666fadaca1ab2639a3c2bc52a4a8508f40&stat=instructions:u

|stage1-O3|stage1-ReleaseThinLTO|stage1-ReleaseLTO-g|stage1-O0-g|stage2-O3|stage2-O0-g|stage2-clang|
|--|--|--|--|--|--|--|
|+0.02%|-0.00%|+0.02%|-0.03%|-0.00%|-0.05%|-0.00%|

It is an alternative to #76203 and #76363 because we can simplify
`select (icmp eq/ne a, b), a, b` into `b` or `a`.
Fixes #75784.
Fixes #76043.

Thank @XChy for providing additional tests.
Co-authored-by: XChy <[email protected]>
@nikic
Copy link
Contributor

nikic commented Jan 2, 2024

#76043 has been fixed by #76621. Does this patch handle any additional cases that one does not?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

missed optimization: Optimization based on the first of the two conditions
3 participants