Skip to content

[InstCombine] Widen Sel width after Cmp to generate Max/Min intrinsics. #118932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 18, 2024

Conversation

tianleliu
Copy link
Contributor

When Sel(Cmp) are in different integer type,

From: (K and N mean width, K < N; a and b are src operands.)
bN = Ext(bK)
cond = Cmp(aN, bN)
aK = Trunc aN
retK = Sel(cond, aK, bK)
To:
bN = Ext(bK)
cond = Cmp(aN, bN)
retN = Sel(cond, aN, bN)
retK = Trunc retN

Though Sel's operands width becomes larger, the benefit
of making type width in Sel the same as Cmp, is for combing
to max/min intrinsics, and also better performance for SIMD instructions.
References of correctness: https://alive2.llvm.org/ce/z/Y4Kegm
https://alive2.llvm.org/ce/z/qFtjtR
Reference of generated code comparision:
https://gcc.godbolt.org/z/o97svGvYM
https://gcc.godbolt.org/z/59Ynj91ov

…h of Sel to generate Max/Min intrincs.

From: (K and N mean width, K < N; a and b are src operands.)
bN = Ext(bK)
cond = Cmp(aN, bN)
aK = Trunc aN
retK = Sel(cond, aK, bK)
To:
bN = Ext(bK)
cond = Cmp(aN, bN)
retN = Sel(cond, aN, bN)
retK = Trunc retN

Though Sel's operands width becomes larger, the benefit
of making type width in Sel the same as Cmp, is for combing
to max/min intrinsics, and also better performance for SIMD instructions.
References of correctness: https://alive2.llvm.org/ce/z/Y4Kegm
                           https://alive2.llvm.org/ce/z/qFtjtR
Reference of generated code comparision:
                           https://gcc.godbolt.org/z/o97svGvYM
                           https://gcc.godbolt.org/z/59Ynj91ov
@tianleliu tianleliu requested a review from nikic as a code owner December 6, 2024 06:49
@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Dec 6, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2024

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: None (tianleliu)

Changes

When Sel(Cmp) are in different integer type,

From: (K and N mean width, K < N; a and b are src operands.)
bN = Ext(bK)
cond = Cmp(aN, bN)
aK = Trunc aN
retK = Sel(cond, aK, bK)
To:
bN = Ext(bK)
cond = Cmp(aN, bN)
retN = Sel(cond, aN, bN)
retK = Trunc retN

Though Sel's operands width becomes larger, the benefit
of making type width in Sel the same as Cmp, is for combing
to max/min intrinsics, and also better performance for SIMD instructions.
References of correctness: https://alive2.llvm.org/ce/z/Y4Kegm
https://alive2.llvm.org/ce/z/qFtjtR
Reference of generated code comparision:
https://gcc.godbolt.org/z/o97svGvYM
https://gcc.godbolt.org/z/59Ynj91ov


Full diff: https://github.com/llvm/llvm-project/pull/118932.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+59-22)
  • (modified) llvm/test/Transforms/InstCombine/minmax-fold.ll (+56)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index c48068afc04816..a7f621f2f4bb30 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8781,34 +8781,14 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
   return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS);
 }
 
-/// Helps to match a select pattern in case of a type mismatch.
-///
-/// The function processes the case when type of true and false values of a
-/// select instruction differs from type of the cmp instruction operands because
-/// of a cast instruction. The function checks if it is legal to move the cast
-/// operation after "select". If yes, it returns the new second value of
-/// "select" (with the assumption that cast is moved):
-/// 1. As operand of cast instruction when both values of "select" are same cast
-/// instructions.
-/// 2. As restored constant (by applying reverse cast operation) when the first
-/// value of the "select" is a cast operation and the second value is a
-/// constant.
-/// NOTE: We return only the new second value because the first value could be
-/// accessed as operand of cast instruction.
-static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
-                              Instruction::CastOps *CastOp) {
+static Value *lookThroughCastConst(CmpInst *CmpI, Value *V1, Value *V2,
+                                   Instruction::CastOps *CastOp) {
   auto *Cast1 = dyn_cast<CastInst>(V1);
   if (!Cast1)
     return nullptr;
 
   *CastOp = Cast1->getOpcode();
   Type *SrcTy = Cast1->getSrcTy();
-  if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
-    // If V1 and V2 are both the same cast from the same type, look through V1.
-    if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
-      return Cast2->getOperand(0);
-    return nullptr;
-  }
 
   auto *C = dyn_cast<Constant>(V2);
   if (!C)
@@ -8890,6 +8870,63 @@ static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
   return CastedTo;
 }
 
+/// Helps to match a select pattern in case of a type mismatch.
+///
+/// The function processes the case when type of true and false values of a
+/// select instruction differs from type of the cmp instruction operands because
+/// of a cast instruction. The function checks if it is legal to move the cast
+/// operation after "select". If yes, it returns the new second value of
+/// "select" (with the assumption that cast is moved):
+/// 1. As operand of cast instruction when both values of "select" are same cast
+/// instructions.
+/// 2. As restored constant (by applying reverse cast operation) when the first
+/// value of the "select" is a cast operation and the second value is a
+/// constant. It is implemented in lookThroughCastConst().
+/// 3. As one operand is cast instruction and the other is not. The operands in
+/// sel(cmp) are in different type integer.
+/// NOTE: We return only the new second value because the first value could be
+/// accessed as operand of cast instruction.
+static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
+                              Instruction::CastOps *CastOp) {
+  auto *Cast1 = dyn_cast<CastInst>(V1);
+  if (!Cast1)
+    return nullptr;
+
+  *CastOp = Cast1->getOpcode();
+  Type *SrcTy = Cast1->getSrcTy();
+  if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
+    // If V1 and V2 are both the same cast from the same type, look through V1.
+    if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
+      return Cast2->getOperand(0);
+    return nullptr;
+  }
+
+  auto *C = dyn_cast<Constant>(V2);
+  if (C)
+    return lookThroughCastConst(CmpI, V1, V2, CastOp);
+
+  Value *CastedTo = nullptr;
+  if (*CastOp == Instruction::Trunc) {
+    Value *ExtV;
+    if (match(CmpI->getOperand(1), m_SExt(m_Value(ExtV))) &&
+        ExtV->getType() == Cast1->getType() && ExtV == V2) {
+      // Here we have the following case:
+      //   %y_ext = sext iK %y to iN
+      //   %cond = cmp iN %x, %y_ext
+      //   %tr = trunc iN %x to iK
+      //   %narrowsel = select i1 %cond, iK %tr, iK %y
+      //
+      // We can always move trunc after select operation:
+      //   %y_ext = sext iK %y to iN
+      //   %cond = cmp iN %x, %y_ext
+      //   %widesel = select i1 %cond, iN %x, iN%y_ext
+      //   %tr = trunc iN %widesel to iK
+      CastedTo = CmpI->getOperand(1);
+    }
+  }
+
+  return CastedTo;
+}
 SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
                                              Instruction::CastOps *CastOp,
                                              unsigned Depth) {
diff --git a/llvm/test/Transforms/InstCombine/minmax-fold.ll b/llvm/test/Transforms/InstCombine/minmax-fold.ll
index ccdf4400b16b54..8b9610569ad508 100644
--- a/llvm/test/Transforms/InstCombine/minmax-fold.ll
+++ b/llvm/test/Transforms/InstCombine/minmax-fold.ll
@@ -697,6 +697,34 @@ define zeroext i8 @look_through_cast2(i32 %x) {
   ret i8 %res
 }
 
+define zeroext i8 @look_through_cast_int_min(i8 %a, i32 %min) {
+; CHECK-LABEL: @look_through_cast_int_min(
+; CHECK-NEXT:    [[A32:%.*]] = sext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[SEL1:%.*]] = call i32 @llvm.smin.i32(i32 [[MIN:%.*]], i32 [[A32]])
+; CHECK-NEXT:    [[SEL:%.*]] = trunc i32 [[SEL1]] to i8
+; CHECK-NEXT:    ret i8 [[SEL]]
+;
+  %a32 = sext i8 %a to i32
+  %cmp = icmp slt i32 %a32, %min
+  %min8 = trunc i32 %min to i8
+  %sel = select i1 %cmp, i8 %a, i8 %min8
+  ret i8 %sel
+}
+
+define zeroext i16 @look_through_cast_int_max(i16 %a, i32 %max) {
+; CHECK-LABEL: @look_through_cast_int_max(
+; CHECK-NEXT:    [[A32:%.*]] = sext i16 [[A:%.*]] to i32
+; CHECK-NEXT:    [[SEL1:%.*]] = call i32 @llvm.smax.i32(i32 [[MAX:%.*]], i32 [[A32]])
+; CHECK-NEXT:    [[SEL:%.*]] = trunc i32 [[SEL1]] to i16
+; CHECK-NEXT:    ret i16 [[SEL]]
+;
+  %a32 = sext i16 %a to i32
+  %cmp = icmp sgt i32 %max, %a32
+  %max8 = trunc i32 %max to i16
+  %sel = select i1 %cmp, i16 %max8, i16 %a
+  ret i16 %sel
+}
+
 define <2 x i8> @min_through_cast_vec1(<2 x i32> %x) {
 ; CHECK-LABEL: @min_through_cast_vec1(
 ; CHECK-NEXT:    [[RES1:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> [[X:%.*]], <2 x i32> <i32 510, i32 511>)
@@ -721,6 +749,34 @@ define <2 x i8> @min_through_cast_vec2(<2 x i32> %x) {
   ret <2 x i8> %res
 }
 
+define <8 x i8> @look_through_cast_int_min_vec(<8 x i8> %a, <8 x i32> %min) {
+; CHECK-LABEL: @look_through_cast_int_min_vec(
+; CHECK-NEXT:    [[A32:%.*]] = sext <8 x i8> [[A:%.*]] to <8 x i32>
+; CHECK-NEXT:    [[SEL1:%.*]] = call <8 x i32> @llvm.smin.v8i32(<8 x i32> [[MIN:%.*]], <8 x i32> [[A32]])
+; CHECK-NEXT:    [[SEL:%.*]] = trunc <8 x i32> [[SEL1]] to <8 x i8>
+; CHECK-NEXT:    ret <8 x i8> [[SEL]]
+;
+  %a32 = sext <8 x i8> %a to <8 x i32>
+  %cmp = icmp slt <8 x i32> %a32, %min
+  %min8 = trunc <8 x i32> %min to <8 x i8>
+  %sel = select <8 x i1> %cmp, <8 x i8> %a, <8 x i8> %min8
+  ret <8 x i8> %sel
+}
+
+define <8 x i32> @look_through_cast_int_max_vec(<8 x i32> %a, <8 x i64> %max) {
+; CHECK-LABEL: @look_through_cast_int_max_vec(
+; CHECK-NEXT:    [[A32:%.*]] = sext <8 x i32> [[A:%.*]] to <8 x i64>
+; CHECK-NEXT:    [[SEL1:%.*]] = call <8 x i64> @llvm.smax.v8i64(<8 x i64> [[MAX:%.*]], <8 x i64> [[A32]])
+; CHECK-NEXT:    [[SEL:%.*]] = trunc <8 x i64> [[SEL1]] to <8 x i32>
+; CHECK-NEXT:    ret <8 x i32> [[SEL]]
+;
+  %a32 = sext <8 x i32> %a to <8 x i64>
+  %cmp = icmp sgt <8 x i64> %a32, %max
+  %max8 = trunc <8 x i64> %max to <8 x i32>
+  %sel = select <8 x i1> %cmp, <8 x i32> %a, <8 x i32> %max8
+  ret <8 x i32> %sel
+}
+
 ; Remove a min/max op in a sequence with a common operand.
 ; PR35717: https://bugs.llvm.org/show_bug.cgi?id=35717
 

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some tests with zext and unsigned predicates.

Comment on lines 8911 to 8913
if ((match(CmpI->getOperand(1), m_ZExt(m_Value(ExtV))) ||
match(CmpI->getOperand(1), m_SExt(m_Value(ExtV)))) &&
ExtV->getType() == Cast1->getType() && ExtV == V2) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ((match(CmpI->getOperand(1), m_ZExt(m_Value(ExtV))) ||
match(CmpI->getOperand(1), m_SExt(m_Value(ExtV)))) &&
ExtV->getType() == Cast1->getType() && ExtV == V2) {
if (match(CmpI->getOperand(1), m_ZExtOrSExt(m_Specific(V2))) &&
V2->getType() == Cast1->getType()) {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like V2 and Cast1/V1 always have the same type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dtcxzyw Thanks for your review!
Do you mean "V2->getType() == Cast1->getType()" or "ExtV == V2" is redundant?
But the follow example is a counter example that %aa and %a are different.
define zeroext i8 @src(i8 %a, i16 %aa, i32 %min) {
%a32 = sext i16 %aa to i32
%cmp = icmp slt i32 %a32, %min
%min8 = trunc i32 %min to i8
%sel = select i1 %cmp, i8 %a, i8 %min8
ret i8 %sel
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExtV == V2 is checked by m_Specific.
Since V1 and V2 are both arms of the select, they have the same type.


auto *C = dyn_cast<Constant>(V2);
if (C)
return lookThroughCastConst(CmpI, V1, V2, CastOp);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please adjust lookThroughCastConst to avoid redundant checks.

static Value *lookThroughCastConst(CmpInst *CmpI, CastInst *Cast1, Constant *C,
                                    Instruction::CastOps CastOp)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean remove

  • if (!Cast1)
  • return nullptr;
    and
  • if (!C)
  • return nullptr;

I remained it just only for robustness of lookThroughCastConst, if some one will call lookThroughCastConst independently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nikic @dtcxzyw Any other comments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianleliu The checks are not necessary if you pass through the correct types (the dyn_cast results).

Value *CastedTo = nullptr;
if (*CastOp == Instruction::Trunc) {
if (match(CmpI->getOperand(1), m_ZExtOrSExt(m_Specific(V2))) &&
V2->getType() == Cast1->getType()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think this check is unnecessary. Can you replace it with an assertion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @dtcxzyw , I don't understand which check you think is unnecessary? Could you please explain me more?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always pass two arms of a select into this function. If I am not missing something, the type of V1 and V2 must be equal.

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you!

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

// We can always move trunc after select operation:
// %y_ext = sext iK %y to iN
// %cond = cmp iN %x, %y_ext
// %widesel = select i1 %cond, iN %x, iN%y_ext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// %widesel = select i1 %cond, iN %x, iN%y_ext
// %widesel = select i1 %cond, iN %x, iN %y_ext

@tianleliu tianleliu merged commit d7fe2cf into llvm:main Dec 18, 2024
8 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Dec 18, 2024

LLVM Buildbot has detected a new failure on builder llvm-x86_64-debian-dylib running on gribozavr4 while building llvm at step 2 "checkout".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/60/builds/15467

Here is the relevant piece of the build log for the reference
Step 2 (checkout) failure: update (failure)
git version 2.30.2
remote: GH100: Service unavailable.
fatal: unable to access 'https://github.com/llvm/llvm-project.git/': The requested URL returned error: 503
remote: GH100: Service unavailable.
fatal: unable to access 'https://github.com/llvm/llvm-project.git/': The requested URL returned error: 503

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants