Skip to content

[SLP]Use getExtendedReduction cost and fix reduction cost calculations #117350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

alexey-bataev
Copy link
Member

Patch uses getExtendedReduction for reductions of ext-based nodes + adds
cost estimation for ctpop-kind reductions into basic implementation and
RISCV-V specific vcpop cost estimation.

Created using spr 1.3.5
@llvmbot
Copy link
Member

llvmbot commented Nov 22, 2024

@llvm/pr-subscribers-vectorizers
@llvm/pr-subscribers-backend-risc-v

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

Changes

Patch uses getExtendedReduction for reductions of ext-based nodes + adds
cost estimation for ctpop-kind reductions into basic implementation and
RISCV-V specific vcpop cost estimation.


Full diff: https://github.com/llvm/llvm-project/pull/117350.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+12)
  • (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+8)
  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+100-51)
  • (modified) llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll (+4-14)
  • (modified) llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll (+1-1)
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index b3583e2819ee4c..d4bd504bf9ba31 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2765,6 +2765,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                            Type *ResTy, VectorType *Ty,
                                            FastMathFlags FMF,
                                            TTI::TargetCostKind CostKind) {
+    if (auto *FTy = dyn_cast<FixedVectorType>(Ty);
+        FTy && Opcode == Instruction::Add &&
+        FTy->getElementType() == IntegerType::getInt1Ty(Ty->getContext())) {
+      // Represent vector_reduce_add(ZExt(<n x i1>)) as
+      // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+      auto *IntTy =
+          IntegerType::get(ResTy->getContext(), FTy->getNumElements());
+      IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
+      return thisT()->getCastInstrCost(Instruction::BitCast, IntTy, FTy,
+                                       TTI::CastContextHint::None, CostKind) +
+             thisT()->getIntrinsicInstrCost(ICA, CostKind);
+    }
     // Without any native support, this is equivalent to the cost of
     // vecreduce.opcode(ext(Ty A)).
     VectorType *ExtTy = VectorType::get(ResTy, Ty);
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 2b16dcbcd8695b..026b1e694e6a64 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1620,6 +1620,14 @@ InstructionCost RISCVTTIImpl::getExtendedReductionCost(
 
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
 
+  if (Opcode == Instruction::Add && LT.second.isFixedLengthVector() &&
+      LT.second.getScalarType() == MVT::i1) {
+    // Represent vector_reduce_add(ZExt(<n x i1>)) as
+    // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+    return LT.first *
+           getRISCVInstructionCost(RISCV::VCPOP_M, LT.second, CostKind);
+  }
+
   if (ResTy->getScalarSizeInBits() != 2 * LT.second.getScalarSizeInBits())
     return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy,
                                            FMF, CostKind);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 8e0ca2677bf0a9..46ae908f57ab89 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1371,22 +1371,46 @@ class BoUpSLP {
     return VectorizableTree.front()->Scalars;
   }
 
+  /// Returns the type/is-signed info for the root node in the graph without
+  /// casting.
+  std::optional<std::pair<Type *, bool>> getRootNodeTypeWithNoCast() const {
+    const TreeEntry &Root = *VectorizableTree.front().get();
+    if (Root.State != TreeEntry::Vectorize || Root.isAltShuffle() ||
+        !Root.Scalars.front()->getType()->isIntegerTy())
+      return std::nullopt;
+    auto It = MinBWs.find(&Root);
+    if (It != MinBWs.end())
+      return std::make_pair(IntegerType::get(Root.Scalars.front()->getContext(),
+                                             It->second.first),
+                            It->second.second);
+    if (Root.getOpcode() == Instruction::ZExt ||
+        Root.getOpcode() == Instruction::SExt)
+      return std::make_pair(cast<CastInst>(Root.getMainOp())->getSrcTy(),
+                            Root.getOpcode() == Instruction::SExt);
+    return std::nullopt;
+  }
+
   /// Checks if the root graph node can be emitted with narrower bitwidth at
   /// codegen and returns it signedness, if so.
   bool isSignedMinBitwidthRootNode() const {
     return MinBWs.at(VectorizableTree.front().get()).second;
   }
 
-  /// Returns reduction bitwidth and signedness, if it does not match the
-  /// original requested size.
-  std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
+  /// Returns reduction type after minbitdth analysis.
+  FixedVectorType *getReductionType() const {
     if (ReductionBitWidth == 0 ||
+        !VectorizableTree.front()->Scalars.front()->getType()->isIntegerTy() ||
         ReductionBitWidth >=
             DL->getTypeSizeInBits(
                 VectorizableTree.front()->Scalars.front()->getType()))
-      return std::nullopt;
-    return std::make_pair(ReductionBitWidth,
-                          MinBWs.at(VectorizableTree.front().get()).second);
+      return getWidenedType(
+          VectorizableTree.front()->Scalars.front()->getType(),
+          VectorizableTree.front()->getVectorFactor());
+    return getWidenedType(
+        IntegerType::get(
+            VectorizableTree.front()->Scalars.front()->getContext(),
+            ReductionBitWidth),
+        VectorizableTree.front()->getVectorFactor());
   }
 
   /// Builds external uses of the vectorized scalars, i.e. the list of
@@ -11297,6 +11321,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
         return CommonCost;
       auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
       TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
+
+      bool IsArithmeticExtendedReduction =
+          E->Idx == 0 && UserIgnoreList &&
+          all_of(*UserIgnoreList, [](Value *V) {
+            auto *I = cast<Instruction>(V);
+            return is_contained({Instruction::Add, Instruction::FAdd,
+                                 Instruction::Mul, Instruction::FMul,
+                                 Instruction::And, Instruction::Or,
+                                 Instruction::Xor},
+                                I->getOpcode());
+          });
+      if (IsArithmeticExtendedReduction &&
+          (VecOpcode == Instruction::ZExt || VecOpcode == Instruction::SExt))
+        return CommonCost;
       return CommonCost +
              TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
                                    VecOpcode == Opcode ? VI : nullptr);
@@ -12652,32 +12690,48 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
       unsigned SrcSize = It->second.first;
       unsigned DstSize = ReductionBitWidth;
       unsigned Opcode = Instruction::Trunc;
-      if (SrcSize < DstSize)
-        Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
-      auto *SrcVecTy =
-          getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
-      auto *DstVecTy =
-          getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
-      TTI::CastContextHint CCH = getCastContextHint(E);
-      InstructionCost CastCost;
-      switch (E.getOpcode()) {
-      case Instruction::SExt:
-      case Instruction::ZExt:
-      case Instruction::Trunc: {
-        const TreeEntry *OpTE = getOperandEntry(&E, 0);
-        CCH = getCastContextHint(*OpTE);
-        break;
-      }
-      default:
-        break;
+      if (SrcSize < DstSize) {
+        bool IsArithmeticExtendedReduction =
+            all_of(*UserIgnoreList, [](Value *V) {
+              auto *I = cast<Instruction>(V);
+              return is_contained({Instruction::Add, Instruction::FAdd,
+                                   Instruction::Mul, Instruction::FMul,
+                                   Instruction::And, Instruction::Or,
+                                   Instruction::Xor},
+                                  I->getOpcode());
+            });
+        if (IsArithmeticExtendedReduction)
+          Opcode =
+              Instruction::BitCast; // Handle it by getExtendedReductionCost
+        else
+          Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+      }
+      if (Opcode != Instruction::BitCast) {
+        auto *SrcVecTy =
+            getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
+        auto *DstVecTy =
+            getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
+        TTI::CastContextHint CCH = getCastContextHint(E);
+        InstructionCost CastCost;
+        switch (E.getOpcode()) {
+        case Instruction::SExt:
+        case Instruction::ZExt:
+        case Instruction::Trunc: {
+          const TreeEntry *OpTE = getOperandEntry(&E, 0);
+          CCH = getCastContextHint(*OpTE);
+          break;
+        }
+        default:
+          break;
+        }
+        CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
+                                          TTI::TCK_RecipThroughput);
+        Cost += CastCost;
+        LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
+                          << " for final resize for reduction from " << SrcVecTy
+                          << " to " << DstVecTy << "\n";
+                   dbgs() << "SLP: Current total cost = " << Cost << "\n");
       }
-      CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
-                                        TTI::TCK_RecipThroughput);
-      Cost += CastCost;
-      LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
-                        << " for final resize for reduction from " << SrcVecTy
-                        << " to " << DstVecTy << "\n";
-                 dbgs() << "SLP: Current total cost = " << Cost << "\n");
     }
   }
 
@@ -19815,8 +19869,8 @@ class HorizontalReduction {
 
         // Estimate cost.
         InstructionCost TreeCost = V.getTreeCost(VL);
-        InstructionCost ReductionCost = getReductionCost(
-            TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
+        InstructionCost ReductionCost =
+            getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V);
         InstructionCost Cost = TreeCost + ReductionCost;
         LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
                           << " for reduction\n");
@@ -20107,14 +20161,14 @@ class HorizontalReduction {
 
 private:
   /// Calculate the cost of a reduction.
-  InstructionCost getReductionCost(
-      TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
-      bool IsCmpSelMinMax, FastMathFlags FMF,
-      const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
+  InstructionCost getReductionCost(TargetTransformInfo *TTI,
+                                   ArrayRef<Value *> ReducedVals,
+                                   bool IsCmpSelMinMax, FastMathFlags FMF,
+                                   const BoUpSLP &R) {
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     Type *ScalarTy = ReducedVals.front()->getType();
     unsigned ReduxWidth = ReducedVals.size();
-    FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
+    FixedVectorType *VectorTy = R.getReductionType();
     InstructionCost VectorCost = 0, ScalarCost;
     // If all of the reduced values are constant, the vector cost is 0, since
     // the reduction value can be calculated at the compile time.
@@ -20172,21 +20226,16 @@ class HorizontalReduction {
               VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
               /*Extract*/ false, TTI::TCK_RecipThroughput);
         } else {
-          auto [Bitwidth, IsSigned] =
-              BitwidthAndSign.value_or(std::make_pair(0u, false));
-          if (RdxKind == RecurKind::Add && Bitwidth == 1) {
-            // Represent vector_reduce_add(ZExt(<n x i1>)) to
-            // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
-            auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
-            IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
-            VectorCost =
-                TTI->getCastInstrCost(Instruction::BitCast, IntTy,
-                                      getWidenedType(ScalarTy, ReduxWidth),
-                                      TTI::CastContextHint::None, CostKind) +
-                TTI->getIntrinsicInstrCost(ICA, CostKind);
-          } else {
+          Type *RedTy = VectorTy->getElementType();
+          auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
+              std::make_pair(RedTy, true));
+          if (RType == RedTy) {
             VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
                                                          FMF, CostKind);
+          } else {
+            VectorCost = TTI->getExtendedReductionCost(
+                RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth),
+                FMF, CostKind);
           }
         }
       }
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
index bc24a44cecbe39..85131758853b3d 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
@@ -877,20 +877,10 @@ entry:
 define i64 @red_zext_ld_4xi64(ptr %ptr) {
 ; CHECK-LABEL: @red_zext_ld_4xi64(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[LD0:%.*]] = load i8, ptr [[PTR:%.*]], align 1
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[LD0]] to i64
-; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 1
-; CHECK-NEXT:    [[LD1:%.*]] = load i8, ptr [[GEP]], align 1
-; CHECK-NEXT:    [[ZEXT_1:%.*]] = zext i8 [[LD1]] to i64
-; CHECK-NEXT:    [[ADD_1:%.*]] = add nuw nsw i64 [[ZEXT]], [[ZEXT_1]]
-; CHECK-NEXT:    [[GEP_1:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 2
-; CHECK-NEXT:    [[LD2:%.*]] = load i8, ptr [[GEP_1]], align 1
-; CHECK-NEXT:    [[ZEXT_2:%.*]] = zext i8 [[LD2]] to i64
-; CHECK-NEXT:    [[ADD_2:%.*]] = add nuw nsw i64 [[ADD_1]], [[ZEXT_2]]
-; CHECK-NEXT:    [[GEP_2:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 3
-; CHECK-NEXT:    [[LD3:%.*]] = load i8, ptr [[GEP_2]], align 1
-; CHECK-NEXT:    [[ZEXT_3:%.*]] = zext i8 [[LD3]] to i64
-; CHECK-NEXT:    [[ADD_3:%.*]] = add nuw nsw i64 [[ADD_2]], [[ZEXT_3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i8>, ptr [[PTR:%.*]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[TMP0]] to <4 x i16>
+; CHECK-NEXT:    [[TMP2:%.*]] = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> [[TMP1]])
+; CHECK-NEXT:    [[ADD_3:%.*]] = zext i16 [[TMP2]] to i64
 ; CHECK-NEXT:    ret i64 [[ADD_3]]
 ;
 entry:
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
index e4d20a6db8fa67..09c11bbefd4a35 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
@@ -8,7 +8,7 @@
 ; YAML-NEXT: Function:        test
 ; YAML-NEXT: Args:
 ; YAML-NEXT:   - String:          'Vectorized horizontal reduction with cost '
-; YAML-NEXT:   - Cost:            '-1'
+; YAML-NEXT:   - Cost:            '-10'
 ; YAML-NEXT:   - String:          ' and with tree size '
 ; YAML-NEXT:   - TreeSize:        '8'
 ; YAML-NEXT:...

@@ -2765,6 +2765,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
Type *ResTy, VectorType *Ty,
FastMathFlags FMF,
TTI::TargetCostKind CostKind) {
if (auto *FTy = dyn_cast<FixedVectorType>(Ty);
FTy && Opcode == Instruction::Add &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're missing the IsUnsigned check here - this allows sext too.

@@ -1620,6 +1620,14 @@ InstructionCost RISCVTTIImpl::getExtendedReductionCost(

std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);

if (Opcode == Instruction::Add && LT.second.isFixedLengthVector() &&
LT.second.getScalarType() == MVT::i1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, missing the unsigned check.

Created using spr 1.3.5
Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - Note that I'm only confident in the TTI bits. The SLP bits aren't obviously broken, but that's all I can say.

@alexey-bataev alexey-bataev merged commit 7523086 into main Nov 22, 2024
5 of 8 checks passed
@alexey-bataev alexey-bataev deleted the users/alexey-bataev/spr/slpuse-getextendedreduction-cost-and-fix-reduction-cost-calculations branch November 22, 2024 21:12
nasherm added a commit to nasherm/llvm-project that referenced this pull request Jan 20, 2025
…SExt

PR llvm#117350 made changes to the SLP vectorizer which introduced
a regression on ARM vectorization benchmarks. This was due
to the changes assuming that SExt/ZExt vector instructions have
constant cost. This behaviour is expected for RISCV but not on ARM
where we take into account source and destination type of SExt/ZExt
instructions when calculating vector cost.

Change-Id: I6f995dcde26e5aaf62b779b63e52988fb333f941
nasherm added a commit to nasherm/llvm-project that referenced this pull request Jan 20, 2025
…SExt

PR llvm#117350 made changes to the SLP vectorizer which introduced
a regression on ARM vectorization benchmarks. This was due
to the changes assuming that SExt/ZExt vector instructions have
constant cost. This behaviour is expected for RISCV but not on ARM
where we take into account source and destination type of SExt/ZExt
instructions when calculating vector cost.

Change-Id: I6f995dcde26e5aaf62b779b63e52988fb333f941
nasherm added a commit to nasherm/llvm-project that referenced this pull request Jan 21, 2025
…SExt

PR llvm#117350 made changes to the SLP vectorizer which introduced
a regression on ARM vectorization benchmarks. This was due
to the changes assuming that SExt/ZExt vector instructions have
constant cost. This behaviour is expected for RISCV but not on ARM
where we take into account source and destination type of SExt/ZExt
instructions when calculating vector cost.

Change-Id: I6f995dcde26e5aaf62b779b63e52988fb333f941
nasherm added a commit to nasherm/llvm-project that referenced this pull request Jan 30, 2025
…SExt

PR llvm#117350 made changes to the SLP vectorizer which introduced
a regression on ARM vectorization benchmarks. This was due
to the changes assuming that SExt/ZExt vector instructions have
constant cost. This behaviour is expected for RISCV but not on ARM
where we take into account source and destination type of SExt/ZExt
instructions when calculating vector cost.

Change-Id: I6f995dcde26e5aaf62b779b63e52988fb333f941
nasherm added a commit that referenced this pull request Mar 19, 2025
PR #117350 made changes to the SLP vectorizer which introduced a
regression on some ARM benchmarks. Investigation narrowed it down to
suboptimal codegen for benchmarks that previously only used scalar (U/S)MLAL
instructions. The linked change meant the SLPVectorizer thought that
these could be vectorized. This change makes the cost of muls in
(U/S)MLAL patterns slightly cheaper to make sure scalar instructions are
preferred in these cases over SLP vectorization on targets supporting DSP
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants