Skip to content

[SLP]Model reduction_add(ext(<n x i1>)) as ext(ctpop(bitcast <n x i1> to int n)) #116875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

alexey-bataev
Copy link
Member

Currently sequences reduction_add(ext()) are modeled as vector
extensions + reduction add, but later instcombiner transforms it into
ext(ctcpop(bitcast to int n)). Patch adds direct support for
this in SLP vectorizer, which enables better cost estimation.

AVX512, -O3+LTO

CINT2006/445.gobmk - extra vector code
Prolangs-C/bison - extra vector code
Benchmarks/NPB-serial/is - 16 x + 8 x reductions vectorized as 24
x reduction

Created using spr 1.3.5
@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

Changes

Currently sequences reduction_add(ext(<n x i1>)) are modeled as vector
extensions + reduction add, but later instcombiner transforms it into
ext(ctcpop(bitcast <n x i1> to int n)). Patch adds direct support for
this in SLP vectorizer, which enables better cost estimation.

AVX512, -O3+LTO

CINT2006/445.gobmk - extra vector code
Prolangs-C/bison - extra vector code
Benchmarks/NPB-serial/is - 16 x + 8 x reductions vectorized as 24
x reduction


Full diff: https://github.com/llvm/llvm-project/pull/116875.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+80-27)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll (+3-2)
  • (modified) llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll (+3-2)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e70627b6afc10d..fe5099d68024c3 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1371,6 +1371,18 @@ class BoUpSLP {
     return MinBWs.at(VectorizableTree.front().get()).second;
   }
 
+  /// Returns reduction bitwidth and signedness, if it does not match the
+  /// original requested size.
+  std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
+    if (ReductionBitWidth == 0 ||
+        ReductionBitWidth ==
+            DL->getTypeSizeInBits(
+                VectorizableTree.front()->Scalars.front()->getType()))
+      return std::nullopt;
+    return std::make_pair(ReductionBitWidth,
+                          MinBWs.at(VectorizableTree.front().get()).second);
+  }
+
   /// Builds external uses of the vectorized scalars, i.e. the list of
   /// vectorized scalars to be extracted, their lanes and their scalar users. \p
   /// ExternallyUsedValues contains additional list of external uses to handle
@@ -17885,24 +17897,37 @@ void BoUpSLP::computeMinimumValueSizes() {
   // Add reduction ops sizes, if any.
   if (UserIgnoreList &&
       isa<IntegerType>(VectorizableTree.front()->Scalars.front()->getType())) {
-    for (Value *V : *UserIgnoreList) {
-      auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
-      auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
-      unsigned BitWidth1 = NumTypeBits - NumSignBits;
-      if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
-        ++BitWidth1;
-      unsigned BitWidth2 = BitWidth1;
-      if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
-        auto Mask = DB->getDemandedBits(cast<Instruction>(V));
-        BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
+      // Convert vector_reduce_add(ZExt(<n x i1>)) to
+      // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+    if (all_of(*UserIgnoreList,
+               [](Value *V) {
+                 return cast<Instruction>(V)->getOpcode() == Instruction::Add;
+               }) &&
+        VectorizableTree.front()->State == TreeEntry::Vectorize &&
+        VectorizableTree.front()->getOpcode() == Instruction::ZExt &&
+        cast<CastInst>(VectorizableTree.front()->getMainOp())->getSrcTy() ==
+            Builder.getInt1Ty()) {
+      ReductionBitWidth = 1;
+    } else {
+      for (Value *V : *UserIgnoreList) {
+        auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
+        auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
+        unsigned BitWidth1 = NumTypeBits - NumSignBits;
+        if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
+          ++BitWidth1;
+        unsigned BitWidth2 = BitWidth1;
+        if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
+          auto Mask = DB->getDemandedBits(cast<Instruction>(V));
+          BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
+        }
+        ReductionBitWidth =
+            std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
       }
-      ReductionBitWidth =
-          std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
-    }
-    if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
-      ReductionBitWidth = 8;
+      if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
+        ReductionBitWidth = 8;
 
-    ReductionBitWidth = bit_ceil(ReductionBitWidth);
+      ReductionBitWidth = bit_ceil(ReductionBitWidth);
+    }
   }
   bool IsTopRoot = NodeIdx == 0;
   while (NodeIdx < VectorizableTree.size() &&
@@ -19758,8 +19783,8 @@ class HorizontalReduction {
 
         // Estimate cost.
         InstructionCost TreeCost = V.getTreeCost(VL);
-        InstructionCost ReductionCost =
-            getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF);
+        InstructionCost ReductionCost = getReductionCost(
+            TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
         InstructionCost Cost = TreeCost + ReductionCost;
         LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
                           << " for reduction\n");
@@ -19864,10 +19889,12 @@ class HorizontalReduction {
                 createStrideMask(I, ScalarTyNumElements, VL.size());
             Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
             ReducedSubTree = Builder.CreateInsertElement(
-                ReducedSubTree, emitReduction(Lane, Builder, TTI), I);
+                ReducedSubTree,
+                emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I);
           }
         } else {
-          ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI);
+          ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI,
+                                         RdxRootInst->getType());
         }
         if (ReducedSubTree->getType() != VL.front()->getType()) {
           assert(ReducedSubTree->getType() != VL.front()->getType() &&
@@ -20048,12 +20075,13 @@ class HorizontalReduction {
 
 private:
   /// Calculate the cost of a reduction.
-  InstructionCost getReductionCost(TargetTransformInfo *TTI,
-                                   ArrayRef<Value *> ReducedVals,
-                                   bool IsCmpSelMinMax, unsigned ReduxWidth,
-                                   FastMathFlags FMF) {
+  InstructionCost getReductionCost(
+      TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
+      bool IsCmpSelMinMax, FastMathFlags FMF,
+      const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     Type *ScalarTy = ReducedVals.front()->getType();
+    unsigned ReduxWidth = ReducedVals.size();
     FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
     InstructionCost VectorCost = 0, ScalarCost;
     // If all of the reduced values are constant, the vector cost is 0, since
@@ -20112,8 +20140,22 @@ class HorizontalReduction {
               VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
               /*Extract*/ false, TTI::TCK_RecipThroughput);
         } else {
-          VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
-                                                       CostKind);
+          auto [Bitwidth, IsSigned] =
+              BitwidthAndSign.value_or(std::make_pair(0u, false));
+          if (RdxKind == RecurKind::Add && Bitwidth == 1) {
+            // Represent vector_reduce_add(ZExt(<n x i1>)) to
+            // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+            auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
+            IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
+            VectorCost =
+                TTI->getCastInstrCost(Instruction::BitCast, IntTy,
+                                      getWidenedType(ScalarTy, ReduxWidth),
+                                      TTI::CastContextHint::None, CostKind) +
+                TTI->getIntrinsicInstrCost(ICA, CostKind);
+          } else {
+            VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
+                                                         FMF, CostKind);
+          }
         }
       }
       ScalarCost = EvaluateScalarCost([&]() {
@@ -20150,11 +20192,22 @@ class HorizontalReduction {
 
   /// Emit a horizontal reduction of the vectorized value.
   Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder,
-                       const TargetTransformInfo *TTI) {
+                       const TargetTransformInfo *TTI, Type *DestTy) {
     assert(VectorizedValue && "Need to have a vectorized tree node");
     assert(RdxKind != RecurKind::FMulAdd &&
            "A call to the llvm.fmuladd intrinsic is not handled yet");
 
+    auto *FTy = cast<FixedVectorType>(VectorizedValue->getType());
+    if (FTy->getScalarType() == Builder.getInt1Ty() &&
+        RdxKind == RecurKind::Add &&
+        DestTy->getScalarType() != FTy->getScalarType()) {
+      // Convert vector_reduce_add(ZExt(<n x i1>)) to
+      // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+      Value *V = Builder.CreateBitCast(
+          VectorizedValue, Builder.getIntNTy(FTy->getNumElements()));
+      ++NumVectorInstructions;
+      return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V);
+    }
     ++NumVectorInstructions;
     return createSimpleReduction(Builder, VectorizedValue, RdxKind);
   }
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll b/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll
index ecf85159efdfbd..f00b846bf4f5bd 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll
@@ -11,8 +11,9 @@ define i16 @test(i16 %call37) {
 ; CHECK-NEXT:    [[TMP2:%.*]] = icmp slt <8 x i16> [[SHUFFLE]], zeroinitializer
 ; CHECK-NEXT:    [[TMP3:%.*]] = icmp sgt <8 x i16> [[SHUFFLE]], zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> [[TMP3]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <8 x i1> [[TMP4]] to <8 x i16>
-; CHECK-NEXT:    [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <8 x i1> [[TMP4]] to i8
+; CHECK-NEXT:    [[TMP7:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP8]])
+; CHECK-NEXT:    [[TMP6:%.*]] = zext i8 [[TMP7]] to i16
 ; CHECK-NEXT:    [[OP_RDX:%.*]] = add i16 [[TMP6]], 0
 ; CHECK-NEXT:    ret i16 [[OP_RDX]]
 ;
diff --git a/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll b/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll
index 89fcc7e983749b..303e31dfa5e64a 100644
--- a/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll
+++ b/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll
@@ -14,8 +14,9 @@ define i32 @test(i32 %a, i8 %b, i8 %c) {
 ; CHECK-NEXT:    [[TMP8:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i16>
 ; CHECK-NEXT:    [[TMP9:%.*]] = sext <4 x i8> [[TMP4]] to <4 x i16>
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp sle <4 x i16> [[TMP8]], [[TMP9]]
-; CHECK-NEXT:    [[TMP6:%.*]] = zext <4 x i1> [[TMP5]] to <4 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast <4 x i1> [[TMP5]] to i4
+; CHECK-NEXT:    [[TMP11:%.*]] = call i4 @llvm.ctpop.i4(i4 [[TMP10]])
+; CHECK-NEXT:    [[TMP7:%.*]] = zext i4 [[TMP11]] to i32
 ; CHECK-NEXT:    [[OP_RDX:%.*]] = add i32 [[TMP7]], [[A]]
 ; CHECK-NEXT:    ret i32 [[OP_RDX]]
 ;

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-vectorizers

Author: Alexey Bataev (alexey-bataev)

Changes

Currently sequences reduction_add(ext(<n x i1>)) are modeled as vector
extensions + reduction add, but later instcombiner transforms it into
ext(ctcpop(bitcast <n x i1> to int n)). Patch adds direct support for
this in SLP vectorizer, which enables better cost estimation.

AVX512, -O3+LTO

CINT2006/445.gobmk - extra vector code
Prolangs-C/bison - extra vector code
Benchmarks/NPB-serial/is - 16 x + 8 x reductions vectorized as 24
x reduction


Full diff: https://github.com/llvm/llvm-project/pull/116875.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+80-27)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll (+3-2)
  • (modified) llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll (+3-2)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e70627b6afc10d..fe5099d68024c3 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1371,6 +1371,18 @@ class BoUpSLP {
     return MinBWs.at(VectorizableTree.front().get()).second;
   }
 
+  /// Returns reduction bitwidth and signedness, if it does not match the
+  /// original requested size.
+  std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
+    if (ReductionBitWidth == 0 ||
+        ReductionBitWidth ==
+            DL->getTypeSizeInBits(
+                VectorizableTree.front()->Scalars.front()->getType()))
+      return std::nullopt;
+    return std::make_pair(ReductionBitWidth,
+                          MinBWs.at(VectorizableTree.front().get()).second);
+  }
+
   /// Builds external uses of the vectorized scalars, i.e. the list of
   /// vectorized scalars to be extracted, their lanes and their scalar users. \p
   /// ExternallyUsedValues contains additional list of external uses to handle
@@ -17885,24 +17897,37 @@ void BoUpSLP::computeMinimumValueSizes() {
   // Add reduction ops sizes, if any.
   if (UserIgnoreList &&
       isa<IntegerType>(VectorizableTree.front()->Scalars.front()->getType())) {
-    for (Value *V : *UserIgnoreList) {
-      auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
-      auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
-      unsigned BitWidth1 = NumTypeBits - NumSignBits;
-      if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
-        ++BitWidth1;
-      unsigned BitWidth2 = BitWidth1;
-      if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
-        auto Mask = DB->getDemandedBits(cast<Instruction>(V));
-        BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
+      // Convert vector_reduce_add(ZExt(<n x i1>)) to
+      // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+    if (all_of(*UserIgnoreList,
+               [](Value *V) {
+                 return cast<Instruction>(V)->getOpcode() == Instruction::Add;
+               }) &&
+        VectorizableTree.front()->State == TreeEntry::Vectorize &&
+        VectorizableTree.front()->getOpcode() == Instruction::ZExt &&
+        cast<CastInst>(VectorizableTree.front()->getMainOp())->getSrcTy() ==
+            Builder.getInt1Ty()) {
+      ReductionBitWidth = 1;
+    } else {
+      for (Value *V : *UserIgnoreList) {
+        auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
+        auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
+        unsigned BitWidth1 = NumTypeBits - NumSignBits;
+        if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
+          ++BitWidth1;
+        unsigned BitWidth2 = BitWidth1;
+        if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
+          auto Mask = DB->getDemandedBits(cast<Instruction>(V));
+          BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
+        }
+        ReductionBitWidth =
+            std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
       }
-      ReductionBitWidth =
-          std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
-    }
-    if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
-      ReductionBitWidth = 8;
+      if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
+        ReductionBitWidth = 8;
 
-    ReductionBitWidth = bit_ceil(ReductionBitWidth);
+      ReductionBitWidth = bit_ceil(ReductionBitWidth);
+    }
   }
   bool IsTopRoot = NodeIdx == 0;
   while (NodeIdx < VectorizableTree.size() &&
@@ -19758,8 +19783,8 @@ class HorizontalReduction {
 
         // Estimate cost.
         InstructionCost TreeCost = V.getTreeCost(VL);
-        InstructionCost ReductionCost =
-            getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF);
+        InstructionCost ReductionCost = getReductionCost(
+            TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
         InstructionCost Cost = TreeCost + ReductionCost;
         LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
                           << " for reduction\n");
@@ -19864,10 +19889,12 @@ class HorizontalReduction {
                 createStrideMask(I, ScalarTyNumElements, VL.size());
             Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
             ReducedSubTree = Builder.CreateInsertElement(
-                ReducedSubTree, emitReduction(Lane, Builder, TTI), I);
+                ReducedSubTree,
+                emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I);
           }
         } else {
-          ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI);
+          ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI,
+                                         RdxRootInst->getType());
         }
         if (ReducedSubTree->getType() != VL.front()->getType()) {
           assert(ReducedSubTree->getType() != VL.front()->getType() &&
@@ -20048,12 +20075,13 @@ class HorizontalReduction {
 
 private:
   /// Calculate the cost of a reduction.
-  InstructionCost getReductionCost(TargetTransformInfo *TTI,
-                                   ArrayRef<Value *> ReducedVals,
-                                   bool IsCmpSelMinMax, unsigned ReduxWidth,
-                                   FastMathFlags FMF) {
+  InstructionCost getReductionCost(
+      TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
+      bool IsCmpSelMinMax, FastMathFlags FMF,
+      const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     Type *ScalarTy = ReducedVals.front()->getType();
+    unsigned ReduxWidth = ReducedVals.size();
     FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
     InstructionCost VectorCost = 0, ScalarCost;
     // If all of the reduced values are constant, the vector cost is 0, since
@@ -20112,8 +20140,22 @@ class HorizontalReduction {
               VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
               /*Extract*/ false, TTI::TCK_RecipThroughput);
         } else {
-          VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
-                                                       CostKind);
+          auto [Bitwidth, IsSigned] =
+              BitwidthAndSign.value_or(std::make_pair(0u, false));
+          if (RdxKind == RecurKind::Add && Bitwidth == 1) {
+            // Represent vector_reduce_add(ZExt(<n x i1>)) to
+            // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+            auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
+            IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
+            VectorCost =
+                TTI->getCastInstrCost(Instruction::BitCast, IntTy,
+                                      getWidenedType(ScalarTy, ReduxWidth),
+                                      TTI::CastContextHint::None, CostKind) +
+                TTI->getIntrinsicInstrCost(ICA, CostKind);
+          } else {
+            VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
+                                                         FMF, CostKind);
+          }
         }
       }
       ScalarCost = EvaluateScalarCost([&]() {
@@ -20150,11 +20192,22 @@ class HorizontalReduction {
 
   /// Emit a horizontal reduction of the vectorized value.
   Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder,
-                       const TargetTransformInfo *TTI) {
+                       const TargetTransformInfo *TTI, Type *DestTy) {
     assert(VectorizedValue && "Need to have a vectorized tree node");
     assert(RdxKind != RecurKind::FMulAdd &&
            "A call to the llvm.fmuladd intrinsic is not handled yet");
 
+    auto *FTy = cast<FixedVectorType>(VectorizedValue->getType());
+    if (FTy->getScalarType() == Builder.getInt1Ty() &&
+        RdxKind == RecurKind::Add &&
+        DestTy->getScalarType() != FTy->getScalarType()) {
+      // Convert vector_reduce_add(ZExt(<n x i1>)) to
+      // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+      Value *V = Builder.CreateBitCast(
+          VectorizedValue, Builder.getIntNTy(FTy->getNumElements()));
+      ++NumVectorInstructions;
+      return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V);
+    }
     ++NumVectorInstructions;
     return createSimpleReduction(Builder, VectorizedValue, RdxKind);
   }
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll b/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll
index ecf85159efdfbd..f00b846bf4f5bd 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll
@@ -11,8 +11,9 @@ define i16 @test(i16 %call37) {
 ; CHECK-NEXT:    [[TMP2:%.*]] = icmp slt <8 x i16> [[SHUFFLE]], zeroinitializer
 ; CHECK-NEXT:    [[TMP3:%.*]] = icmp sgt <8 x i16> [[SHUFFLE]], zeroinitializer
 ; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> [[TMP3]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <8 x i1> [[TMP4]] to <8 x i16>
-; CHECK-NEXT:    [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP5]])
+; CHECK-NEXT:    [[TMP8:%.*]] = bitcast <8 x i1> [[TMP4]] to i8
+; CHECK-NEXT:    [[TMP7:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP8]])
+; CHECK-NEXT:    [[TMP6:%.*]] = zext i8 [[TMP7]] to i16
 ; CHECK-NEXT:    [[OP_RDX:%.*]] = add i16 [[TMP6]], 0
 ; CHECK-NEXT:    ret i16 [[OP_RDX]]
 ;
diff --git a/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll b/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll
index 89fcc7e983749b..303e31dfa5e64a 100644
--- a/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll
+++ b/llvm/test/Transforms/SLPVectorizer/zext-incoming-for-neg-icmp.ll
@@ -14,8 +14,9 @@ define i32 @test(i32 %a, i8 %b, i8 %c) {
 ; CHECK-NEXT:    [[TMP8:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i16>
 ; CHECK-NEXT:    [[TMP9:%.*]] = sext <4 x i8> [[TMP4]] to <4 x i16>
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp sle <4 x i16> [[TMP8]], [[TMP9]]
-; CHECK-NEXT:    [[TMP6:%.*]] = zext <4 x i1> [[TMP5]] to <4 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast <4 x i1> [[TMP5]] to i4
+; CHECK-NEXT:    [[TMP11:%.*]] = call i4 @llvm.ctpop.i4(i4 [[TMP10]])
+; CHECK-NEXT:    [[TMP7:%.*]] = zext i4 [[TMP11]] to i32
 ; CHECK-NEXT:    [[OP_RDX:%.*]] = add i32 [[TMP7]], [[A]]
 ; CHECK-NEXT:    ret i32 [[OP_RDX]]
 ;

Copy link

github-actions bot commented Nov 19, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Created using spr 1.3.5
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Created using spr 1.3.5
@alexey-bataev alexey-bataev merged commit 0298c59 into main Nov 21, 2024
5 of 6 checks passed
@alexey-bataev alexey-bataev deleted the users/alexey-bataev/spr/slpmodel-reduction_addextn-x-i1-as-extctpopbitcast-n-x-i1-to-int-n branch November 21, 2024 18:21
@preames
Copy link
Collaborator

preames commented Nov 21, 2024

Why not use the existing getExtendedReductionCost and push the logic about the ctpop inside that?

@alexey-bataev
Copy link
Member Author

alexey-bataev commented Nov 21, 2024

Why not use the existing getExtendedReductionCost and push the logic about the ctpop inside that?

Good point, missed that we have this entry point, will fix it in a followup patch

alexey-bataev added a commit that referenced this pull request Nov 22, 2024
… to int n))

Currently sequences reduction_add(ext(<n x i1>)) are modeled as vector
extensions + reduction add, but later instcombiner transforms it into
ext(ctcpop(bitcast <n x i1> to int n)). Patch adds direct support for
this in SLP vectorizer, which enables better cost estimation.

AVX512, -O3+LTO

CINT2006/445.gobmk - extra vector code
Prolangs-C/bison - extra vector code
Benchmarks/NPB-serial/is - 16 x + 8 x reductions vectorized as 24
x reduction

Reviewers: RKSimon

Reviewed By: RKSimon

Pull Request: #116875
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants