Skip to content

Commit 1813ffd

Browse files
authored
[SLP][REVEC] Make SLP support revectorization (-slp-revec) and add simple test. (#98269)
This PR will make SLP support revectorization. Add an option -slp-revec to control the functionality. reference: https://discourse.llvm.org/t/rfc-make-slp-vectorizer-revectorize-vector-instructions/79436
1 parent fa0e529 commit 1813ffd

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ static cl::opt<bool>
113113
RunSLPVectorization("vectorize-slp", cl::init(true), cl::Hidden,
114114
cl::desc("Run the SLP vectorization passes"));
115115

116+
static cl::opt<bool>
117+
SLPReVec("slp-revec", cl::init(false), cl::Hidden,
118+
cl::desc("Enable vectorization for wider vector utilization"));
119+
116120
static cl::opt<int>
117121
SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden,
118122
cl::desc("Only vectorize if you gain more than this "
@@ -227,13 +231,26 @@ static const unsigned MaxPHINumOperands = 128;
227231
/// avoids spending time checking the cost model and realizing that they will
228232
/// be inevitably scalarized.
229233
static bool isValidElementType(Type *Ty) {
234+
// TODO: Support ScalableVectorType.
235+
if (SLPReVec && isa<FixedVectorType>(Ty))
236+
Ty = Ty->getScalarType();
230237
return VectorType::isValidElementType(Ty) && !Ty->isX86_FP80Ty() &&
231238
!Ty->isPPC_FP128Ty();
232239
}
233240

241+
/// \returns the number of elements for Ty.
242+
static unsigned getNumElements(Type *Ty) {
243+
assert(!isa<ScalableVectorType>(Ty) &&
244+
"ScalableVectorType is not supported.");
245+
if (auto *VecTy = dyn_cast<FixedVectorType>(Ty))
246+
return VecTy->getNumElements();
247+
return 1;
248+
}
249+
234250
/// \returns the vector type of ScalarTy based on vectorization factor.
235251
static FixedVectorType *getWidenedType(Type *ScalarTy, unsigned VF) {
236-
return FixedVectorType::get(ScalarTy, VF);
252+
return FixedVectorType::get(ScalarTy->getScalarType(),
253+
VF * getNumElements(ScalarTy));
237254
}
238255

239256
/// \returns True if the value is a constant (but not globals/constant
@@ -6779,15 +6796,15 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
67796796
}
67806797

67816798
// Don't handle vectors.
6782-
if (S.OpValue->getType()->isVectorTy() &&
6799+
if (!SLPReVec && S.OpValue->getType()->isVectorTy() &&
67836800
!isa<InsertElementInst>(S.OpValue)) {
67846801
LLVM_DEBUG(dbgs() << "SLP: Gathering due to vector type.\n");
67856802
newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
67866803
return;
67876804
}
67886805

67896806
if (StoreInst *SI = dyn_cast<StoreInst>(S.OpValue))
6790-
if (SI->getValueOperand()->getType()->isVectorTy()) {
6807+
if (!SLPReVec && SI->getValueOperand()->getType()->isVectorTy()) {
67916808
LLVM_DEBUG(dbgs() << "SLP: Gathering due to store vector type.\n");
67926809
newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
67936810
return;
@@ -11833,10 +11850,12 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1183311850
Value *castToScalarTyElem(Value *V,
1183411851
std::optional<bool> IsSigned = std::nullopt) {
1183511852
auto *VecTy = cast<VectorType>(V->getType());
11836-
if (VecTy->getElementType() == ScalarTy)
11853+
assert(getNumElements(ScalarTy) < getNumElements(VecTy) &&
11854+
(getNumElements(VecTy) % getNumElements(ScalarTy) == 0));
11855+
if (VecTy->getElementType() == ScalarTy->getScalarType())
1183711856
return V;
1183811857
return Builder.CreateIntCast(
11839-
V, VectorType::get(ScalarTy, VecTy->getElementCount()),
11858+
V, VectorType::get(ScalarTy->getScalarType(), VecTy->getElementCount()),
1184011859
IsSigned.value_or(!isKnownNonNegative(V, SimplifyQuery(*R.DL))));
1184111860
}
1184211861

@@ -12221,7 +12240,8 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx,
1222112240
return ShuffleBuilder.finalize(std::nullopt);
1222212241
};
1222312242
Value *V = vectorizeTree(VE, PostponedPHIs);
12224-
if (VF != cast<FixedVectorType>(V->getType())->getNumElements()) {
12243+
if (VF * getNumElements(VL[0]->getType()) !=
12244+
cast<FixedVectorType>(V->getType())->getNumElements()) {
1222512245
if (!VE->ReuseShuffleIndices.empty()) {
1222612246
// Reshuffle to get only unique values.
1222712247
// If some of the scalars are duplicated in the vectorization
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes=slp-vectorizer -S -slp-revec -slp-max-reg-size=1024 -slp-threshold=-100 %s | FileCheck %s
3+
4+
define void @test1(ptr %a, ptr %b, ptr %c) {
5+
; CHECK-LABEL: @test1(
6+
; CHECK-NEXT: entry:
7+
; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i32>, ptr [[A:%.*]], align 4
8+
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i32>, ptr [[B:%.*]], align 4
9+
; CHECK-NEXT: [[TMP2:%.*]] = add <16 x i32> [[TMP1]], [[TMP0]]
10+
; CHECK-NEXT: store <16 x i32> [[TMP2]], ptr [[C:%.*]], align 4
11+
; CHECK-NEXT: ret void
12+
;
13+
entry:
14+
%arrayidx3 = getelementptr inbounds i32, ptr %a, i64 4
15+
%arrayidx7 = getelementptr inbounds i32, ptr %a, i64 8
16+
%arrayidx11 = getelementptr inbounds i32, ptr %a, i64 12
17+
%0 = load <4 x i32>, ptr %a, align 4
18+
%1 = load <4 x i32>, ptr %arrayidx3, align 4
19+
%2 = load <4 x i32>, ptr %arrayidx7, align 4
20+
%3 = load <4 x i32>, ptr %arrayidx11, align 4
21+
%arrayidx19 = getelementptr inbounds i32, ptr %b, i64 4
22+
%arrayidx23 = getelementptr inbounds i32, ptr %b, i64 8
23+
%arrayidx27 = getelementptr inbounds i32, ptr %b, i64 12
24+
%4 = load <4 x i32>, ptr %b, align 4
25+
%5 = load <4 x i32>, ptr %arrayidx19, align 4
26+
%6 = load <4 x i32>, ptr %arrayidx23, align 4
27+
%7 = load <4 x i32>, ptr %arrayidx27, align 4
28+
%add.i = add <4 x i32> %4, %0
29+
%add.i63 = add <4 x i32> %5, %1
30+
%add.i64 = add <4 x i32> %6, %2
31+
%add.i65 = add <4 x i32> %7, %3
32+
%arrayidx36 = getelementptr inbounds i32, ptr %c, i64 4
33+
%arrayidx39 = getelementptr inbounds i32, ptr %c, i64 8
34+
%arrayidx42 = getelementptr inbounds i32, ptr %c, i64 12
35+
store <4 x i32> %add.i, ptr %c, align 4
36+
store <4 x i32> %add.i63, ptr %arrayidx36, align 4
37+
store <4 x i32> %add.i64, ptr %arrayidx39, align 4
38+
store <4 x i32> %add.i65, ptr %arrayidx42, align 4
39+
ret void
40+
}

0 commit comments

Comments
 (0)