Skip to content

[vectorcombine] Pull sext/zext through reduce.or/and/xor #99548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 18, 2024

Conversation

preames
Copy link
Collaborator

@preames preames commented Jul 18, 2024

This extends the existing foldTruncFromReductions transform to handle sext and zext as well. This is only legal for the bitwise reductions (and/or/xor) and not the arithmetic ones (add, mul). Use the same costing decision to drive whether we do the transform.

This extends the existing foldTruncFromReductions transform to handle
sext and zext as well.  This is only legal for the bitwise reductions
(and/or/xor) and not the arithmetic ones (add, mul).  Use the same
costing decision to drive whether we do the transform.
@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Philip Reames (preames)

Changes

This extends the existing foldTruncFromReductions transform to handle sext and zext as well. This is only legal for the bitwise reductions (and/or/xor) and not the arithmetic ones (add, mul). Use the same costing decision to drive whether we do the transform.


Full diff: https://github.com/llvm/llvm-project/pull/99548.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+26-14)
  • (modified) llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll (+22-8)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 3a49f95d3f117..de60d80aeffa1 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -117,7 +117,7 @@ class VectorCombine {
   bool foldShuffleOfShuffles(Instruction &I);
   bool foldShuffleToIdentity(Instruction &I);
   bool foldShuffleFromReductions(Instruction &I);
-  bool foldTruncFromReductions(Instruction &I);
+  bool foldCastFromReductions(Instruction &I);
   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
 
   void replaceValue(Value &Old, Value &New) {
@@ -2113,15 +2113,20 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
 
 /// Determine if its more efficient to fold:
 ///   reduce(trunc(x)) -> trunc(reduce(x)).
-bool VectorCombine::foldTruncFromReductions(Instruction &I) {
+///   reduce(sext(x))  -> sext(reduce(x)).
+///   reduce(zext(x))  -> zext(reduce(x)).
+bool VectorCombine::foldCastFromReductions(Instruction &I) {
   auto *II = dyn_cast<IntrinsicInst>(&I);
   if (!II)
     return false;
 
+  bool TruncOnly = false;
   Intrinsic::ID IID = II->getIntrinsicID();
   switch (IID) {
   case Intrinsic::vector_reduce_add:
   case Intrinsic::vector_reduce_mul:
+    TruncOnly = true;
+    break;
   case Intrinsic::vector_reduce_and:
   case Intrinsic::vector_reduce_or:
   case Intrinsic::vector_reduce_xor:
@@ -2133,25 +2138,32 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
   unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
   Value *ReductionSrc = I.getOperand(0);
 
-  Value *TruncSrc;
-  if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(TruncSrc)))))
+  Value *Src;
+  if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
+      (TruncOnly ||
+       !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
     return false;
 
-  auto *TruncSrcTy = cast<VectorType>(TruncSrc->getType());
+  // Note: Only trunc has a constexpr, neither sext or zext do.
+  auto CastOpc = Instruction::Trunc;
+  if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
+      CastOpc = (Instruction::CastOps)cast<Instruction>(Cast)->getOpcode();
+
+  auto *SrcTy = cast<VectorType>(Src->getType());
   auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
   Type *ResultTy = I.getType();
 
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   InstructionCost OldCost = TTI.getArithmeticReductionCost(
       ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
-  if (auto *Trunc = dyn_cast<CastInst>(ReductionSrc))
+  if (auto *Cast = dyn_cast<CastInst>(ReductionSrc))
     OldCost +=
-        TTI.getCastInstrCost(Instruction::Trunc, ReductionSrcTy, TruncSrcTy,
-                             TTI::CastContextHint::None, CostKind, Trunc);
+        TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
+                             TTI::CastContextHint::None, CostKind, Cast);
   InstructionCost NewCost =
-      TTI.getArithmeticReductionCost(ReductionOpc, TruncSrcTy, std::nullopt,
+      TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
                                      CostKind) +
-      TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
+      TTI.getCastInstrCost(CastOpc, ResultTy,
                            ReductionSrcTy->getScalarType(),
                            TTI::CastContextHint::None, CostKind);
 
@@ -2159,9 +2171,9 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
     return false;
 
   Value *NewReduction = Builder.CreateIntrinsic(
-      TruncSrcTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
-  Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
-  replaceValue(I, *NewTruncation);
+      SrcTy->getScalarType(), II->getIntrinsicID(), {Src});
+  Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
+  replaceValue(I, *NewCast);
   return true;
 }
 
@@ -2559,7 +2571,7 @@ bool VectorCombine::run() {
       switch (Opcode) {
       case Instruction::Call:
         MadeChange |= foldShuffleFromReductions(I);
-        MadeChange |= foldTruncFromReductions(I);
+        MadeChange |= foldCastFromReductions(I);
         break;
       case Instruction::ICmp:
       case Instruction::FCmp:
diff --git a/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll b/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
index 9b1aa19f85c21..f04bcc90e5c35 100644
--- a/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
+++ b/llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll
@@ -74,8 +74,8 @@ define i16 @reduce_mul_trunc_v8i64_i16(<8 x i64> %a0)  {
 
 define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0)  {
 ; CHECK-LABEL: @reduce_or_sext_v8i8_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = sext <8 x i8> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = sext i8 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = sext <8 x i8> %a0 to <8 x i32>
@@ -85,8 +85,8 @@ define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0)  {
 
 define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0)  {
 ; CHECK-LABEL: @reduce_or_sext_v8i16_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = sext <8 x i16> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = sext i16 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = sext <8 x i16> %a0 to <8 x i32>
@@ -96,8 +96,8 @@ define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0)  {
 
 define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0)  {
 ; CHECK-LABEL: @reduce_or_zext_v8i8_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = zext i8 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = zext <8 x i8> %a0 to <8 x i32>
@@ -107,8 +107,8 @@ define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0)  {
 
 define i32 @reduce_or_zext_v8i16_to_v8i32(<8 x i16> %a0)  {
 ; CHECK-LABEL: @reduce_or_zext_v8i16_to_v8i32(
-; CHECK-NEXT:    [[TR:%.*]] = zext <8 x i16> [[A0:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = zext i16 [[TMP1]] to i32
 ; CHECK-NEXT:    ret i32 [[RED]]
 ;
   %tr = zext <8 x i16> %a0 to <8 x i32>
@@ -116,6 +116,20 @@ define i32 @reduce_or_zext_v8i16_to_v8i32(<8 x i16> %a0)  {
   ret i32 %red
 }
 
+; Negative case - narrowing the reduce (to i8) is illegal.
+; TODO: We could narrow to i16 instead.
+define i32 @reduce_add_trunc_v8i8_to_v8i32(<8 x i8> %a0)  {
+; CHECK-LABEL: @reduce_add_trunc_v8i8_to_v8i32(
+; CHECK-NEXT:    [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
+; CHECK-NEXT:    [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %tr = zext <8 x i8> %a0 to <8 x i32>
+  %red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
+  ret i32 %red
+}
+
+
 declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
 declare i16 @llvm.vector.reduce.add.v8i16(<8 x i16>)
 declare i8 @llvm.vector.reduce.add.v8i8(<8 x i8>)

Copy link

github-actions bot commented Jul 18, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for and/xor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left them out intentionally since we generally avoid repetitive tests. Happy to add if you'd prefer.

@preames preames merged commit ded35c0 into llvm:main Jul 18, 2024
4 of 6 checks passed
@preames preames deleted the pr-vector-combine-reduce-of-zext-or-sext branch July 18, 2024 20:56
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
This extends the existing foldTruncFromReductions transform to handle
sext and zext as well. This is only legal for the bitwise reductions
(and/or/xor) and not the arithmetic ones (add, mul). Use the same
costing decision to drive whether we do the transform.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250996
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants