Skip to content

Commit a693f23

Browse files
authored
[SLP][REVEC] Fix CompressVectorize does not expand mask when REVEC is enabled. (#135174)
1 parent 85614e1 commit a693f23

File tree

2 files changed

+64
-18
lines changed

2 files changed

+64
-18
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,17 @@ static SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, Type *ScalarTy,
12781278
return OpcodeMask;
12791279
}
12801280

1281+
/// Replicates the given \p Val \p VF times.
1282+
static SmallVector<Constant *> replicateMask(ArrayRef<Constant *> Val,
1283+
unsigned VF) {
1284+
assert(none_of(Val, [](Constant *C) { return C->getType()->isVectorTy(); }) &&
1285+
"Expected scalar constants.");
1286+
SmallVector<Constant *> NewVal(Val.size() * VF);
1287+
for (auto [I, V] : enumerate(Val))
1288+
std::fill_n(NewVal.begin() + I * VF, VF, V);
1289+
return NewVal;
1290+
}
1291+
12811292
namespace llvm {
12821293

12831294
static void inversePermutation(ArrayRef<unsigned> Indices,
@@ -12202,32 +12213,24 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
1220212213
unsigned VF = VL.size();
1220312214
if (MaskVF != 0)
1220412215
VF = std::min(VF, MaskVF);
12216+
Type *VLScalarTy = VL.front()->getType();
1220512217
for (Value *V : VL.take_front(VF)) {
12218+
Type *ScalarTy = VLScalarTy->getScalarType();
12219+
if (isa<PoisonValue>(V)) {
12220+
Vals.push_back(PoisonValue::get(ScalarTy));
12221+
continue;
12222+
}
1220612223
if (isa<UndefValue>(V)) {
12207-
Vals.push_back(cast<Constant>(V));
12224+
Vals.push_back(UndefValue::get(ScalarTy));
1220812225
continue;
1220912226
}
12210-
Vals.push_back(Constant::getNullValue(V->getType()));
12227+
Vals.push_back(Constant::getNullValue(ScalarTy));
1221112228
}
12212-
if (auto *VecTy = dyn_cast<FixedVectorType>(Vals.front()->getType())) {
12229+
if (auto *VecTy = dyn_cast<FixedVectorType>(VLScalarTy)) {
1221312230
assert(SLPReVec && "FixedVectorType is not expected.");
1221412231
// When REVEC is enabled, we need to expand vector types into scalar
1221512232
// types.
12216-
unsigned VecTyNumElements = VecTy->getNumElements();
12217-
SmallVector<Constant *> NewVals(VF * VecTyNumElements, nullptr);
12218-
for (auto [I, V] : enumerate(Vals)) {
12219-
Type *ScalarTy = V->getType()->getScalarType();
12220-
Constant *NewVal;
12221-
if (isa<PoisonValue>(V))
12222-
NewVal = PoisonValue::get(ScalarTy);
12223-
else if (isa<UndefValue>(V))
12224-
NewVal = UndefValue::get(ScalarTy);
12225-
else
12226-
NewVal = Constant::getNullValue(ScalarTy);
12227-
std::fill_n(NewVals.begin() + I * VecTyNumElements, VecTyNumElements,
12228-
NewVal);
12229-
}
12230-
Vals.swap(NewVals);
12233+
Vals = replicateMask(Vals, VecTy->getNumElements());
1223112234
}
1223212235
return ConstantVector::get(Vals);
1223312236
}
@@ -17610,6 +17613,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1761017613
ConstantInt::getFalse(VecTy->getContext()));
1761117614
for (int I : CompressMask)
1761217615
MaskValues[I] = ConstantInt::getTrue(VecTy->getContext());
17616+
if (auto *VecTy = dyn_cast<FixedVectorType>(LI->getType())) {
17617+
assert(SLPReVec && "Only supported by REVEC.");
17618+
MaskValues = replicateMask(MaskValues, VecTy->getNumElements());
17619+
}
1761317620
Constant *MaskValue = ConstantVector::get(MaskValues);
1761417621
NewLI = Builder.CreateMaskedLoad(LoadVecTy, PO, CommonAlignment,
1761517622
MaskValue);
@@ -17618,6 +17625,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1761817625
}
1761917626
NewLI = ::propagateMetadata(NewLI, E->Scalars);
1762017627
// TODO: include this cost into CommonCost.
17628+
if (auto *VecTy = dyn_cast<FixedVectorType>(LI->getType())) {
17629+
assert(SLPReVec && "FixedVectorType is not expected.");
17630+
transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
17631+
CompressMask);
17632+
}
1762117633
NewLI =
1762217634
cast<Instruction>(Builder.CreateShuffleVector(NewLI, CompressMask));
1762317635
} else if (E->State == TreeEntry::StridedVectorize) {
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -mtriple=x86_64-unknown-linux-gnu -mattr=+avx10.1-512 -passes=slp-vectorizer -S -slp-revec < %s | FileCheck %s
3+
4+
define void @test(ptr %in) {
5+
; CHECK-LABEL: @test(
6+
; CHECK-NEXT: entry:
7+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[IN:%.*]], i64 32
8+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[IN]], i64 64
9+
; CHECK-NEXT: [[TMP2:%.*]] = call <32 x i16> @llvm.masked.load.v32i16.p0(ptr [[TMP1]], i32 2, <32 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <32 x i16> poison)
10+
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <32 x i16> [[TMP2]], <32 x i16> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
11+
; CHECK-NEXT: [[TMP4:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v8i32(<16 x i32> poison, <8 x i32> zeroinitializer, i64 0)
12+
; CHECK-NEXT: [[TMP5:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v8i32(<16 x i32> [[TMP4]], <8 x i32> zeroinitializer, i64 8)
13+
; CHECK-NEXT: [[TMP6:%.*]] = trunc <16 x i32> [[TMP5]] to <16 x i16>
14+
; CHECK-NEXT: [[TMP7:%.*]] = or <16 x i16> [[TMP6]], [[TMP3]]
15+
; CHECK-NEXT: store <16 x i16> [[TMP7]], ptr [[TMP0]], align 2
16+
; CHECK-NEXT: ret void
17+
;
18+
entry:
19+
%0 = getelementptr i8, ptr %in, i64 112
20+
%wide.load = load <8 x i16>, ptr %0, align 2
21+
%1 = sext <8 x i16> %wide.load to <8 x i32>
22+
%2 = getelementptr i8, ptr %in, i64 48
23+
%3 = or <8 x i32> zeroinitializer, %1
24+
%4 = getelementptr i8, ptr %in, i64 32
25+
%5 = getelementptr i8, ptr %in, i64 64
26+
%wide.load155 = load <8 x i16>, ptr %5, align 2
27+
%6 = sext <8 x i16> %wide.load155 to <8 x i32>
28+
%7 = or <8 x i32> zeroinitializer, %6
29+
%8 = trunc <8 x i32> %3 to <8 x i16>
30+
store <8 x i16> %8, ptr %2, align 2
31+
%9 = trunc <8 x i32> %7 to <8 x i16>
32+
store <8 x i16> %9, ptr %4, align 2
33+
ret void
34+
}

0 commit comments

Comments
 (0)