-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SLP]Allow matching and shuffling of extractelement vector operands with different VF. #97414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Created using spr 1.3.5
@llvm/pr-subscribers-llvm-transforms Author: Alexey Bataev (alexey-bataev) ChangesAllows better codegen with the free resizing of small VF vector operands Full diff: https://github.com/llvm/llvm-project/pull/97414.diff 3 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e233de89a33f1..3ad40c329211e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -510,11 +510,17 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) {
const auto *It = find_if(VL, IsaPred<ExtractElementInst>);
if (It == VL.end())
return std::nullopt;
- auto *EI0 = cast<ExtractElementInst>(*It);
- if (isa<ScalableVectorType>(EI0->getVectorOperandType()))
- return std::nullopt;
unsigned Size =
- cast<FixedVectorType>(EI0->getVectorOperandType())->getNumElements();
+ std::accumulate(VL.begin(), VL.end(), 0u, [](unsigned S, Value *V) {
+ auto *EI = dyn_cast<ExtractElementInst>(V);
+ if (!EI)
+ return S;
+ auto *VTy = dyn_cast<FixedVectorType>(EI->getVectorOperandType());
+ if (!VTy)
+ return S;
+ return std::max(S, VTy->getNumElements());
+ });
+
Value *Vec1 = nullptr;
Value *Vec2 = nullptr;
bool HasNonUndefVec = any_of(VL, [](Value *V) {
@@ -544,8 +550,6 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) {
if (isa<UndefValue>(Vec)) {
Mask[I] = I;
} else {
- if (cast<FixedVectorType>(Vec->getType())->getNumElements() != Size)
- return std::nullopt;
if (isa<UndefValue>(EI->getIndexOperand()))
continue;
auto *Idx = dyn_cast<ConstantInt>(EI->getIndexOperand());
@@ -10639,36 +10643,20 @@ BoUpSLP::tryToGatherSingleRegisterExtractElements(
VectorOpToIdx[EI->getVectorOperand()].push_back(I);
}
// Sort the vector operands by the maximum number of uses in extractelements.
- MapVector<unsigned, SmallVector<Value *>> VFToVector;
- for (const auto &Data : VectorOpToIdx)
- VFToVector[cast<FixedVectorType>(Data.first->getType())->getNumElements()]
- .push_back(Data.first);
- for (auto &Data : VFToVector) {
- stable_sort(Data.second, [&VectorOpToIdx](Value *V1, Value *V2) {
- return VectorOpToIdx.find(V1)->second.size() >
- VectorOpToIdx.find(V2)->second.size();
- });
- }
- // Find the best pair of the vectors with the same number of elements or a
- // single vector.
+ SmallVector<std::pair<Value *, SmallVector<int>>> Vectors =
+ VectorOpToIdx.takeVector();
+ stable_sort(Vectors, [](const auto &P1, const auto &P2) {
+ return P1.second.size() > P2.second.size();
+ });
+ // Find the best pair of the vectors or a single vector.
const int UndefSz = UndefVectorExtracts.size();
unsigned SingleMax = 0;
- Value *SingleVec = nullptr;
unsigned PairMax = 0;
- std::pair<Value *, Value *> PairVec(nullptr, nullptr);
- for (auto &Data : VFToVector) {
- Value *V1 = Data.second.front();
- if (SingleMax < VectorOpToIdx[V1].size() + UndefSz) {
- SingleMax = VectorOpToIdx[V1].size() + UndefSz;
- SingleVec = V1;
- }
- Value *V2 = nullptr;
- if (Data.second.size() > 1)
- V2 = *std::next(Data.second.begin());
- if (V2 && PairMax < VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() +
- UndefSz) {
- PairMax = VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + UndefSz;
- PairVec = std::make_pair(V1, V2);
+ if (!Vectors.empty()) {
+ SingleMax = Vectors.front().second.size() + UndefSz;
+ if (Vectors.size() > 1) {
+ auto *ItNext = std::next(Vectors.begin());
+ PairMax = SingleMax + ItNext->second.size();
}
}
if (SingleMax == 0 && PairMax == 0 && UndefSz == 0)
@@ -10679,11 +10667,11 @@ BoUpSLP::tryToGatherSingleRegisterExtractElements(
SmallVector<Value *> GatheredExtracts(
VL.size(), PoisonValue::get(VL.front()->getType()));
if (SingleMax >= PairMax && SingleMax) {
- for (int Idx : VectorOpToIdx[SingleVec])
+ for (int Idx : Vectors.front().second)
std::swap(GatheredExtracts[Idx], VL[Idx]);
- } else {
- for (Value *V : {PairVec.first, PairVec.second})
- for (int Idx : VectorOpToIdx[V])
+ } else if (!Vectors.empty()) {
+ for (unsigned Idx : {0, 1})
+ for (int Idx : Vectors[Idx].second)
std::swap(GatheredExtracts[Idx], VL[Idx]);
}
// Add extracts from undefs too.
@@ -11752,25 +11740,29 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
MutableArrayRef<int> SubMask = Mask.slice(Part * SliceSize, Limit);
constexpr int MaxBases = 2;
SmallVector<Value *, MaxBases> Bases(MaxBases);
-#ifndef NDEBUG
- int PrevSize = 0;
-#endif // NDEBUG
- for (const auto [I, V]: enumerate(VL)) {
- if (SubMask[I] == PoisonMaskElem)
+ auto VLMask = zip(VL, SubMask);
+ const unsigned VF = std::accumulate(
+ VLMask.begin(), VLMask.end(), 0U, [&](unsigned S, const auto &D) {
+ if (std::get<1>(D) == PoisonMaskElem)
+ return S;
+ Value *VecOp =
+ cast<ExtractElementInst>(std::get<0>(D))->getVectorOperand();
+ if (const TreeEntry *TE = R.getTreeEntry(VecOp))
+ VecOp = TE->VectorizedValue;
+ assert(VecOp && "Expected vectorized value.");
+ const unsigned Size =
+ cast<FixedVectorType>(VecOp->getType())->getNumElements();
+ return std::max(S, Size);
+ });
+ for (const auto [V, I] : VLMask) {
+ if (I == PoisonMaskElem)
continue;
Value *VecOp = cast<ExtractElementInst>(V)->getVectorOperand();
if (const TreeEntry *TE = R.getTreeEntry(VecOp))
VecOp = TE->VectorizedValue;
assert(VecOp && "Expected vectorized value.");
- const int Size =
- cast<FixedVectorType>(VecOp->getType())->getNumElements();
-#ifndef NDEBUG
- assert((PrevSize == Size || PrevSize == 0) &&
- "Expected vectors of the same size.");
- PrevSize = Size;
-#endif // NDEBUG
VecOp = castToScalarTyElem(VecOp);
- Bases[SubMask[I] < Size ? 0 : 1] = VecOp;
+ Bases[I / VF] = VecOp;
}
if (!Bases.front())
continue;
@@ -11796,16 +11788,17 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
"Expected first part or all previous parts masked.");
copy(SubMask, std::next(VecMask.begin(), Part * SliceSize));
} else {
- unsigned VF = cast<FixedVectorType>(Vec->getType())->getNumElements();
+ unsigned NewVF =
+ cast<FixedVectorType>(Vec->getType())->getNumElements();
if (Vec->getType() != SubVec->getType()) {
unsigned SubVecVF =
cast<FixedVectorType>(SubVec->getType())->getNumElements();
- VF = std::max(VF, SubVecVF);
+ NewVF = std::max(NewVF, SubVecVF);
}
// Adjust SubMask.
for (int &Idx : SubMask)
if (Idx != PoisonMaskElem)
- Idx += VF;
+ Idx += NewVF;
copy(SubMask, std::next(VecMask.begin(), Part * SliceSize));
Vec = createShuffle(Vec, SubVec, VecMask);
TransformToIdentity(VecMask);
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/extractelement-single-use-many-nodes.ll b/llvm/test/Transforms/SLPVectorizer/X86/extractelement-single-use-many-nodes.ll
index 24b95c4e6ff2f..4e6ed4bce6588 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/extractelement-single-use-many-nodes.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/extractelement-single-use-many-nodes.ll
@@ -9,11 +9,10 @@ define void @foo(double %i) {
; CHECK-NEXT: [[TMP1:%.*]] = fsub <4 x double> zeroinitializer, [[TMP0]]
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x double> poison, double [[I]], i32 0
; CHECK-NEXT: [[TMP4:%.*]] = fsub <2 x double> zeroinitializer, [[TMP3]]
-; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x double> [[TMP4]], i32 1
-; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> poison, <8 x i32> <i32 0, i32 poison, i32 poison, i32 1, i32 poison, i32 0, i32 poison, i32 1>
-; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <8 x double> [[TMP9]], <8 x double> <double poison, double 0.000000e+00, double poison, double poison, double 0.000000e+00, double poison, double poison, double poison>, <8 x i32> <i32 0, i32 9, i32 poison, i32 3, i32 12, i32 5, i32 poison, i32 7>
-; CHECK-NEXT: [[TMP7:%.*]] = insertelement <8 x double> [[TMP6]], double [[TMP5]], i32 2
-; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x double> [[TMP7]], <8 x double> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 2, i32 7>
+; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <2 x double> [[TMP4]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> [[TMP22]], <4 x i32> <i32 0, i32 0, i32 5, i32 1>
+; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x double> [[TMP5]], <4 x double> [[TMP5]], <8 x i32> <i32 0, i32 poison, i32 2, i32 3, i32 poison, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <8 x double> [[TMP6]], <8 x double> <double poison, double 0.000000e+00, double poison, double poison, double 0.000000e+00, double poison, double poison, double poison>, <8 x i32> <i32 0, i32 9, i32 2, i32 3, i32 12, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP12:%.*]] = fmul <8 x double> <double 0.000000e+00, double poison, double 0.000000e+00, double 0.000000e+00, double poison, double 0.000000e+00, double 0.000000e+00, double 0.000000e+00>, [[TMP8]]
; CHECK-NEXT: [[TMP13:%.*]] = fadd <8 x double> zeroinitializer, [[TMP12]]
; CHECK-NEXT: [[TMP14:%.*]] = fadd <8 x double> [[TMP13]], zeroinitializer
@@ -27,7 +26,6 @@ define void @foo(double %i) {
; CHECK-NEXT: [[TMP20:%.*]] = extractelement <2 x double> [[TMP18]], i32 1
; CHECK-NEXT: [[I118:%.*]] = fadd double [[TMP19]], [[TMP20]]
; CHECK-NEXT: [[TMP21:%.*]] = fmul <4 x double> zeroinitializer, [[TMP1]]
-; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <2 x double> [[TMP4]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <4 x double> <double 0.000000e+00, double 0.000000e+00, double 0.000000e+00, double poison>, <4 x double> [[TMP22]], <4 x i32> <i32 0, i32 1, i32 2, i32 5>
; CHECK-NEXT: [[TMP24:%.*]] = fadd <4 x double> [[TMP21]], [[TMP23]]
; CHECK-NEXT: [[TMP25:%.*]] = fadd <4 x double> [[TMP24]], zeroinitializer
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/extractlements-gathered-first-node.ll b/llvm/test/Transforms/SLPVectorizer/X86/extractlements-gathered-first-node.ll
index 57fa83b1ccdd6..d5f2cf7fc28c4 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/extractlements-gathered-first-node.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/extractlements-gathered-first-node.ll
@@ -6,7 +6,6 @@ define void @test() {
; CHECK-NEXT: bb:
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i32> zeroinitializer, i32 0
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x i32> zeroinitializer, i32 0
-; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x i32> <i32 0, i32 undef>, i32 [[TMP1]], i32 1
; CHECK-NEXT: [[ICMP:%.*]] = icmp ult i32 [[TMP0]], [[TMP1]]
; CHECK-NEXT: ret void
;
|
for (Value *V : {PairVec.first, PairVec.second}) | ||
for (int Idx : VectorOpToIdx[V]) | ||
} else if (!Vectors.empty()) { | ||
for (unsigned Idx : {0, 1}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why can we guarantee that Vectors.size() == 2 here (but not earlier)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we checked SingleMax != 0 in the first condition, which is a guarantee Vectors is not empty. Here I can change the condition just to check that PairMax != 0 instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Allows better codegen with the free resizing of small VF vector operands
and then regular shuffling of the operands of the same size and
simplifies the code.