Skip to content

Commit 009002a

Browse files
committed
[SLP][NFC]Unify matching for perfect diamond match between cost and codegen
models, NFC.
1 parent 70900ec commit 009002a

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7659,16 +7659,24 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
76597659
<< "SLP: perfect diamond match for gather bundle "
76607660
<< shortBundleName(VL) << ".\n");
76617661
// Restore the mask for previous partially matched values.
7662-
for (auto [I, V] : enumerate(E->Scalars)) {
7663-
if (isa<PoisonValue>(V)) {
7664-
Mask[I] = PoisonMaskElem;
7665-
continue;
7662+
Mask.resize(E->Scalars.size());
7663+
const TreeEntry *FrontTE = Entries.front().front();
7664+
if (FrontTE->ReorderIndices.empty() &&
7665+
((FrontTE->ReuseShuffleIndices.empty() &&
7666+
E->Scalars.size() == FrontTE->Scalars.size()) ||
7667+
(E->Scalars.size() == FrontTE->ReuseShuffleIndices.size()))) {
7668+
std::iota(Mask.begin(), Mask.end(), 0);
7669+
} else {
7670+
for (auto [I, V] : enumerate(E->Scalars)) {
7671+
if (isa<PoisonValue>(V)) {
7672+
Mask[I] = PoisonMaskElem;
7673+
continue;
7674+
}
7675+
Mask[I] = FrontTE->findLaneForValue(V);
76667676
}
7667-
if (Mask[I] == PoisonMaskElem)
7668-
Mask[I] = Entries.front().front()->findLaneForValue(V);
76697677
}
7670-
Estimator.add(*Entries.front().front(), Mask);
7671-
return Estimator.finalize(E->ReuseShuffleIndices);
7678+
Estimator.add(*FrontTE, Mask);
7679+
return Estimator.finalize(E->getCommonMask());
76727680
}
76737681
if (!Resized) {
76747682
if (GatheredScalars.size() != VF &&
@@ -9460,10 +9468,19 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
94609468
auto *It = find_if(FirstEntries, [=](const TreeEntry *EntryPtr) {
94619469
return EntryPtr->isSame(VL) || EntryPtr->isSame(TE->Scalars);
94629470
});
9463-
if (It != FirstEntries.end() && (*It)->getVectorFactor() == VL.size()) {
9471+
if (It != FirstEntries.end() &&
9472+
((*It)->getVectorFactor() == VL.size() ||
9473+
((*It)->getVectorFactor() == TE->Scalars.size() &&
9474+
TE->ReuseShuffleIndices.size() == VL.size() &&
9475+
(*It)->isSame(TE->Scalars)))) {
94649476
Entries.push_back(*It);
9465-
std::iota(std::next(Mask.begin(), Part * VL.size()),
9466-
std::next(Mask.begin(), (Part + 1) * VL.size()), 0);
9477+
if ((*It)->getVectorFactor() == VL.size()) {
9478+
std::iota(std::next(Mask.begin(), Part * VL.size()),
9479+
std::next(Mask.begin(), (Part + 1) * VL.size()), 0);
9480+
} else {
9481+
SmallVector<int> CommonMask = TE->getCommonMask();
9482+
copy(CommonMask, Mask.begin());
9483+
}
94679484
// Clear undef scalars.
94689485
for (int I = 0, Sz = VL.size(); I < Sz; ++I)
94699486
if (isa<PoisonValue>(VL[I]))
@@ -10657,6 +10674,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
1065710674
<< "SLP: perfect diamond match for gather bundle "
1065810675
<< shortBundleName(E->Scalars) << ".\n");
1065910676
// Restore the mask for previous partially matched values.
10677+
Mask.resize(E->Scalars.size());
1066010678
const TreeEntry *FrontTE = Entries.front().front();
1066110679
if (FrontTE->ReorderIndices.empty() &&
1066210680
((FrontTE->ReuseShuffleIndices.empty() &&

0 commit comments

Comments
 (0)