Skip to content

[InstCombine] Fold (x < y) ? -1 : zext(x > y) and (x > y) ? 1 : sext(x < y) to ucmp/scmp(x, y) #105272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 23, 2024

Conversation

Poseydon42
Copy link
Contributor

This patch expands already existing funcionality to include these two additional folds, which are nearly identical to the ones already implemented.

Proofs: https://alive2.llvm.org/ce/z/Xy7s4j

@llvmbot
Copy link
Member

llvmbot commented Aug 20, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Volodymyr Vasylkun (Poseydon42)

Changes

This patch expands already existing funcionality to include these two additional folds, which are nearly identical to the ones already implemented.

Proofs: https://alive2.llvm.org/ce/z/Xy7s4j


Full diff: https://github.com/llvm/llvm-project/pull/105272.diff

4 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+13-4)
  • (modified) llvm/test/Transforms/InstCombine/scmp.ll (+28)
  • (modified) llvm/test/Transforms/InstCombine/select-select.ll (+11-31)
  • (modified) llvm/test/Transforms/InstCombine/ucmp.ll (+30-2)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 1f6d5759883fd0..18ffc209f259e0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3560,7 +3560,9 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
 
 // This function tries to fold the following operations:
 //   (x < y) ? -1 : zext(x != y)
+//   (x < y) ? -1 : zext(x > y)
 //   (x > y) ? 1 : sext(x != y)
+//   (x > y) ? 1 : sext(x < y)
 // Into ucmp/scmp(x, y), where signedness is determined by the signedness
 // of the comparison in the original sequence.
 Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
@@ -3589,16 +3591,23 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
       ICmpInst::isSigned(Pred) ? Intrinsic::scmp : Intrinsic::ucmp;
 
   bool Replace = false;
+  ICmpInst::Predicate ExtendedCmpPredicate;
   // (x < y) ? -1 : zext(x != y)
+  // (x < y) ? -1 : zext(x > y)
   if (ICmpInst::isLT(Pred) && match(TV, m_AllOnes()) &&
-      match(FV, m_ZExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
-                                        m_Specific(RHS)))))
+      match(FV, m_ZExt(m_c_ICmp(ExtendedCmpPredicate, m_Specific(LHS),
+                                m_Specific(RHS)))) &&
+      (ExtendedCmpPredicate == ICmpInst::ICMP_NE ||
+       ICmpInst::getSwappedPredicate(ExtendedCmpPredicate) == Pred))
     Replace = true;
 
   // (x > y) ? 1 : sext(x != y)
+  // (x > y) ? 1 : sext(x < y)
   if (ICmpInst::isGT(Pred) && match(TV, m_One()) &&
-      match(FV, m_SExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
-                                        m_Specific(RHS)))))
+      match(FV, m_SExt(m_c_ICmp(ExtendedCmpPredicate, m_Specific(LHS),
+                                m_Specific(RHS)))) &&
+      (ExtendedCmpPredicate == ICmpInst::ICMP_NE ||
+       ICmpInst::getSwappedPredicate(ExtendedCmpPredicate) == Pred))
     Replace = true;
 
   if (Replace)
diff --git a/llvm/test/Transforms/InstCombine/scmp.ll b/llvm/test/Transforms/InstCombine/scmp.ll
index 7f374c5f9a1d64..e2312140c8c13d 100644
--- a/llvm/test/Transforms/InstCombine/scmp.ll
+++ b/llvm/test/Transforms/InstCombine/scmp.ll
@@ -223,6 +223,20 @@ define i8 @scmp_from_select_lt(i32 %x, i32 %y) {
   ret i8 %r
 }
 
+; Fold (x s< y) ? -1 : zext(x s> y) into scmp(x, y)
+define i8 @scmp_from_select_lt_and_gt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_lt_and_gt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %gt_bool = icmp sgt i32 %x, %y
+  %gt = zext i1 %gt_bool to i8
+  %lt = icmp slt i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %gt
+  ret i8 %r
+}
+
 ; Vector version
 define <4 x i8> @scmp_from_select_vec_lt(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: define <4 x i8> @scmp_from_select_vec_lt(
@@ -264,3 +278,17 @@ define i8 @scmp_from_select_ge(i32 %x, i32 %y) {
   %r = select i1 %ge, i8 %ne, i8 -1
   ret i8 %r
 }
+
+; Fold (x s> y) ? 1 : sext(x s< y)
+define i8 @scmp_from_select_gt_and_lt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_gt_and_lt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lt_bool = icmp slt i32 %x, %y
+  %lt = sext i1 %lt_bool to i8
+  %gt = icmp sgt i32 %x, %y
+  %r = select i1 %gt, i8 1, i8 %lt
+  ret i8 %r
+}
diff --git a/llvm/test/Transforms/InstCombine/select-select.ll b/llvm/test/Transforms/InstCombine/select-select.ll
index 5460ba1bc55838..1feae5ab504dcf 100644
--- a/llvm/test/Transforms/InstCombine/select-select.ll
+++ b/llvm/test/Transforms/InstCombine/select-select.ll
@@ -18,9 +18,9 @@ define float @foo1(float %a) {
 
 define float @foo2(float %a) {
 ; CHECK-LABEL: @foo2(
-; CHECK-NEXT:    [[B:%.*]] = fcmp ule float [[C:%.*]], 0.000000e+00
-; CHECK-NEXT:    [[D:%.*]] = fcmp olt float [[C]], 1.000000e+00
-; CHECK-NEXT:    [[E:%.*]] = select i1 [[D]], float [[C]], float 1.000000e+00
+; CHECK-NEXT:    [[B:%.*]] = fcmp ule float [[A:%.*]], 0.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = fcmp olt float [[A]], 1.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = select i1 [[TMP1]], float [[A]], float 1.000000e+00
 ; CHECK-NEXT:    [[F:%.*]] = select i1 [[B]], float 0.000000e+00, float [[E]]
 ; CHECK-NEXT:    ret float [[F]]
 ;
@@ -330,10 +330,7 @@ define i8 @strong_order_cmp_eq_ugt(i32 %a, i32 %b) {
 
 define i8 @strong_order_cmp_slt_sgt(i32 %a, i32 %b) {
 ; CHECK-LABEL: @strong_order_cmp_slt_sgt(
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp sgt i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
+; CHECK-NEXT:    [[SEL_GT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
 ; CHECK-NEXT:    ret i8 [[SEL_GT]]
 ;
   %cmp.lt = icmp slt i32 %a, %b
@@ -345,10 +342,7 @@ define i8 @strong_order_cmp_slt_sgt(i32 %a, i32 %b) {
 
 define i8 @strong_order_cmp_ult_ugt(i32 %a, i32 %b) {
 ; CHECK-LABEL: @strong_order_cmp_ult_ugt(
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp ugt i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
+; CHECK-NEXT:    [[SEL_GT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
 ; CHECK-NEXT:    ret i8 [[SEL_GT]]
 ;
   %cmp.lt = icmp ult i32 %a, %b
@@ -360,10 +354,7 @@ define i8 @strong_order_cmp_ult_ugt(i32 %a, i32 %b) {
 
 define i8 @strong_order_cmp_sgt_slt(i32 %a, i32 %b) {
 ; CHECK-LABEL: @strong_order_cmp_sgt_slt(
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp sgt i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp slt i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
+; CHECK-NEXT:    [[SEL_LT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
 ; CHECK-NEXT:    ret i8 [[SEL_LT]]
 ;
   %cmp.gt = icmp sgt i32 %a, %b
@@ -375,10 +366,7 @@ define i8 @strong_order_cmp_sgt_slt(i32 %a, i32 %b) {
 
 define i8 @strong_order_cmp_ugt_ult(i32 %a, i32 %b) {
 ; CHECK-LABEL: @strong_order_cmp_ugt_ult(
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp ugt i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp ult i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
+; CHECK-NEXT:    [[SEL_LT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
 ; CHECK-NEXT:    ret i8 [[SEL_LT]]
 ;
   %cmp.gt = icmp ugt i32 %a, %b
@@ -460,8 +448,7 @@ define i8 @strong_order_cmp_ugt_ult_zext_not_oneuse(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp ugt i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
 ; CHECK-NEXT:    call void @use8(i8 [[ZEXT]])
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp ult i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
+; CHECK-NEXT:    [[SEL_LT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
 ; CHECK-NEXT:    ret i8 [[SEL_LT]]
 ;
   %cmp.gt = icmp ugt i32 %a, %b
@@ -477,8 +464,7 @@ define i8 @strong_order_cmp_slt_sgt_sext_not_oneuse(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
 ; CHECK-NEXT:    call void @use8(i8 [[SEXT]])
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp sgt i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
+; CHECK-NEXT:    [[SEL_GT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A]], i32 [[B]])
 ; CHECK-NEXT:    ret i8 [[SEL_GT]]
 ;
   %cmp.lt = icmp slt i32 %a, %b
@@ -491,10 +477,7 @@ define i8 @strong_order_cmp_slt_sgt_sext_not_oneuse(i32 %a, i32 %b) {
 
 define <2 x i8> @strong_order_cmp_ugt_ult_vector(<2 x i32> %a, <2 x i32> %b) {
 ; CHECK-LABEL: @strong_order_cmp_ugt_ult_vector(
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i1> [[CMP_GT]] to <2 x i8>
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp ult <2 x i32> [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_LT:%.*]] = select <2 x i1> [[CMP_LT]], <2 x i8> <i8 -1, i8 -1>, <2 x i8> [[ZEXT]]
+; CHECK-NEXT:    [[SEL_LT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
 ; CHECK-NEXT:    ret <2 x i8> [[SEL_LT]]
 ;
   %cmp.gt = icmp ugt <2 x i32> %a, %b
@@ -506,10 +489,7 @@ define <2 x i8> @strong_order_cmp_ugt_ult_vector(<2 x i32> %a, <2 x i32> %b) {
 
 define <2 x i8> @strong_order_cmp_ugt_ult_vector_poison(<2 x i32> %a, <2 x i32> %b) {
 ; CHECK-LABEL: @strong_order_cmp_ugt_ult_vector_poison(
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i1> [[CMP_GT]] to <2 x i8>
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp ult <2 x i32> [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_LT:%.*]] = select <2 x i1> [[CMP_LT]], <2 x i8> <i8 poison, i8 -1>, <2 x i8> [[ZEXT]]
+; CHECK-NEXT:    [[SEL_LT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
 ; CHECK-NEXT:    ret <2 x i8> [[SEL_LT]]
 ;
   %cmp.gt = icmp ugt <2 x i32> %a, %b
diff --git a/llvm/test/Transforms/InstCombine/ucmp.ll b/llvm/test/Transforms/InstCombine/ucmp.ll
index ad8a57825253b0..13755f13bb0a11 100644
--- a/llvm/test/Transforms/InstCombine/ucmp.ll
+++ b/llvm/test/Transforms/InstCombine/ucmp.ll
@@ -222,6 +222,20 @@ define i8 @ucmp_from_select_lt(i32 %x, i32 %y) {
   ret i8 %r
 }
 
+; Fold (x u< y) ? -1 : zext(x u> y) into ucmp(x, y)
+define i8 @ucmp_from_select_lt_and_gt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_lt_and_gt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %gt_bool = icmp ugt i32 %x, %y
+  %gt = zext i1 %gt_bool to i8
+  %lt = icmp ult i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %gt
+  ret i8 %r
+}
+
 ; Vector version
 define <4 x i8> @ucmp_from_select_vec_lt(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: define <4 x i8> @ucmp_from_select_vec_lt(
@@ -349,13 +363,13 @@ define i8 @ucmp_from_select_le_neg1(i32 %x, i32 %y) {
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
 ; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp ult i32 [[X]], [[Y]]
 ; CHECK-NEXT:    [[NE:%.*]] = sext i1 [[NE_BOOL]] to i8
-; CHECK-NEXT:    [[LE_NOT:%.*]] = icmp ugt i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[LE_NOT:%.*]] = icmp ult i32 [[X]], [[Y]]
 ; CHECK-NEXT:    [[R:%.*]] = select i1 [[LE_NOT]], i8 1, i8 [[NE]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %ne_bool = icmp ult i32 %x, %y
   %ne = sext i1 %ne_bool to i8
-  %le = icmp ule i32 %x, %y
+  %le = icmp uge i32 %x, %y
   %r = select i1 %le, i8 %ne, i8 1
   ret i8 %r
 }
@@ -513,3 +527,17 @@ define i8 @ucmp_from_select_ge_neg4(i32 %x, i32 %y) {
   %r = select i1 %ge, i8 %ne, i8 3
   ret i8 %r
 }
+
+; Fold (x > y) ? 1 : sext(x < y)
+define i8 @ucmp_from_select_gt_and_lt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_gt_and_lt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lt_bool = icmp ult i32 %x, %y
+  %lt = sext i1 %lt_bool to i8
+  %gt = icmp ugt i32 %x, %y
+  %r = select i1 %gt, i8 1, i8 %lt
+  ret i8 %r
+}

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Aug 20, 2024
@nikic nikic requested a review from dtcxzyw August 20, 2024 19:23
@dtcxzyw
Copy link
Member

dtcxzyw commented Aug 20, 2024

Missing fold: https://alive2.llvm.org/ce/z/9wTPep

@Poseydon42
Copy link
Contributor Author

Missing fold: https://alive2.llvm.org/ce/z/9wTPep

Addressed in #105583

@nikic
Copy link
Contributor

nikic commented Aug 22, 2024

After this, I think the only remaining pattern is where the outer select has an equality condition. And I just found that we have a function for matching it already:

bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,

@Poseydon42 Poseydon42 force-pushed the more-folds-to-uscmp branch from 204c9e9 to 9788782 Compare August 22, 2024 14:06
@Poseydon42
Copy link
Contributor Author

Missing fold: https://alive2.llvm.org/ce/z/9wTPep

Should be fixed now.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Poseydon42 Poseydon42 merged commit da6f423 into llvm:main Aug 23, 2024
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants