Skip to content

Commit 2f40145

Browse files
[RISCV][TTI]Use processShuffleMasks for cost estimations/actual per-register shuffles
Patch adds usage of processShuffleMasks in TTI for RISCV. This function is already used for X86 shuffles estimations and in DAGTypeLegalizer::SplitVecRes_VECTOR_SHUFFLE functions and in RISCV codegen. Patch allows better cost estimation for sparse masks and unifies cost/codegen between different targets/passes Reviewers: preames Reviewed By: preames Pull Request: llvm#118103
1 parent 87782b2 commit 2f40145

File tree

4 files changed

+183
-101
lines changed

4 files changed

+183
-101
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,104 @@ static VectorType *getVRGatherIndexType(MVT DataVT, const RISCVSubtarget &ST,
376376
return cast<VectorType>(EVT(IndexVT).getTypeForEVT(C));
377377
}
378378

379+
/// Try to perform better estimation of the permutation.
380+
/// 1. Split the source/destination vectors into real registers.
381+
/// 2. Do the mask analysis to identify which real registers are
382+
/// permuted. If more than 1 source registers are used for the
383+
/// destination register building, the cost for this destination register
384+
/// is (Number_of_source_register - 1) * Cost_PermuteTwoSrc. If only one
385+
/// source register is used, build mask and calculate the cost as a cost
386+
/// of PermuteSingleSrc.
387+
/// Also, for the single register permute we try to identify if the
388+
/// destination register is just a copy of the source register or the
389+
/// copy of the previous destination register (the cost is
390+
/// TTI::TCC_Basic). If the source register is just reused, the cost for
391+
/// this operation is 0.
392+
static InstructionCost
393+
costShuffleViaVRegSplitting(RISCVTTIImpl &TTI, MVT LegalVT,
394+
std::optional<unsigned> VLen, VectorType *Tp,
395+
ArrayRef<int> Mask, TTI::TargetCostKind CostKind) {
396+
InstructionCost NumOfDests = InstructionCost::getInvalid();
397+
if (VLen && LegalVT.isFixedLengthVector() && !Mask.empty()) {
398+
MVT ElemVT = LegalVT.getVectorElementType();
399+
unsigned ElemsPerVReg = *VLen / ElemVT.getFixedSizeInBits();
400+
LegalVT = TTI.getTypeLegalizationCost(
401+
FixedVectorType::get(Tp->getElementType(), ElemsPerVReg))
402+
.second;
403+
// Number of destination vectors after legalization:
404+
NumOfDests = divideCeil(Mask.size(), LegalVT.getVectorNumElements());
405+
}
406+
if (!NumOfDests.isValid() || NumOfDests <= 1 ||
407+
!LegalVT.isFixedLengthVector() ||
408+
LegalVT.getVectorElementType().getSizeInBits() !=
409+
Tp->getElementType()->getPrimitiveSizeInBits() ||
410+
LegalVT.getVectorNumElements() >= Tp->getElementCount().getFixedValue())
411+
return InstructionCost::getInvalid();
412+
413+
unsigned VecTySize = TTI.getDataLayout().getTypeStoreSize(Tp);
414+
unsigned LegalVTSize = LegalVT.getStoreSize();
415+
// Number of source vectors after legalization:
416+
unsigned NumOfSrcs = divideCeil(VecTySize, LegalVTSize);
417+
418+
auto *SingleOpTy = FixedVectorType::get(Tp->getElementType(),
419+
LegalVT.getVectorNumElements());
420+
421+
unsigned E = *NumOfDests.getValue();
422+
unsigned NormalizedVF =
423+
LegalVT.getVectorNumElements() * std::max(NumOfSrcs, E);
424+
unsigned NumOfSrcRegs = NormalizedVF / LegalVT.getVectorNumElements();
425+
unsigned NumOfDestRegs = NormalizedVF / LegalVT.getVectorNumElements();
426+
SmallVector<int> NormalizedMask(NormalizedVF, PoisonMaskElem);
427+
assert(NormalizedVF >= Mask.size() &&
428+
"Normalized mask expected to be not shorter than original mask.");
429+
copy(Mask, NormalizedMask.begin());
430+
InstructionCost Cost = 0;
431+
SmallBitVector ExtractedRegs(2 * NumOfSrcRegs);
432+
int NumShuffles = 0;
433+
processShuffleMasks(
434+
NormalizedMask, NumOfSrcRegs, NumOfDestRegs, NumOfDestRegs, []() {},
435+
[&](ArrayRef<int> RegMask, unsigned SrcReg, unsigned DestReg) {
436+
if (ExtractedRegs.test(SrcReg)) {
437+
Cost += TTI.getShuffleCost(TTI::SK_ExtractSubvector, Tp, {}, CostKind,
438+
(SrcReg % NumOfSrcRegs) *
439+
SingleOpTy->getNumElements(),
440+
SingleOpTy);
441+
ExtractedRegs.set(SrcReg);
442+
}
443+
if (!ShuffleVectorInst::isIdentityMask(RegMask, RegMask.size())) {
444+
++NumShuffles;
445+
Cost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, SingleOpTy,
446+
RegMask, CostKind, 0, nullptr);
447+
return;
448+
}
449+
},
450+
[&](ArrayRef<int> RegMask, unsigned Idx1, unsigned Idx2, bool NewReg) {
451+
if (ExtractedRegs.test(Idx1)) {
452+
Cost += TTI.getShuffleCost(
453+
TTI::SK_ExtractSubvector, Tp, {}, CostKind,
454+
(Idx1 % NumOfSrcRegs) * SingleOpTy->getNumElements(), SingleOpTy);
455+
ExtractedRegs.set(Idx1);
456+
}
457+
if (ExtractedRegs.test(Idx2)) {
458+
Cost += TTI.getShuffleCost(
459+
TTI::SK_ExtractSubvector, Tp, {}, CostKind,
460+
(Idx2 % NumOfSrcRegs) * SingleOpTy->getNumElements(), SingleOpTy);
461+
ExtractedRegs.set(Idx2);
462+
}
463+
Cost += TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, SingleOpTy, RegMask,
464+
CostKind, 0, nullptr);
465+
NumShuffles += 2;
466+
});
467+
// Note: check that we do not emit too many shuffles here to prevent code
468+
// size explosion.
469+
// TODO: investigate, if it can be improved by extra analysis of the masks
470+
// to check if the code is more profitable.
471+
if ((NumOfDestRegs > 2 && NumShuffles <= static_cast<int>(NumOfDestRegs)) ||
472+
(NumOfDestRegs <= 2 && NumShuffles < 4))
473+
return Cost;
474+
return InstructionCost::getInvalid();
475+
}
476+
379477
InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
380478
VectorType *Tp, ArrayRef<int> Mask,
381479
TTI::TargetCostKind CostKind,
@@ -389,7 +487,11 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
389487
// First, handle cases where having a fixed length vector enables us to
390488
// give a more accurate cost than falling back to generic scalable codegen.
391489
// TODO: Each of these cases hints at a modeling gap around scalable vectors.
392-
if (isa<FixedVectorType>(Tp)) {
490+
if (ST->hasVInstructions() && isa<FixedVectorType>(Tp)) {
491+
InstructionCost VRegSplittingCost = costShuffleViaVRegSplitting(
492+
*this, LT.second, ST->getRealVLen(), Tp, Mask, CostKind);
493+
if (VRegSplittingCost.isValid())
494+
return VRegSplittingCost;
393495
switch (Kind) {
394496
default:
395497
break;

0 commit comments

Comments
 (0)