-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[vectorcombine] Pull sext/zext through reduce.or/and/xor #99548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[vectorcombine] Pull sext/zext through reduce.or/and/xor #99548
Conversation
This extends the existing foldTruncFromReductions transform to handle sext and zext as well. This is only legal for the bitwise reductions (and/or/xor) and not the arithmetic ones (add, mul). Use the same costing decision to drive whether we do the transform.
@llvm/pr-subscribers-llvm-transforms Author: Philip Reames (preames) ChangesThis extends the existing foldTruncFromReductions transform to handle sext and zext as well. This is only legal for the bitwise reductions (and/or/xor) and not the arithmetic ones (add, mul). Use the same costing decision to drive whether we do the transform. Full diff: https://github.com/llvm/llvm-project/pull/99548.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 3a49f95d3f117..de60d80aeffa1 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -117,7 +117,7 @@ class VectorCombine {
bool foldShuffleOfShuffles(Instruction &I);
bool foldShuffleToIdentity(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
- bool foldTruncFromReductions(Instruction &I);
+ bool foldCastFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
void replaceValue(Value &Old, Value &New) {
@@ -2113,15 +2113,20 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
/// Determine if its more efficient to fold:
/// reduce(trunc(x)) -> trunc(reduce(x)).
-bool VectorCombine::foldTruncFromReductions(Instruction &I) {
+/// reduce(sext(x)) -> sext(reduce(x)).
+/// reduce(zext(x)) -> zext(reduce(x)).
+bool VectorCombine::foldCastFromReductions(Instruction &I) {
auto *II = dyn_cast<IntrinsicInst>(&I);
if (!II)
return false;
+ bool TruncOnly = false;
Intrinsic::ID IID = II->getIntrinsicID();
switch (IID) {
case Intrinsic::vector_reduce_add:
case Intrinsic::vector_reduce_mul:
+ TruncOnly = true;
+ break;
case Intrinsic::vector_reduce_and:
case Intrinsic::vector_reduce_or:
case Intrinsic::vector_reduce_xor:
@@ -2133,25 +2138,32 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
Value *ReductionSrc = I.getOperand(0);
- Value *TruncSrc;
- if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(TruncSrc)))))
+ Value *Src;
+ if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
+ (TruncOnly ||
+ !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
return false;
- auto *TruncSrcTy = cast<VectorType>(TruncSrc->getType());
+ // Note: Only trunc has a constexpr, neither sext or zext do.
+ auto CastOpc = Instruction::Trunc;
+ if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
+ CastOpc = (Instruction::CastOps)cast<Instruction>(Cast)->getOpcode();
+
+ auto *SrcTy = cast<VectorType>(Src->getType());
auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
Type *ResultTy = I.getType();
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
InstructionCost OldCost = TTI.getArithmeticReductionCost(
ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
- if (auto *Trunc = dyn_cast<CastInst>(ReductionSrc))
+ if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
OldCost +=
- TTI.getCastInstrCost(Instruction::Trunc, ReductionSrcTy, TruncSrcTy,
- TTI::CastContextHint::None, CostKind, Trunc);
+ TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
+ TTI::CastContextHint::None, CostKind, Cast);
InstructionCost NewCost =
- TTI.getArithmeticReductionCost(ReductionOpc, TruncSrcTy, std::nullopt,
+ TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
CostKind) +
- TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
+ TTI.getCastInstrCost(CastOpc, ResultTy,
ReductionSrcTy->getScalarType(),
TTI::CastContextHint::None, CostKind);
@@ -2159,9 +2171,9 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
return false;
Value *NewReduction = Builder.CreateIntrinsic(
- TruncSrcTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
- Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
- replaceValue(I, *NewTruncation);
+ SrcTy->getScalarType(), II->getIntrinsicID(), {Src});
+ Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
+ replaceValue(I, *NewCast);
return true;
}
@@ -2559,7 +2571,7 @@ bool VectorCombine::run() {
switch (Opcode) {
case Instruction::Call:
MadeChange |= foldShuffleFromReductions(I);
- MadeChange |= foldTruncFromReductions(I);
+ MadeChange |= foldCastFromReductions(I);
break;
case Instruction::ICmp:
case Instruction::FCmp:
diff --git a/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll b/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
index 9b1aa19f85c21..f04bcc90e5c35 100644
--- a/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
+++ b/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
@@ -74,8 +74,8 @@ define i16 @reduce_mul_trunc_v8i64_i16(<8 x i64> %a0) {
define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0) {
; CHECK-LABEL: @reduce_or_sext_v8i8_to_v8i32(
-; CHECK-NEXT: [[TR:%.*]] = sext <8 x i8> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = sext i8 [[TMP1]] to i32
; CHECK-NEXT: ret i32 [[RED]]
;
%tr = sext <8 x i8> %a0 to <8 x i32>
@@ -85,8 +85,8 @@ define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0) {
define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0) {
; CHECK-LABEL: @reduce_or_sext_v8i16_to_v8i32(
-; CHECK-NEXT: [[TR:%.*]] = sext <8 x i16> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = sext i16 [[TMP1]] to i32
; CHECK-NEXT: ret i32 [[RED]]
;
%tr = sext <8 x i16> %a0 to <8 x i32>
@@ -96,8 +96,8 @@ define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0) {
define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0) {
; CHECK-LABEL: @reduce_or_zext_v8i8_to_v8i32(
-; CHECK-NEXT: [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = zext i8 [[TMP1]] to i32
; CHECK-NEXT: ret i32 [[RED]]
;
%tr = zext <8 x i8> %a0 to <8 x i32>
@@ -107,8 +107,8 @@ define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0) {
define i32 @reduce_or_zext_v8i16_to_v8i32(<8 x i16> %a0) {
; CHECK-LABEL: @reduce_or_zext_v8i16_to_v8i32(
-; CHECK-NEXT: [[TR:%.*]] = zext <8 x i16> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
+; CHECK-NEXT: [[RED:%.*]] = zext i16 [[TMP1]] to i32
; CHECK-NEXT: ret i32 [[RED]]
;
%tr = zext <8 x i16> %a0 to <8 x i32>
@@ -116,6 +116,20 @@ define i32 @reduce_or_zext_v8i16_to_v8i32(<8 x i16> %a0) {
ret i32 %red
}
+; Negative case - narrowing the reduce (to i8) is illegal.
+; TODO: We could narrow to i16 instead.
+define i32 @reduce_add_trunc_v8i8_to_v8i32(<8 x i8> %a0) {
+; CHECK-LABEL: @reduce_add_trunc_v8i8_to_v8i32(
+; CHECK-NEXT: [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
+; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT: ret i32 [[RED]]
+;
+ %tr = zext <8 x i8> %a0 to <8 x i32>
+ %red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
+ ret i32 %red
+}
+
+
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
declare i16 @llvm.vector.reduce.add.v8i16(<8 x i16>)
declare i8 @llvm.vector.reduce.add.v8i8(<8 x i8>)
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests for and/xor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left them out intentionally since we generally avoid repetitive tests. Happy to add if you'd prefer.
Summary: This extends the existing foldTruncFromReductions transform to handle sext and zext as well. This is only legal for the bitwise reductions (and/or/xor) and not the arithmetic ones (add, mul). Use the same costing decision to drive whether we do the transform. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250996
This extends the existing foldTruncFromReductions transform to handle sext and zext as well. This is only legal for the bitwise reductions (and/or/xor) and not the arithmetic ones (add, mul). Use the same costing decision to drive whether we do the transform.