Skip to content

Commit ce8ec31

Browse files
authored
[SLP][REVEC] Support more mask pattern usage in shufflevector. (llvm#106212)
1 parent b74e09c commit ce8ec31

File tree

2 files changed

+76
-27
lines changed

2 files changed

+76
-27
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,7 @@ static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
293293
/// A group has the following features
294294
/// 1. All of value in a group are shufflevector.
295295
/// 2. The mask of all shufflevector is isExtractSubvectorMask.
296-
/// 3. The mask of all shufflevector uses all of the elements of the source (and
297-
/// the elements are used in order).
296+
/// 3. The mask of all shufflevector uses all of the elements of the source.
298297
/// e.g., it is 1 group (%0)
299298
/// %1 = shufflevector <16 x i8> %0, <16 x i8> poison,
300299
/// <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
@@ -322,15 +321,16 @@ static unsigned getShufflevectorNumGroups(ArrayRef<Value *> VL) {
322321
auto *SV = cast<ShuffleVectorInst>(VL.front());
323322
unsigned SVNumElements =
324323
cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
325-
unsigned GroupSize = SVNumElements / SV->getShuffleMask().size();
324+
unsigned ShuffleMaskSize = SV->getShuffleMask().size();
325+
unsigned GroupSize = SVNumElements / ShuffleMaskSize;
326326
if (GroupSize == 0 || (VL.size() % GroupSize) != 0)
327327
return 0;
328328
unsigned NumGroup = 0;
329329
for (size_t I = 0, E = VL.size(); I != E; I += GroupSize) {
330330
auto *SV = cast<ShuffleVectorInst>(VL[I]);
331331
Value *Src = SV->getOperand(0);
332332
ArrayRef<Value *> Group = VL.slice(I, GroupSize);
333-
SmallVector<int> ExtractionIndex(SVNumElements);
333+
SmallBitVector ExpectedIndex(GroupSize);
334334
if (!all_of(Group, [&](Value *V) {
335335
auto *SV = cast<ShuffleVectorInst>(V);
336336
// From the same source.
@@ -339,12 +339,11 @@ static unsigned getShufflevectorNumGroups(ArrayRef<Value *> VL) {
339339
int Index;
340340
if (!SV->isExtractSubvectorMask(Index))
341341
return false;
342-
for (int I : seq<int>(Index, Index + SV->getShuffleMask().size()))
343-
ExtractionIndex.push_back(I);
342+
ExpectedIndex.set(Index / ShuffleMaskSize);
344343
return true;
345344
}))
346345
return 0;
347-
if (!is_sorted(ExtractionIndex))
346+
if (!ExpectedIndex.all())
348347
return 0;
349348
++NumGroup;
350349
}
@@ -10289,12 +10288,40 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1028910288
return VecCost;
1029010289
};
1029110290
if (SLPReVec && !E->isAltShuffle())
10292-
return GetCostDiff(GetScalarCost, [](InstructionCost) {
10293-
// shufflevector will be eliminated by instcombine because the
10294-
// shufflevector masks are used in order (guaranteed by
10295-
// getShufflevectorNumGroups). The vector cost is 0.
10296-
return TTI::TCC_Free;
10297-
});
10291+
return GetCostDiff(
10292+
GetScalarCost, [&](InstructionCost) -> InstructionCost {
10293+
// If a group uses mask in order, the shufflevector can be
10294+
// eliminated by instcombine. Then the cost is 0.
10295+
assert(isa<ShuffleVectorInst>(VL.front()) &&
10296+
"Not supported shufflevector usage.");
10297+
auto *SV = cast<ShuffleVectorInst>(VL.front());
10298+
unsigned SVNumElements =
10299+
cast<FixedVectorType>(SV->getOperand(0)->getType())
10300+
->getNumElements();
10301+
unsigned GroupSize = SVNumElements / SV->getShuffleMask().size();
10302+
for (size_t I = 0, End = VL.size(); I != End; I += GroupSize) {
10303+
ArrayRef<Value *> Group = VL.slice(I, GroupSize);
10304+
int NextIndex = 0;
10305+
if (!all_of(Group, [&](Value *V) {
10306+
assert(isa<ShuffleVectorInst>(V) &&
10307+
"Not supported shufflevector usage.");
10308+
auto *SV = cast<ShuffleVectorInst>(V);
10309+
int Index;
10310+
bool isExtractSubvectorMask =
10311+
SV->isExtractSubvectorMask(Index);
10312+
assert(isExtractSubvectorMask &&
10313+
"Not supported shufflevector usage.");
10314+
if (NextIndex != Index)
10315+
return false;
10316+
NextIndex += SV->getShuffleMask().size();
10317+
return true;
10318+
}))
10319+
return ::getShuffleCost(
10320+
*TTI, TargetTransformInfo::SK_PermuteSingleSrc, VecTy,
10321+
calculateShufflevectorMask(E->Scalars));
10322+
}
10323+
return TTI::TCC_Free;
10324+
});
1029810325
return GetCostDiff(GetScalarCost, GetVectorCost);
1029910326
}
1030010327
case Instruction::Freeze:
@@ -14072,9 +14099,16 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1407214099
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
1407314100
return E->VectorizedValue;
1407414101
}
14075-
// The current shufflevector usage always duplicate the source.
14076-
V = Builder.CreateShuffleVector(Src,
14077-
calculateShufflevectorMask(E->Scalars));
14102+
assert(isa<ShuffleVectorInst>(Src) &&
14103+
"Not supported shufflevector usage.");
14104+
auto *SVSrc = cast<ShuffleVectorInst>(Src);
14105+
assert(isa<PoisonValue>(SVSrc->getOperand(1)) &&
14106+
"Not supported shufflevector usage.");
14107+
SmallVector<int> ThisMask(calculateShufflevectorMask(E->Scalars));
14108+
SmallVector<int> NewMask(ThisMask.size());
14109+
transform(ThisMask, NewMask.begin(),
14110+
[&SVSrc](int Mask) { return SVSrc->getShuffleMask()[Mask]; });
14111+
V = Builder.CreateShuffleVector(SVSrc->getOperand(0), NewMask);
1407814112
propagateIRFlags(V, E->Scalars, VL0);
1407914113
} else {
1408014114
assert(E->isAltShuffle() &&

llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,9 @@ define void @test2(ptr %in, ptr %out) {
3434
; CHECK-LABEL: @test2(
3535
; CHECK-NEXT: entry:
3636
; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i32>, ptr [[IN:%.*]], align 1
37-
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
38-
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
39-
; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i32> [[TMP1]] to <4 x i64>
40-
; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i32> [[TMP2]] to <4 x i64>
41-
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
42-
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
43-
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[OUT:%.*]], i64 16
44-
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 32
45-
; CHECK-NEXT: store <2 x i64> [[TMP5]], ptr [[OUT]], align 8
46-
; CHECK-NEXT: store <2 x i64> [[TMP6]], ptr [[TMP7]], align 8
47-
; CHECK-NEXT: store <4 x i64> [[TMP4]], ptr [[TMP8]], align 8
37+
; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i32> [[TMP0]] to <8 x i64>
38+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i64> [[TMP1]], <8 x i64> poison, <8 x i32> <i32 2, i32 3, i32 0, i32 1, i32 4, i32 5, i32 6, i32 7>
39+
; CHECK-NEXT: store <8 x i64> [[TMP2]], ptr [[OUT:%.*]], align 8
4840
; CHECK-NEXT: ret void
4941
;
5042
entry:
@@ -67,3 +59,26 @@ entry:
6759
store <2 x i64> %8, ptr %12, align 8
6860
ret void
6961
}
62+
63+
define void @test3(<16 x i32> %0, ptr %out) {
64+
; CHECK-LABEL: @test3(
65+
; CHECK-NEXT: entry:
66+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <16 x i32> [[TMP0:%.*]], <16 x i32> poison, <16 x i32> <i32 12, i32 13, i32 14, i32 15, i32 8, i32 9, i32 10, i32 11, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3>
67+
; CHECK-NEXT: store <16 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4
68+
; CHECK-NEXT: ret void
69+
;
70+
entry:
71+
%1 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 12, i32 13, i32 14, i32 15>
72+
%2 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 8, i32 9, i32 10, i32 11>
73+
%3 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
74+
%4 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
75+
%5 = getelementptr inbounds i32, ptr %out, i64 0
76+
%6 = getelementptr inbounds i32, ptr %out, i64 4
77+
%7 = getelementptr inbounds i32, ptr %out, i64 8
78+
%8 = getelementptr inbounds i32, ptr %out, i64 12
79+
store <4 x i32> %1, ptr %5, align 4
80+
store <4 x i32> %2, ptr %6, align 4
81+
store <4 x i32> %3, ptr %7, align 4
82+
store <4 x i32> %4, ptr %8, align 4
83+
ret void
84+
}

0 commit comments

Comments
 (0)