Skip to content

Commit f5f55ad

Browse files
committed
[RISCV][TTI] Common a check in getShufleCost [nfc]
None of the vector costings apply if we're scalarizing. Pull that check into an early guard instead.
1 parent 9b7bf1f commit f5f55ad

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,9 @@ static InstructionCost
393393
costShuffleViaVRegSplitting(RISCVTTIImpl &TTI, MVT LegalVT,
394394
std::optional<unsigned> VLen, VectorType *Tp,
395395
ArrayRef<int> Mask, TTI::TargetCostKind CostKind) {
396+
assert(LegalVT.isFixedLengthVector());
396397
InstructionCost NumOfDests = InstructionCost::getInvalid();
397-
if (VLen && LegalVT.isFixedLengthVector() && !Mask.empty()) {
398+
if (VLen && !Mask.empty()) {
398399
MVT ElemVT = LegalVT.getVectorElementType();
399400
unsigned ElemsPerVReg = *VLen / ElemVT.getFixedSizeInBits();
400401
LegalVT = TTI.getTypeLegalizationCost(
@@ -404,7 +405,6 @@ costShuffleViaVRegSplitting(RISCVTTIImpl &TTI, MVT LegalVT,
404405
NumOfDests = divideCeil(Mask.size(), LegalVT.getVectorNumElements());
405406
}
406407
if (!NumOfDests.isValid() || NumOfDests <= 1 ||
407-
!LegalVT.isFixedLengthVector() ||
408408
LegalVT.getVectorElementType().getSizeInBits() !=
409409
Tp->getElementType()->getPrimitiveSizeInBits() ||
410410
LegalVT.getVectorNumElements() >= Tp->getElementCount().getFixedValue())
@@ -487,7 +487,8 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
487487
// First, handle cases where having a fixed length vector enables us to
488488
// give a more accurate cost than falling back to generic scalable codegen.
489489
// TODO: Each of these cases hints at a modeling gap around scalable vectors.
490-
if (ST->hasVInstructions() && isa<FixedVectorType>(Tp)) {
490+
if (ST->hasVInstructions() && isa<FixedVectorType>(Tp) &&
491+
LT.second.isFixedLengthVector()) {
491492
InstructionCost VRegSplittingCost = costShuffleViaVRegSplitting(
492493
*this, LT.second, ST->getRealVLen(), Tp, Mask, CostKind);
493494
if (VRegSplittingCost.isValid())
@@ -496,7 +497,7 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
496497
default:
497498
break;
498499
case TTI::SK_PermuteSingleSrc: {
499-
if (Mask.size() >= 2 && LT.second.isFixedLengthVector()) {
500+
if (Mask.size() >= 2) {
500501
MVT EltTp = LT.second.getVectorElementType();
501502
// If the size of the element is < ELEN then shuffles of interleaves and
502503
// deinterleaves of 2 vectors can be lowered into the following
@@ -545,10 +546,10 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
545546
}
546547
// vrgather + cost of generating the mask constant.
547548
// We model this for an unknown mask with a single vrgather.
548-
if (LT.second.isFixedLengthVector() && LT.first == 1 &&
549-
(LT.second.getScalarSizeInBits() != 8 ||
550-
LT.second.getVectorNumElements() <= 256)) {
551-
VectorType *IdxTy = getVRGatherIndexType(LT.second, *ST, Tp->getContext());
549+
if (LT.first == 1 && (LT.second.getScalarSizeInBits() != 8 ||
550+
LT.second.getVectorNumElements() <= 256)) {
551+
VectorType *IdxTy =
552+
getVRGatherIndexType(LT.second, *ST, Tp->getContext());
552553
InstructionCost IndexCost = getConstantPoolLoadCost(IdxTy, CostKind);
553554
return IndexCost +
554555
getRISCVInstructionCost(RISCV::VRGATHER_VV, LT.second, CostKind);
@@ -560,9 +561,8 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
560561
// 2 x (vrgather + cost of generating the mask constant) + cost of mask
561562
// register for the second vrgather. We model this for an unknown
562563
// (shuffle) mask.
563-
if (LT.second.isFixedLengthVector() && LT.first == 1 &&
564-
(LT.second.getScalarSizeInBits() != 8 ||
565-
LT.second.getVectorNumElements() <= 256)) {
564+
if (LT.first == 1 && (LT.second.getScalarSizeInBits() != 8 ||
565+
LT.second.getVectorNumElements() <= 256)) {
566566
auto &C = Tp->getContext();
567567
auto EC = Tp->getElementCount();
568568
VectorType *IdxTy = getVRGatherIndexType(LT.second, *ST, C);
@@ -581,7 +581,6 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
581581
// multiple destinations. Providing an accurate cost only for splits where
582582
// the element type remains the same.
583583
if (!Mask.empty() && LT.first.isValid() && LT.first != 1 &&
584-
LT.second.isFixedLengthVector() &&
585584
LT.second.getVectorElementType().getSizeInBits() ==
586585
Tp->getElementType()->getPrimitiveSizeInBits() &&
587586
LT.second.getVectorNumElements() <

0 commit comments

Comments
 (0)