Skip to content

Commit 6b109a3

Browse files
[SLP]Initial support for non-power-of-2 (but still whole register) number of elements in operands.
Patch adds basic support for non-power-of-2 number of elements in operands. The patch still requires that this number addresses whole registers. Reviewers: RKSimon, preames Reviewed By: preames Pull Request: #107273
1 parent a514457 commit 6b109a3

File tree

3 files changed

+98
-34
lines changed

3 files changed

+98
-34
lines changed

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2538,7 +2538,19 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
25382538

25392539
unsigned getNumberOfParts(Type *Tp) {
25402540
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
2541-
return LT.first.isValid() ? *LT.first.getValue() : 0;
2541+
if (!LT.first.isValid())
2542+
return 0;
2543+
// Try to find actual number of parts for non-power-of-2 elements as
2544+
// ceil(num-of-elements/num-of-subtype-elements).
2545+
if (auto *FTp = dyn_cast<FixedVectorType>(Tp);
2546+
Tp && LT.second.isFixedLengthVector() &&
2547+
!has_single_bit(FTp->getNumElements())) {
2548+
if (auto *SubTp = dyn_cast_if_present<FixedVectorType>(
2549+
EVT(LT.second).getTypeForEVT(Tp->getContext()));
2550+
SubTp && SubTp->getElementType() == FTp->getElementType())
2551+
return divideCeil(FTp->getNumElements(), SubTp->getNumElements());
2552+
}
2553+
return *LT.first.getValue();
25422554
}
25432555

25442556
InstructionCost getAddressComputationCost(Type *Ty, ScalarEvolution *,

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,20 @@ static FixedVectorType *getWidenedType(Type *ScalarTy, unsigned VF) {
260260
VF * getNumElements(ScalarTy));
261261
}
262262

263+
/// Returns the number of elements of the given type \p Ty, not less than \p Sz,
264+
/// which forms type, which splits by \p TTI into whole vector types during
265+
/// legalization.
266+
static unsigned getFullVectorNumberOfElements(const TargetTransformInfo &TTI,
267+
Type *Ty, unsigned Sz) {
268+
if (!isValidElementType(Ty))
269+
return bit_ceil(Sz);
270+
// Find the number of elements, which forms full vectors.
271+
const unsigned NumParts = TTI.getNumberOfParts(getWidenedType(Ty, Sz));
272+
if (NumParts == 0 || NumParts >= Sz)
273+
return bit_ceil(Sz);
274+
return bit_ceil(divideCeil(Sz, NumParts)) * NumParts;
275+
}
276+
263277
static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
264278
SmallVectorImpl<int> &Mask) {
265279
// The ShuffleBuilder implementation use shufflevector to splat an "element".
@@ -394,7 +408,7 @@ static bool isVectorLikeInstWithConstOps(Value *V) {
394408
/// total number of elements \p Size and number of registers (parts) \p
395409
/// NumParts.
396410
static unsigned getPartNumElems(unsigned Size, unsigned NumParts) {
397-
return PowerOf2Ceil(divideCeil(Size, NumParts));
411+
return std::min<unsigned>(Size, bit_ceil(divideCeil(Size, NumParts)));
398412
}
399413

400414
/// Returns correct remaining number of elements, considering total amount \p
@@ -1222,6 +1236,22 @@ static bool doesNotNeedToSchedule(ArrayRef<Value *> VL) {
12221236
(all_of(VL, isUsedOutsideBlock) || all_of(VL, areAllOperandsNonInsts));
12231237
}
12241238

1239+
/// Returns true if widened type of \p Ty elements with size \p Sz represents
1240+
/// full vector type, i.e. adding extra element results in extra parts upon type
1241+
/// legalization.
1242+
static bool hasFullVectorsOrPowerOf2(const TargetTransformInfo &TTI, Type *Ty,
1243+
unsigned Sz) {
1244+
if (Sz <= 1)
1245+
return false;
1246+
if (!isValidElementType(Ty) && !isa<FixedVectorType>(Ty))
1247+
return false;
1248+
if (has_single_bit(Sz))
1249+
return true;
1250+
const unsigned NumParts = TTI.getNumberOfParts(getWidenedType(Ty, Sz));
1251+
return NumParts > 0 && NumParts < Sz && has_single_bit(Sz / NumParts) &&
1252+
Sz % NumParts == 0;
1253+
}
1254+
12251255
namespace slpvectorizer {
12261256

12271257
/// Bottom Up SLP Vectorizer.
@@ -3311,6 +3341,15 @@ class BoUpSLP {
33113341
/// Return true if this is a non-power-of-2 node.
33123342
bool isNonPowOf2Vec() const {
33133343
bool IsNonPowerOf2 = !has_single_bit(Scalars.size());
3344+
return IsNonPowerOf2;
3345+
}
3346+
3347+
/// Return true if this is a node, which tries to vectorize number of
3348+
/// elements, forming whole vectors.
3349+
bool
3350+
hasNonWholeRegisterOrNonPowerOf2Vec(const TargetTransformInfo &TTI) const {
3351+
bool IsNonPowerOf2 = !hasFullVectorsOrPowerOf2(
3352+
TTI, getValueType(Scalars.front()), Scalars.size());
33143353
assert((!IsNonPowerOf2 || ReuseShuffleIndices.empty()) &&
33153354
"Reshuffling not supported with non-power-of-2 vectors yet.");
33163355
return IsNonPowerOf2;
@@ -3430,8 +3469,10 @@ class BoUpSLP {
34303469
Last->State = EntryState;
34313470
// FIXME: Remove once support for ReuseShuffleIndices has been implemented
34323471
// for non-power-of-two vectors.
3433-
assert((has_single_bit(VL.size()) || ReuseShuffleIndices.empty()) &&
3434-
"Reshuffling scalars not yet supported for nodes with padding");
3472+
assert(
3473+
(hasFullVectorsOrPowerOf2(*TTI, getValueType(VL.front()), VL.size()) ||
3474+
ReuseShuffleIndices.empty()) &&
3475+
"Reshuffling scalars not yet supported for nodes with padding");
34353476
Last->ReuseShuffleIndices.append(ReuseShuffleIndices.begin(),
34363477
ReuseShuffleIndices.end());
34373478
if (ReorderIndices.empty()) {
@@ -5269,7 +5310,7 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
52695310
// node.
52705311
if (!TE.ReuseShuffleIndices.empty()) {
52715312
// FIXME: Support ReuseShuffleIndices for non-power-of-two vectors.
5272-
assert(!TE.isNonPowOf2Vec() &&
5313+
assert(!TE.hasNonWholeRegisterOrNonPowerOf2Vec(*TTI) &&
52735314
"Reshuffling scalars not yet supported for nodes with padding");
52745315

52755316
if (isSplat(TE.Scalars))
@@ -5509,7 +5550,7 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
55095550
}
55105551
// FIXME: Remove the non-power-of-two check once findReusedOrderedScalars
55115552
// has been auditted for correctness with non-power-of-two vectors.
5512-
if (!TE.isNonPowOf2Vec())
5553+
if (!TE.hasNonWholeRegisterOrNonPowerOf2Vec(*TTI))
55135554
if (std::optional<OrdersType> CurrentOrder = findReusedOrderedScalars(TE))
55145555
return CurrentOrder;
55155556
}
@@ -5662,15 +5703,18 @@ void BoUpSLP::reorderTopToBottom() {
56625703
});
56635704

56645705
// Reorder the graph nodes according to their vectorization factor.
5665-
for (unsigned VF = VectorizableTree.front()->getVectorFactor(); VF > 1;
5666-
VF = bit_ceil(VF) / 2) {
5706+
for (unsigned VF = VectorizableTree.front()->getVectorFactor();
5707+
!VFToOrderedEntries.empty() && VF > 1; VF -= 2 - (VF & 1U)) {
56675708
auto It = VFToOrderedEntries.find(VF);
56685709
if (It == VFToOrderedEntries.end())
56695710
continue;
56705711
// Try to find the most profitable order. We just are looking for the most
56715712
// used order and reorder scalar elements in the nodes according to this
56725713
// mostly used order.
56735714
ArrayRef<TreeEntry *> OrderedEntries = It->second.getArrayRef();
5715+
// Delete VF entry upon exit.
5716+
auto Cleanup = make_scope_exit([&]() { VFToOrderedEntries.erase(It); });
5717+
56745718
// All operands are reordered and used only in this node - propagate the
56755719
// most used order to the user node.
56765720
MapVector<OrdersType, unsigned,
@@ -7529,33 +7573,36 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
75297573
UniqueValues.emplace_back(V);
75307574
}
75317575
size_t NumUniqueScalarValues = UniqueValues.size();
7532-
if (NumUniqueScalarValues == VL.size()) {
7576+
bool IsFullVectors = hasFullVectorsOrPowerOf2(
7577+
*TTI, UniqueValues.front()->getType(), NumUniqueScalarValues);
7578+
if (NumUniqueScalarValues == VL.size() &&
7579+
(VectorizeNonPowerOf2 || IsFullVectors)) {
75337580
ReuseShuffleIndices.clear();
75347581
} else {
75357582
// FIXME: Reshuffing scalars is not supported yet for non-power-of-2 ops.
7536-
if ((UserTreeIdx.UserTE && UserTreeIdx.UserTE->isNonPowOf2Vec()) ||
7537-
!llvm::has_single_bit(VL.size())) {
7583+
if ((UserTreeIdx.UserTE &&
7584+
UserTreeIdx.UserTE->hasNonWholeRegisterOrNonPowerOf2Vec(*TTI)) ||
7585+
!has_single_bit(VL.size())) {
75387586
LLVM_DEBUG(dbgs() << "SLP: Reshuffling scalars not yet supported "
75397587
"for nodes with padding.\n");
75407588
newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
75417589
return false;
75427590
}
75437591
LLVM_DEBUG(dbgs() << "SLP: Shuffle for reused scalars.\n");
7544-
if (NumUniqueScalarValues <= 1 ||
7545-
(UniquePositions.size() == 1 && all_of(UniqueValues,
7546-
[](Value *V) {
7547-
return isa<UndefValue>(V) ||
7548-
!isConstant(V);
7549-
})) ||
7550-
!llvm::has_single_bit<uint32_t>(NumUniqueScalarValues)) {
7592+
if (NumUniqueScalarValues <= 1 || !IsFullVectors ||
7593+
(UniquePositions.size() == 1 && all_of(UniqueValues, [](Value *V) {
7594+
return isa<UndefValue>(V) || !isConstant(V);
7595+
}))) {
75517596
if (DoNotFail && UniquePositions.size() > 1 &&
75527597
NumUniqueScalarValues > 1 && S.MainOp->isSafeToRemove() &&
75537598
all_of(UniqueValues, [=](Value *V) {
75547599
return isa<ExtractElementInst>(V) ||
75557600
areAllUsersVectorized(cast<Instruction>(V),
75567601
UserIgnoreList);
75577602
})) {
7558-
unsigned PWSz = PowerOf2Ceil(UniqueValues.size());
7603+
// Find the number of elements, which forms full vectors.
7604+
unsigned PWSz = getFullVectorNumberOfElements(
7605+
*TTI, UniqueValues.front()->getType(), UniqueValues.size());
75597606
if (PWSz == VL.size()) {
75607607
ReuseShuffleIndices.clear();
75617608
} else {
@@ -9793,9 +9840,6 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
97939840
return nullptr;
97949841
Value *VecBase = nullptr;
97959842
ArrayRef<Value *> VL = E->Scalars;
9796-
// If the resulting type is scalarized, do not adjust the cost.
9797-
if (NumParts == VL.size())
9798-
return nullptr;
97999843
// Check if it can be considered reused if same extractelements were
98009844
// vectorized already.
98019845
bool PrevNodeFound = any_of(
@@ -10450,7 +10494,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1045010494
InsertMask[Idx] = I + 1;
1045110495
}
1045210496
unsigned VecScalarsSz = PowerOf2Ceil(NumElts);
10453-
if (NumOfParts > 0)
10497+
if (NumOfParts > 0 && NumOfParts < NumElts)
1045410498
VecScalarsSz = PowerOf2Ceil((NumElts + NumOfParts - 1) / NumOfParts);
1045510499
unsigned VecSz = (1 + OffsetEnd / VecScalarsSz - OffsetBeg / VecScalarsSz) *
1045610500
VecScalarsSz;
@@ -17785,7 +17829,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
1778517829
for (unsigned I = NextInst; I < MaxInst; ++I) {
1778617830
unsigned ActualVF = std::min(MaxInst - I, VF);
1778717831

17788-
if (!has_single_bit(ActualVF))
17832+
if (!hasFullVectorsOrPowerOf2(*TTI, ScalarTy, ActualVF))
1778917833
continue;
1779017834

1779117835
if (MaxVFOnly && ActualVF < MaxVF)

llvm/test/Transforms/SLPVectorizer/reduction-whole-regs-loads.ll

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2-
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=riscv64-unknown-linux -mattr=+v -slp-threshold=-100 | FileCheck %s
2+
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=riscv64-unknown-linux -mattr=+v -slp-threshold=-100 | FileCheck %s --check-prefix=RISCV
33
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=x86_64-unknown-linux -slp-threshold=-100 | FileCheck %s
44
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=aarch64-unknown-linux -slp-threshold=-100 | FileCheck %s
55
; REQUIRES: aarch64-registered-target, x86-registered-target, riscv-registered-target
66

77
define i64 @test(ptr %p) {
8+
; RISCV-LABEL: @test(
9+
; RISCV-NEXT: entry:
10+
; RISCV-NEXT: [[ARRAYIDX_4:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i64 4
11+
; RISCV-NEXT: [[TMP0:%.*]] = load <4 x i64>, ptr [[P]], align 4
12+
; RISCV-NEXT: [[TMP1:%.*]] = load <2 x i64>, ptr [[ARRAYIDX_4]], align 4
13+
; RISCV-NEXT: [[TMP2:%.*]] = shufflevector <4 x i64> [[TMP0]], <4 x i64> poison, <8 x i32> <i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 0, i32 0>
14+
; RISCV-NEXT: [[TMP3:%.*]] = call <8 x i64> @llvm.vector.insert.v8i64.v4i64(<8 x i64> [[TMP2]], <4 x i64> [[TMP0]], i64 0)
15+
; RISCV-NEXT: [[TMP4:%.*]] = call <8 x i64> @llvm.vector.insert.v8i64.v2i64(<8 x i64> [[TMP3]], <2 x i64> [[TMP1]], i64 4)
16+
; RISCV-NEXT: [[TMP5:%.*]] = mul <8 x i64> [[TMP4]], <i64 42, i64 42, i64 42, i64 42, i64 42, i64 42, i64 42, i64 42>
17+
; RISCV-NEXT: [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
18+
; RISCV-NEXT: ret i64 [[TMP6]]
19+
;
820
; CHECK-LABEL: @test(
921
; CHECK-NEXT: entry:
10-
; CHECK-NEXT: [[ARRAYIDX_4:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i64 4
11-
; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i64>, ptr [[P]], align 4
12-
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x i64>, ptr [[ARRAYIDX_4]], align 4
13-
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i64> [[TMP0]], <4 x i64> poison, <8 x i32> <i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 0, i32 0>
14-
; CHECK-NEXT: [[TMP3:%.*]] = call <8 x i64> @llvm.vector.insert.v8i64.v4i64(<8 x i64> [[TMP2]], <4 x i64> [[TMP0]], i64 0)
15-
; CHECK-NEXT: [[TMP4:%.*]] = call <8 x i64> @llvm.vector.insert.v8i64.v2i64(<8 x i64> [[TMP3]], <2 x i64> [[TMP1]], i64 4)
16-
; CHECK-NEXT: [[TMP5:%.*]] = mul <8 x i64> [[TMP4]], <i64 42, i64 42, i64 42, i64 42, i64 42, i64 42, i64 42, i64 42>
17-
; CHECK-NEXT: [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
18-
; CHECK-NEXT: ret i64 [[TMP6]]
22+
; CHECK-NEXT: [[TMP0:%.*]] = load <6 x i64>, ptr [[P:%.*]], align 4
23+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <6 x i64> [[TMP0]], <6 x i64> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 0, i32 0>
24+
; CHECK-NEXT: [[TMP2:%.*]] = mul <8 x i64> [[TMP1]], <i64 42, i64 42, i64 42, i64 42, i64 42, i64 42, i64 42, i64 42>
25+
; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP2]])
26+
; CHECK-NEXT: ret i64 [[TMP3]]
1927
;
2028
entry:
2129
%arrayidx.1 = getelementptr inbounds i64, ptr %p, i64 1

0 commit comments

Comments
 (0)