Skip to content

[SLP]Unify getNumberOfParts use #124774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

alexey-bataev
Copy link
Member

Adds getNumberOfParts and uses it instead of similar code across code
base, fixes analysis of non-vectorizable types in
computeMinimumValueSizes.

Created using spr 1.3.5
@llvmbot
Copy link
Member

llvmbot commented Jan 28, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-vectorizers

Author: Alexey Bataev (alexey-bataev)

Changes

Adds getNumberOfParts and uses it instead of similar code across code
base, fixes analysis of non-vectorizable types in
computeMinimumValueSizes.


Full diff: https://github.com/llvm/llvm-project/pull/124774.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+52-52)
  • (modified) llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll (+10-2)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f73ad1b15891a3..7f86651256e577 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1314,6 +1314,22 @@ static bool hasFullVectorsOrPowerOf2(const TargetTransformInfo &TTI, Type *Ty,
          Sz % NumParts == 0;
 }
 
+/// Returns number of parts, the type \p VecTy will be split at the codegen
+/// phase. If the type is going to be scalarized or does not uses whole
+/// registers, returns 1.
+static unsigned
+getNumberOfParts(const TargetTransformInfo &TTI, VectorType *VecTy,
+                 const unsigned Limit = std::numeric_limits<unsigned>::max()) {
+  unsigned NumParts = TTI.getNumberOfParts(VecTy);
+  if (NumParts == 0 || NumParts >= Limit)
+    return 1;
+  unsigned Sz = getNumElements(VecTy);
+  if (NumParts >= Sz || Sz % NumParts != 0 ||
+      !hasFullVectorsOrPowerOf2(TTI, VecTy->getElementType(), Sz / NumParts))
+    return 1;
+  return NumParts;
+}
+
 namespace slpvectorizer {
 
 /// Bottom Up SLP Vectorizer.
@@ -4618,12 +4634,7 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) {
   if (!isValidElementType(ScalarTy))
     return std::nullopt;
   auto *VecTy = getWidenedType(ScalarTy, NumScalars);
-  int NumParts = TTI->getNumberOfParts(VecTy);
-  if (NumParts == 0 || NumParts >= NumScalars ||
-      VecTy->getNumElements() % NumParts != 0 ||
-      !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(),
-                                VecTy->getNumElements() / NumParts))
-    NumParts = 1;
+  unsigned NumParts = ::getNumberOfParts(*TTI, VecTy, NumScalars);
   SmallVector<int> ExtractMask;
   SmallVector<int> Mask;
   SmallVector<SmallVector<const TreeEntry *>> Entries;
@@ -5574,8 +5585,8 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
       }
     }
     if (Sz == 2 && TE.getVectorFactor() == 4 &&
-        TTI->getNumberOfParts(getWidenedType(TE.Scalars.front()->getType(),
-                                             2 * TE.getVectorFactor())) == 1)
+        ::getNumberOfParts(*TTI, getWidenedType(TE.Scalars.front()->getType(),
+                                                2 * TE.getVectorFactor())) == 1)
       return std::nullopt;
     if (!ShuffleVectorInst::isOneUseSingleSourceMask(TE.ReuseShuffleIndices,
                                                      Sz)) {
@@ -9847,12 +9858,15 @@ void BoUpSLP::transformNodes() {
           // only with the single non-undef element).
           bool IsSplat = isSplat(Slice);
           if (Slices.empty() || !IsSplat ||
-              (VF <= 2 && 2 * std::clamp(TTI->getNumberOfParts(getWidenedType(
-                                             Slice.front()->getType(), VF)),
-                                         1U, VF - 1) !=
-                              std::clamp(TTI->getNumberOfParts(getWidenedType(
-                                             Slice.front()->getType(), 2 * VF)),
-                                         1U, 2 * VF)) ||
+              (VF <= 2 &&
+               2 * std::clamp(
+                       ::getNumberOfParts(
+                           *TTI, getWidenedType(Slice.front()->getType(), VF)),
+                       1U, VF - 1) !=
+                   std::clamp(::getNumberOfParts(
+                                  *TTI, getWidenedType(Slice.front()->getType(),
+                                                       2 * VF)),
+                              1U, 2 * VF)) ||
               count(Slice, Slice.front()) ==
                   static_cast<long>(isa<UndefValue>(Slice.front()) ? VF - 1
                                                                    : 1)) {
@@ -10793,12 +10807,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     }
     assert(!CommonMask.empty() && "Expected non-empty common mask.");
     auto *MaskVecTy = getWidenedType(ScalarTy, Mask.size());
-    unsigned NumParts = TTI.getNumberOfParts(MaskVecTy);
-    if (NumParts == 0 || NumParts >= Mask.size() ||
-        MaskVecTy->getNumElements() % NumParts != 0 ||
-        !hasFullVectorsOrPowerOf2(TTI, MaskVecTy->getElementType(),
-                                  MaskVecTy->getNumElements() / NumParts))
-      NumParts = 1;
+    unsigned NumParts = ::getNumberOfParts(TTI, MaskVecTy, Mask.size());
     unsigned SliceSize = getPartNumElems(Mask.size(), NumParts);
     const auto *It =
         find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; });
@@ -10813,12 +10822,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     }
     assert(!CommonMask.empty() && "Expected non-empty common mask.");
     auto *MaskVecTy = getWidenedType(ScalarTy, Mask.size());
-    unsigned NumParts = TTI.getNumberOfParts(MaskVecTy);
-    if (NumParts == 0 || NumParts >= Mask.size() ||
-        MaskVecTy->getNumElements() % NumParts != 0 ||
-        !hasFullVectorsOrPowerOf2(TTI, MaskVecTy->getElementType(),
-                                  MaskVecTy->getNumElements() / NumParts))
-      NumParts = 1;
+    unsigned NumParts = ::getNumberOfParts(TTI, MaskVecTy, Mask.size());
     unsigned SliceSize = getPartNumElems(Mask.size(), NumParts);
     const auto *It =
         find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; });
@@ -11351,7 +11355,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     unsigned const NumElts = SrcVecTy->getNumElements();
     unsigned const NumScalars = VL.size();
 
-    unsigned NumOfParts = TTI->getNumberOfParts(SrcVecTy);
+    unsigned NumOfParts = ::getNumberOfParts(*TTI, SrcVecTy);
 
     SmallVector<int> InsertMask(NumElts, PoisonMaskElem);
     unsigned OffsetBeg = *getElementIndex(VL.front());
@@ -14862,12 +14866,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
   SmallVector<SmallVector<const TreeEntry *>> Entries;
   Type *OrigScalarTy = GatheredScalars.front()->getType();
   auto *VecTy = getWidenedType(ScalarTy, GatheredScalars.size());
-  unsigned NumParts = TTI->getNumberOfParts(VecTy);
-  if (NumParts == 0 || NumParts >= GatheredScalars.size() ||
-      VecTy->getNumElements() % NumParts != 0 ||
-      !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(),
-                                VecTy->getNumElements() / NumParts))
-    NumParts = 1;
+  unsigned NumParts = ::getNumberOfParts(*TTI, VecTy, GatheredScalars.size());
   if (!all_of(GatheredScalars, IsaPred<UndefValue>)) {
     // Check for gathered extracts.
     bool Resized = false;
@@ -14899,12 +14898,8 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
             Resized = true;
             GatheredScalars.append(VF - GatheredScalars.size(),
                                    PoisonValue::get(OrigScalarTy));
-            NumParts = TTI->getNumberOfParts(getWidenedType(OrigScalarTy, VF));
-            if (NumParts == 0 || NumParts >= GatheredScalars.size() ||
-                VecTy->getNumElements() % NumParts != 0 ||
-                !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(),
-                                          VecTy->getNumElements() / NumParts))
-              NumParts = 1;
+            NumParts =
+                ::getNumberOfParts(*TTI, getWidenedType(OrigScalarTy, VF), VF);
           }
       }
     }
@@ -17049,10 +17044,10 @@ void BoUpSLP::optimizeGatherSequence() {
     // Check if the last undefs actually change the final number of used vector
     // registers.
     return SM1.size() - LastUndefsCnt > 1 &&
-           TTI->getNumberOfParts(SI1->getType()) ==
-               TTI->getNumberOfParts(
-                   getWidenedType(SI1->getType()->getElementType(),
-                                  SM1.size() - LastUndefsCnt));
+           ::getNumberOfParts(*TTI, SI1->getType()) ==
+               ::getNumberOfParts(
+                   *TTI, getWidenedType(SI1->getType()->getElementType(),
+                                        SM1.size() - LastUndefsCnt));
   };
   // Perform O(N^2) search over the gather/shuffle sequences and merge identical
   // instructions. TODO: We can further optimize this scan if we split the
@@ -17829,9 +17824,12 @@ bool BoUpSLP::collectValuesToDemote(
       const unsigned VF = E.Scalars.size();
       Type *OrigScalarTy = E.Scalars.front()->getType();
       if (UniqueBases.size() <= 2 ||
-          TTI->getNumberOfParts(getWidenedType(OrigScalarTy, VF)) ==
-              TTI->getNumberOfParts(getWidenedType(
-                  IntegerType::get(OrigScalarTy->getContext(), BitWidth), VF)))
+          ::getNumberOfParts(*TTI, getWidenedType(OrigScalarTy, VF)) ==
+              ::getNumberOfParts(
+                  *TTI,
+                  getWidenedType(
+                      IntegerType::get(OrigScalarTy->getContext(), BitWidth),
+                      VF)))
         ToDemote.push_back(E.Idx);
     }
     return Res;
@@ -18241,8 +18239,8 @@ void BoUpSLP::computeMinimumValueSizes() {
                [&](Value *V) { return AnalyzedMinBWVals.contains(V); }))
       return 0u;
 
-    unsigned NumParts = TTI->getNumberOfParts(
-        getWidenedType(TreeRootIT, VF * ScalarTyNumElements));
+    unsigned NumParts = ::getNumberOfParts(
+        *TTI, getWidenedType(TreeRootIT, VF * ScalarTyNumElements));
 
     // The maximum bit width required to represent all the values that can be
     // demoted without loss of precision. It would be safe to truncate the roots
@@ -18302,8 +18300,10 @@ void BoUpSLP::computeMinimumValueSizes() {
     // use - ignore it.
     if (NumParts > 1 &&
         NumParts ==
-            TTI->getNumberOfParts(getWidenedType(
-                IntegerType::get(F->getContext(), bit_ceil(MaxBitWidth)), VF)))
+            ::getNumberOfParts(
+                *TTI, getWidenedType(IntegerType::get(F->getContext(),
+                                                      bit_ceil(MaxBitWidth)),
+                                     VF)))
       return 0u;
 
     unsigned Opcode = E.getOpcode();
@@ -20086,14 +20086,14 @@ class HorizontalReduction {
         ReduxWidth =
             getFloorFullVectorNumberOfElements(TTI, ScalarTy, ReduxWidth);
         VectorType *Tp = getWidenedType(ScalarTy, ReduxWidth);
-        NumParts = TTI.getNumberOfParts(Tp);
+        NumParts = ::getNumberOfParts(TTI, Tp);
         NumRegs =
             TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true, Tp));
         while (NumParts > NumRegs) {
           assert(ReduxWidth > 0 && "ReduxWidth is unexpectedly 0.");
           ReduxWidth = bit_floor(ReduxWidth - 1);
           VectorType *Tp = getWidenedType(ScalarTy, ReduxWidth);
-          NumParts = TTI.getNumberOfParts(Tp);
+          NumParts = ::getNumberOfParts(TTI, Tp);
           NumRegs =
               TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true, Tp));
         }
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll
index 6388cc2dedc73a..085d7a64fc9ac9 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll
@@ -7,9 +7,17 @@ define void @partial_vec_invalid_cost() #0 {
 ; CHECK-LABEL: define void @partial_vec_invalid_cost(
 ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> zeroinitializer)
+; CHECK-NEXT:    [[LSHR_1:%.*]] = lshr i96 0, 0
+; CHECK-NEXT:    [[LSHR_2:%.*]] = lshr i96 0, 0
+; CHECK-NEXT:    [[TRUNC_I96_1:%.*]] = trunc i96 [[LSHR_1]] to i32
+; CHECK-NEXT:    [[TRUNC_I96_2:%.*]] = trunc i96 [[LSHR_2]] to i32
+; CHECK-NEXT:    [[TRUNC_I96_3:%.*]] = trunc i96 0 to i32
+; CHECK-NEXT:    [[TRUNC_I96_4:%.*]] = trunc i96 0 to i32
 ; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> zeroinitializer)
-; CHECK-NEXT:    [[OP_RDX3:%.*]] = or i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[OP_RDX:%.*]] = or i32 [[TMP1]], [[TRUNC_I96_1]]
+; CHECK-NEXT:    [[OP_RDX1:%.*]] = or i32 [[TRUNC_I96_2]], [[TRUNC_I96_3]]
+; CHECK-NEXT:    [[OP_RDX2:%.*]] = or i32 [[OP_RDX]], [[OP_RDX1]]
+; CHECK-NEXT:    [[OP_RDX3:%.*]] = or i32 [[OP_RDX2]], [[TRUNC_I96_4]]
 ; CHECK-NEXT:    [[STORE_THIS:%.*]] = zext i32 [[OP_RDX3]] to i96
 ; CHECK-NEXT:    store i96 [[STORE_THIS]], ptr null, align 16
 ; CHECK-NEXT:    ret void

@alexey-bataev alexey-bataev requested a review from RKSimon January 28, 2025 15:56
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one minor

std::clamp(::getNumberOfParts(
*TTI, getWidenedType(Slice.front()->getType(),
2 * VF)),
1U, 2 * VF)) ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty horrible to read - any chance it can be cleaned up?

Created using spr 1.3.5
@alexey-bataev alexey-bataev merged commit 947d8eb into main Jan 28, 2025
5 of 7 checks passed
@alexey-bataev alexey-bataev deleted the users/alexey-bataev/spr/slpunify-getnumberofparts-use branch January 28, 2025 17:16
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 28, 2025
Adds getNumberOfParts and uses it instead of similar code across code
base, fixes analysis of non-vectorizable types in
computeMinimumValueSizes.

Reviewers: RKSimon

Reviewed By: RKSimon

Pull Request: llvm/llvm-project#124774
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants