-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[InstCombine] Widen Sel width after Cmp to generate Max/Min intrinsics. #118932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…h of Sel to generate Max/Min intrincs. From: (K and N mean width, K < N; a and b are src operands.) bN = Ext(bK) cond = Cmp(aN, bN) aK = Trunc aN retK = Sel(cond, aK, bK) To: bN = Ext(bK) cond = Cmp(aN, bN) retN = Sel(cond, aN, bN) retK = Trunc retN Though Sel's operands width becomes larger, the benefit of making type width in Sel the same as Cmp, is for combing to max/min intrinsics, and also better performance for SIMD instructions. References of correctness: https://alive2.llvm.org/ce/z/Y4Kegm https://alive2.llvm.org/ce/z/qFtjtR Reference of generated code comparision: https://gcc.godbolt.org/z/o97svGvYM https://gcc.godbolt.org/z/59Ynj91ov
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: None (tianleliu) ChangesWhen Sel(Cmp) are in different integer type, From: (K and N mean width, K < N; a and b are src operands.) Though Sel's operands width becomes larger, the benefit Full diff: https://github.com/llvm/llvm-project/pull/118932.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index c48068afc04816..a7f621f2f4bb30 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8781,34 +8781,14 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS);
}
-/// Helps to match a select pattern in case of a type mismatch.
-///
-/// The function processes the case when type of true and false values of a
-/// select instruction differs from type of the cmp instruction operands because
-/// of a cast instruction. The function checks if it is legal to move the cast
-/// operation after "select". If yes, it returns the new second value of
-/// "select" (with the assumption that cast is moved):
-/// 1. As operand of cast instruction when both values of "select" are same cast
-/// instructions.
-/// 2. As restored constant (by applying reverse cast operation) when the first
-/// value of the "select" is a cast operation and the second value is a
-/// constant.
-/// NOTE: We return only the new second value because the first value could be
-/// accessed as operand of cast instruction.
-static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
- Instruction::CastOps *CastOp) {
+static Value *lookThroughCastConst(CmpInst *CmpI, Value *V1, Value *V2,
+ Instruction::CastOps *CastOp) {
auto *Cast1 = dyn_cast<CastInst>(V1);
if (!Cast1)
return nullptr;
*CastOp = Cast1->getOpcode();
Type *SrcTy = Cast1->getSrcTy();
- if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
- // If V1 and V2 are both the same cast from the same type, look through V1.
- if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
- return Cast2->getOperand(0);
- return nullptr;
- }
auto *C = dyn_cast<Constant>(V2);
if (!C)
@@ -8890,6 +8870,63 @@ static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
return CastedTo;
}
+/// Helps to match a select pattern in case of a type mismatch.
+///
+/// The function processes the case when type of true and false values of a
+/// select instruction differs from type of the cmp instruction operands because
+/// of a cast instruction. The function checks if it is legal to move the cast
+/// operation after "select". If yes, it returns the new second value of
+/// "select" (with the assumption that cast is moved):
+/// 1. As operand of cast instruction when both values of "select" are same cast
+/// instructions.
+/// 2. As restored constant (by applying reverse cast operation) when the first
+/// value of the "select" is a cast operation and the second value is a
+/// constant. It is implemented in lookThroughCastConst().
+/// 3. As one operand is cast instruction and the other is not. The operands in
+/// sel(cmp) are in different type integer.
+/// NOTE: We return only the new second value because the first value could be
+/// accessed as operand of cast instruction.
+static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
+ Instruction::CastOps *CastOp) {
+ auto *Cast1 = dyn_cast<CastInst>(V1);
+ if (!Cast1)
+ return nullptr;
+
+ *CastOp = Cast1->getOpcode();
+ Type *SrcTy = Cast1->getSrcTy();
+ if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
+ // If V1 and V2 are both the same cast from the same type, look through V1.
+ if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
+ return Cast2->getOperand(0);
+ return nullptr;
+ }
+
+ auto *C = dyn_cast<Constant>(V2);
+ if (C)
+ return lookThroughCastConst(CmpI, V1, V2, CastOp);
+
+ Value *CastedTo = nullptr;
+ if (*CastOp == Instruction::Trunc) {
+ Value *ExtV;
+ if (match(CmpI->getOperand(1), m_SExt(m_Value(ExtV))) &&
+ ExtV->getType() == Cast1->getType() && ExtV == V2) {
+ // Here we have the following case:
+ // %y_ext = sext iK %y to iN
+ // %cond = cmp iN %x, %y_ext
+ // %tr = trunc iN %x to iK
+ // %narrowsel = select i1 %cond, iK %tr, iK %y
+ //
+ // We can always move trunc after select operation:
+ // %y_ext = sext iK %y to iN
+ // %cond = cmp iN %x, %y_ext
+ // %widesel = select i1 %cond, iN %x, iN%y_ext
+ // %tr = trunc iN %widesel to iK
+ CastedTo = CmpI->getOperand(1);
+ }
+ }
+
+ return CastedTo;
+}
SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
Instruction::CastOps *CastOp,
unsigned Depth) {
diff --git a/llvm/test/Transforms/InstCombine/minmax-fold.ll b/llvm/test/Transforms/InstCombine/minmax-fold.ll
index ccdf4400b16b54..8b9610569ad508 100644
--- a/llvm/test/Transforms/InstCombine/minmax-fold.ll
+++ b/llvm/test/Transforms/InstCombine/minmax-fold.ll
@@ -697,6 +697,34 @@ define zeroext i8 @look_through_cast2(i32 %x) {
ret i8 %res
}
+define zeroext i8 @look_through_cast_int_min(i8 %a, i32 %min) {
+; CHECK-LABEL: @look_through_cast_int_min(
+; CHECK-NEXT: [[A32:%.*]] = sext i8 [[A:%.*]] to i32
+; CHECK-NEXT: [[SEL1:%.*]] = call i32 @llvm.smin.i32(i32 [[MIN:%.*]], i32 [[A32]])
+; CHECK-NEXT: [[SEL:%.*]] = trunc i32 [[SEL1]] to i8
+; CHECK-NEXT: ret i8 [[SEL]]
+;
+ %a32 = sext i8 %a to i32
+ %cmp = icmp slt i32 %a32, %min
+ %min8 = trunc i32 %min to i8
+ %sel = select i1 %cmp, i8 %a, i8 %min8
+ ret i8 %sel
+}
+
+define zeroext i16 @look_through_cast_int_max(i16 %a, i32 %max) {
+; CHECK-LABEL: @look_through_cast_int_max(
+; CHECK-NEXT: [[A32:%.*]] = sext i16 [[A:%.*]] to i32
+; CHECK-NEXT: [[SEL1:%.*]] = call i32 @llvm.smax.i32(i32 [[MAX:%.*]], i32 [[A32]])
+; CHECK-NEXT: [[SEL:%.*]] = trunc i32 [[SEL1]] to i16
+; CHECK-NEXT: ret i16 [[SEL]]
+;
+ %a32 = sext i16 %a to i32
+ %cmp = icmp sgt i32 %max, %a32
+ %max8 = trunc i32 %max to i16
+ %sel = select i1 %cmp, i16 %max8, i16 %a
+ ret i16 %sel
+}
+
define <2 x i8> @min_through_cast_vec1(<2 x i32> %x) {
; CHECK-LABEL: @min_through_cast_vec1(
; CHECK-NEXT: [[RES1:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> [[X:%.*]], <2 x i32> <i32 510, i32 511>)
@@ -721,6 +749,34 @@ define <2 x i8> @min_through_cast_vec2(<2 x i32> %x) {
ret <2 x i8> %res
}
+define <8 x i8> @look_through_cast_int_min_vec(<8 x i8> %a, <8 x i32> %min) {
+; CHECK-LABEL: @look_through_cast_int_min_vec(
+; CHECK-NEXT: [[A32:%.*]] = sext <8 x i8> [[A:%.*]] to <8 x i32>
+; CHECK-NEXT: [[SEL1:%.*]] = call <8 x i32> @llvm.smin.v8i32(<8 x i32> [[MIN:%.*]], <8 x i32> [[A32]])
+; CHECK-NEXT: [[SEL:%.*]] = trunc <8 x i32> [[SEL1]] to <8 x i8>
+; CHECK-NEXT: ret <8 x i8> [[SEL]]
+;
+ %a32 = sext <8 x i8> %a to <8 x i32>
+ %cmp = icmp slt <8 x i32> %a32, %min
+ %min8 = trunc <8 x i32> %min to <8 x i8>
+ %sel = select <8 x i1> %cmp, <8 x i8> %a, <8 x i8> %min8
+ ret <8 x i8> %sel
+}
+
+define <8 x i32> @look_through_cast_int_max_vec(<8 x i32> %a, <8 x i64> %max) {
+; CHECK-LABEL: @look_through_cast_int_max_vec(
+; CHECK-NEXT: [[A32:%.*]] = sext <8 x i32> [[A:%.*]] to <8 x i64>
+; CHECK-NEXT: [[SEL1:%.*]] = call <8 x i64> @llvm.smax.v8i64(<8 x i64> [[MAX:%.*]], <8 x i64> [[A32]])
+; CHECK-NEXT: [[SEL:%.*]] = trunc <8 x i64> [[SEL1]] to <8 x i32>
+; CHECK-NEXT: ret <8 x i32> [[SEL]]
+;
+ %a32 = sext <8 x i32> %a to <8 x i64>
+ %cmp = icmp sgt <8 x i64> %a32, %max
+ %max8 = trunc <8 x i64> %max to <8 x i32>
+ %sel = select <8 x i1> %cmp, <8 x i32> %a, <8 x i32> %max8
+ ret <8 x i32> %sel
+}
+
; Remove a min/max op in a sequence with a common operand.
; PR35717: https://bugs.llvm.org/show_bug.cgi?id=35717
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some tests with zext and unsigned predicates.
llvm/lib/Analysis/ValueTracking.cpp
Outdated
if ((match(CmpI->getOperand(1), m_ZExt(m_Value(ExtV))) || | ||
match(CmpI->getOperand(1), m_SExt(m_Value(ExtV)))) && | ||
ExtV->getType() == Cast1->getType() && ExtV == V2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if ((match(CmpI->getOperand(1), m_ZExt(m_Value(ExtV))) || | |
match(CmpI->getOperand(1), m_SExt(m_Value(ExtV)))) && | |
ExtV->getType() == Cast1->getType() && ExtV == V2) { | |
if (match(CmpI->getOperand(1), m_ZExtOrSExt(m_Specific(V2))) && | |
V2->getType() == Cast1->getType()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like V2
and Cast1/V1
always have the same type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dtcxzyw Thanks for your review!
Do you mean "V2->getType() == Cast1->getType()" or "ExtV == V2" is redundant?
But the follow example is a counter example that %aa and %a are different.
define zeroext i8 @src(i8 %a, i16 %aa, i32 %min) {
%a32 = sext i16 %aa to i32
%cmp = icmp slt i32 %a32, %min
%min8 = trunc i32 %min to i8
%sel = select i1 %cmp, i8 %a, i8 %min8
ret i8 %sel
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExtV == V2
is checked by m_Specific
.
Since V1
and V2
are both arms of the select, they have the same type.
llvm/lib/Analysis/ValueTracking.cpp
Outdated
|
||
auto *C = dyn_cast<Constant>(V2); | ||
if (C) | ||
return lookThroughCastConst(CmpI, V1, V2, CastOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please adjust lookThroughCastConst
to avoid redundant checks.
static Value *lookThroughCastConst(CmpInst *CmpI, CastInst *Cast1, Constant *C,
Instruction::CastOps CastOp)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean remove
- if (!Cast1)
- return nullptr;
and - if (!C)
- return nullptr;
I remained it just only for robustness of lookThroughCastConst, if some one will call lookThroughCastConst independently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianleliu The checks are not necessary if you pass through the correct types (the dyn_cast results).
llvm/lib/Analysis/ValueTracking.cpp
Outdated
Value *CastedTo = nullptr; | ||
if (*CastOp == Instruction::Trunc) { | ||
if (match(CmpI->getOperand(1), m_ZExtOrSExt(m_Specific(V2))) && | ||
V2->getType() == Cast1->getType()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think this check is unnecessary. Can you replace it with an assertion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry @dtcxzyw , I don't understand which check you think is unnecessary? Could you please explain me more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We always pass two arms of a select into this function. If I am not missing something, the type of V1
and V2
must be equal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
llvm/lib/Analysis/ValueTracking.cpp
Outdated
// We can always move trunc after select operation: | ||
// %y_ext = sext iK %y to iN | ||
// %cond = cmp iN %x, %y_ext | ||
// %widesel = select i1 %cond, iN %x, iN%y_ext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// %widesel = select i1 %cond, iN %x, iN%y_ext | |
// %widesel = select i1 %cond, iN %x, iN %y_ext |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/60/builds/15467 Here is the relevant piece of the build log for the reference
|
When Sel(Cmp) are in different integer type,
From: (K and N mean width, K < N; a and b are src operands.)
bN = Ext(bK)
cond = Cmp(aN, bN)
aK = Trunc aN
retK = Sel(cond, aK, bK)
To:
bN = Ext(bK)
cond = Cmp(aN, bN)
retN = Sel(cond, aN, bN)
retK = Trunc retN
Though Sel's operands width becomes larger, the benefit
of making type width in Sel the same as Cmp, is for combing
to max/min intrinsics, and also better performance for SIMD instructions.
References of correctness: https://alive2.llvm.org/ce/z/Y4Kegm
https://alive2.llvm.org/ce/z/qFtjtR
Reference of generated code comparision:
https://gcc.godbolt.org/z/o97svGvYM
https://gcc.godbolt.org/z/59Ynj91ov