Skip to content

Commit 1c722fc

Browse files
authored
[RISCV][TTI] Use processShuffleMask for shuffle legalization estimate (#136191)
We had some code which tried to estimate legalization costs for illegally typed shuffles, but it only handled the case of a widening shuffle, and used a somewhat adhoc heuristic. We can reuse the processShuffleMask utility (which we already use for individual vector register splitting when exact VLEN is known) to perform the same splitting given the legal vector type as the unit of split instead. This makes the costing both simpler and more robust. Note that this swings costs for illegal shuffles pretty wildly as we were previously sometimes hitting the adhoc code, and sometimes falling through into generic scalarization costing. I don't know that any of the costs for the individual tests in tree are significant, but the test which which triggered me finding this was reported to me by Alexey reduced from something triggering a bad choice in SLP for x264. So this has the potential to be somewhat high impact.
1 parent 55678dc commit 1c722fc

File tree

4 files changed

+233
-377
lines changed

4 files changed

+233
-377
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 62 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,62 @@ static VectorType *getVRGatherIndexType(MVT DataVT, const RISCVSubtarget &ST,
385385
return cast<VectorType>(EVT(IndexVT).getTypeForEVT(C));
386386
}
387387

388+
/// Attempt to approximate the cost of a shuffle which will require splitting
389+
/// during legalization. Note that processShuffleMasks is not an exact proxy
390+
/// for the algorithm used in LegalizeVectorTypes, but hopefully it's a
391+
/// reasonably close upperbound.
392+
static InstructionCost costShuffleViaSplitting(RISCVTTIImpl &TTI, MVT LegalVT,
393+
VectorType *Tp,
394+
ArrayRef<int> Mask,
395+
TTI::TargetCostKind CostKind) {
396+
assert(LegalVT.isFixedLengthVector() && !Mask.empty() &&
397+
"Expected fixed vector type and non-empty mask");
398+
unsigned LegalNumElts = LegalVT.getVectorNumElements();
399+
// Number of destination vectors after legalization:
400+
unsigned NumOfDests = divideCeil(Mask.size(), LegalNumElts);
401+
// We are going to permute multiple sources and the result will be in
402+
// multiple destinations. Providing an accurate cost only for splits where
403+
// the element type remains the same.
404+
if (NumOfDests <= 1 ||
405+
LegalVT.getVectorElementType().getSizeInBits() !=
406+
Tp->getElementType()->getPrimitiveSizeInBits() ||
407+
LegalNumElts >= Tp->getElementCount().getFixedValue())
408+
return InstructionCost::getInvalid();
409+
410+
unsigned VecTySize = TTI.getDataLayout().getTypeStoreSize(Tp);
411+
unsigned LegalVTSize = LegalVT.getStoreSize();
412+
// Number of source vectors after legalization:
413+
unsigned NumOfSrcs = divideCeil(VecTySize, LegalVTSize);
414+
415+
auto *SingleOpTy = FixedVectorType::get(Tp->getElementType(), LegalNumElts);
416+
417+
unsigned NormalizedVF = LegalNumElts * std::max(NumOfSrcs, NumOfDests);
418+
unsigned NumOfSrcRegs = NormalizedVF / LegalNumElts;
419+
unsigned NumOfDestRegs = NormalizedVF / LegalNumElts;
420+
SmallVector<int> NormalizedMask(NormalizedVF, PoisonMaskElem);
421+
assert(NormalizedVF >= Mask.size() &&
422+
"Normalized mask expected to be not shorter than original mask.");
423+
copy(Mask, NormalizedMask.begin());
424+
InstructionCost Cost = 0;
425+
SmallDenseSet<std::pair<ArrayRef<int>, unsigned>> ReusedSingleSrcShuffles;
426+
processShuffleMasks(
427+
NormalizedMask, NumOfSrcRegs, NumOfDestRegs, NumOfDestRegs, []() {},
428+
[&](ArrayRef<int> RegMask, unsigned SrcReg, unsigned DestReg) {
429+
if (ShuffleVectorInst::isIdentityMask(RegMask, RegMask.size()))
430+
return;
431+
if (!ReusedSingleSrcShuffles.insert(std::make_pair(RegMask, SrcReg))
432+
.second)
433+
return;
434+
Cost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, SingleOpTy,
435+
RegMask, CostKind, 0, nullptr);
436+
},
437+
[&](ArrayRef<int> RegMask, unsigned Idx1, unsigned Idx2, bool NewReg) {
438+
Cost += TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, SingleOpTy, RegMask,
439+
CostKind, 0, nullptr);
440+
});
441+
return Cost;
442+
}
443+
388444
/// Try to perform better estimation of the permutation.
389445
/// 1. Split the source/destination vectors into real registers.
390446
/// 2. Do the mask analysis to identify which real registers are
@@ -647,48 +703,13 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
647703
return true;
648704
}
649705
};
650-
// We are going to permute multiple sources and the result will be in
651-
// multiple destinations. Providing an accurate cost only for splits where
652-
// the element type remains the same.
706+
653707
if (!Mask.empty() && LT.first.isValid() && LT.first != 1 &&
654-
shouldSplit(Kind) &&
655-
LT.second.getVectorElementType().getSizeInBits() ==
656-
Tp->getElementType()->getPrimitiveSizeInBits() &&
657-
LT.second.getVectorNumElements() <
658-
cast<FixedVectorType>(Tp)->getNumElements() &&
659-
divideCeil(Mask.size(),
660-
cast<FixedVectorType>(Tp)->getNumElements()) ==
661-
static_cast<unsigned>(*LT.first.getValue())) {
662-
unsigned NumRegs = *LT.first.getValue();
663-
unsigned VF = cast<FixedVectorType>(Tp)->getNumElements();
664-
unsigned SubVF = PowerOf2Ceil(VF / NumRegs);
665-
auto *SubVecTy = FixedVectorType::get(Tp->getElementType(), SubVF);
666-
667-
InstructionCost Cost = 0;
668-
for (unsigned I = 0, NumSrcRegs = divideCeil(Mask.size(), SubVF);
669-
I < NumSrcRegs; ++I) {
670-
bool IsSingleVector = true;
671-
SmallVector<int> SubMask(SubVF, PoisonMaskElem);
672-
transform(
673-
Mask.slice(I * SubVF,
674-
I == NumSrcRegs - 1 ? Mask.size() % SubVF : SubVF),
675-
SubMask.begin(), [&](int I) -> int {
676-
if (I == PoisonMaskElem)
677-
return PoisonMaskElem;
678-
bool SingleSubVector = I / VF == 0;
679-
IsSingleVector &= SingleSubVector;
680-
return (SingleSubVector ? 0 : 1) * SubVF + (I % VF) % SubVF;
681-
});
682-
if (all_of(enumerate(SubMask), [](auto &&P) {
683-
return P.value() == PoisonMaskElem ||
684-
static_cast<unsigned>(P.value()) == P.index();
685-
}))
686-
continue;
687-
Cost += getShuffleCost(IsSingleVector ? TTI::SK_PermuteSingleSrc
688-
: TTI::SK_PermuteTwoSrc,
689-
SubVecTy, SubMask, CostKind, 0, nullptr);
690-
}
691-
return Cost;
708+
shouldSplit(Kind)) {
709+
InstructionCost SplitCost =
710+
costShuffleViaSplitting(*this, LT.second, FVTp, Mask, CostKind);
711+
if (SplitCost.isValid())
712+
return SplitCost;
692713
}
693714
}
694715

0 commit comments

Comments
 (0)