@@ -12613,11 +12613,13 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
12613
12613
}
12614
12614
InstructionCost createFreeze(InstructionCost Cost) { return Cost; }
12615
12615
/// Finalize emission of the shuffles.
12616
- InstructionCost
12617
- finalize(ArrayRef<int> ExtMask,
12618
- ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
12619
- ArrayRef<int> SubVectorsMask, unsigned VF = 0,
12620
- function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
12616
+ InstructionCost finalize(
12617
+ ArrayRef<int> ExtMask,
12618
+ ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
12619
+ ArrayRef<int> SubVectorsMask, unsigned VF = 0,
12620
+ function_ref<void(Value *&, SmallVectorImpl<int> &,
12621
+ function_ref<Value *(Value *, Value *, ArrayRef<int>)>)>
12622
+ Action = {}) {
12621
12623
IsFinalized = true;
12622
12624
if (Action) {
12623
12625
const PointerUnion<Value *, const TreeEntry *> &Vec = InVectors.front();
@@ -12629,7 +12631,10 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
12629
12631
assert(VF > 0 &&
12630
12632
"Expected vector length for the final value before action.");
12631
12633
Value *V = cast<Value *>(Vec);
12632
- Action(V, CommonMask);
12634
+ Action(V, CommonMask, [this](Value *V1, Value *V2, ArrayRef<int> Mask) {
12635
+ Cost += createShuffle(V1, V2, Mask);
12636
+ return V1;
12637
+ });
12633
12638
InVectors.front() = V;
12634
12639
}
12635
12640
if (!SubVectors.empty()) {
@@ -16593,11 +16598,13 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
16593
16598
/// Finalize emission of the shuffles.
16594
16599
/// \param Action the action (if any) to be performed before final applying of
16595
16600
/// the \p ExtMask mask.
16596
- Value *
16597
- finalize(ArrayRef<int> ExtMask,
16598
- ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
16599
- ArrayRef<int> SubVectorsMask, unsigned VF = 0,
16600
- function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
16601
+ Value *finalize(
16602
+ ArrayRef<int> ExtMask,
16603
+ ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
16604
+ ArrayRef<int> SubVectorsMask, unsigned VF = 0,
16605
+ function_ref<void(Value *&, SmallVectorImpl<int> &,
16606
+ function_ref<Value *(Value *, Value *, ArrayRef<int>)>)>
16607
+ Action = {}) {
16601
16608
IsFinalized = true;
16602
16609
if (Action) {
16603
16610
Value *Vec = InVectors.front();
@@ -16616,7 +16623,9 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
16616
16623
std::iota(ResizeMask.begin(), std::next(ResizeMask.begin(), VecVF), 0);
16617
16624
Vec = createShuffle(Vec, nullptr, ResizeMask);
16618
16625
}
16619
- Action(Vec, CommonMask);
16626
+ Action(Vec, CommonMask, [this](Value *V1, Value *V2, ArrayRef<int> Mask) {
16627
+ return createShuffle(V1, V2, Mask);
16628
+ });
16620
16629
InVectors.front() = Vec;
16621
16630
}
16622
16631
if (!SubVectors.empty()) {
@@ -17278,9 +17287,67 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
17278
17287
else
17279
17288
Res = ShuffleBuilder.finalize(
17280
17289
E->ReuseShuffleIndices, SubVectors, SubVectorsMask, E->Scalars.size(),
17281
- [&](Value *&Vec, SmallVectorImpl<int> &Mask) {
17282
- TryPackScalars(NonConstants, Mask, /*IsRootPoison=*/false);
17283
- Vec = ShuffleBuilder.gather(NonConstants, Mask.size(), Vec);
17290
+ [&](Value *&Vec, SmallVectorImpl<int> &Mask, auto CreateShuffle) {
17291
+ bool IsSplat = isSplat(NonConstants);
17292
+ SmallVector<int> BVMask(Mask.size(), PoisonMaskElem);
17293
+ TryPackScalars(NonConstants, BVMask, /*IsRootPoison=*/false);
17294
+ auto CheckIfSplatIsProfitable = [&]() {
17295
+ // Estimate the cost of splatting + shuffle and compare with
17296
+ // insert + shuffle.
17297
+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
17298
+ Value *V = *find_if_not(NonConstants, IsaPred<UndefValue>);
17299
+ if (isa<ExtractElementInst>(V) || isVectorized(V))
17300
+ return false;
17301
+ InstructionCost SplatCost = TTI->getVectorInstrCost(
17302
+ Instruction::InsertElement, VecTy, CostKind, /*Index=*/0,
17303
+ PoisonValue::get(VecTy), V);
17304
+ SmallVector<int> NewMask(Mask.begin(), Mask.end());
17305
+ for (auto [Idx, I] : enumerate(BVMask))
17306
+ if (I != PoisonMaskElem)
17307
+ NewMask[Idx] = Mask.size();
17308
+ SplatCost += ::getShuffleCost(*TTI, TTI::SK_PermuteTwoSrc, VecTy,
17309
+ NewMask, CostKind);
17310
+ InstructionCost BVCost = TTI->getVectorInstrCost(
17311
+ Instruction::InsertElement, VecTy, CostKind,
17312
+ *find_if(Mask, [](int I) { return I != PoisonMaskElem; }),
17313
+ Vec, V);
17314
+ // Shuffle required?
17315
+ if (count(BVMask, PoisonMaskElem) <
17316
+ static_cast<int>(BVMask.size() - 1)) {
17317
+ SmallVector<int> NewMask(Mask.begin(), Mask.end());
17318
+ for (auto [Idx, I] : enumerate(BVMask))
17319
+ if (I != PoisonMaskElem)
17320
+ NewMask[Idx] = I;
17321
+ BVCost += ::getShuffleCost(*TTI, TTI::SK_PermuteSingleSrc,
17322
+ VecTy, NewMask, CostKind);
17323
+ }
17324
+ return SplatCost <= BVCost;
17325
+ };
17326
+ if (!IsSplat || Mask.size() <= 2 || !CheckIfSplatIsProfitable()) {
17327
+ for (auto [Idx, I] : enumerate(BVMask))
17328
+ if (I != PoisonMaskElem)
17329
+ Mask[Idx] = I;
17330
+ Vec = ShuffleBuilder.gather(NonConstants, Mask.size(), Vec);
17331
+ } else {
17332
+ Value *V = *find_if_not(NonConstants, IsaPred<UndefValue>);
17333
+ SmallVector<Value *> Values(NonConstants.size(),
17334
+ PoisonValue::get(ScalarTy));
17335
+ Values[0] = V;
17336
+ Value *BV = ShuffleBuilder.gather(Values, BVMask.size());
17337
+ SmallVector<int> SplatMask(BVMask.size(), PoisonMaskElem);
17338
+ transform(BVMask, SplatMask.begin(), [](int I) {
17339
+ return I == PoisonMaskElem ? PoisonMaskElem : 0;
17340
+ });
17341
+ if (!ShuffleVectorInst::isIdentityMask(SplatMask, VF))
17342
+ BV = CreateShuffle(BV, nullptr, SplatMask);
17343
+ for (auto [Idx, I] : enumerate(BVMask))
17344
+ if (I != PoisonMaskElem)
17345
+ Mask[Idx] = BVMask.size() + Idx;
17346
+ Vec = CreateShuffle(Vec, BV, Mask);
17347
+ for (auto [Idx, I] : enumerate(Mask))
17348
+ if (I != PoisonMaskElem)
17349
+ Mask[Idx] = Idx;
17350
+ }
17284
17351
});
17285
17352
} else if (!allConstant(GatheredScalars)) {
17286
17353
// Gather unique scalars and all constants.
0 commit comments