@@ -393,8 +393,9 @@ static InstructionCost
393
393
costShuffleViaVRegSplitting (RISCVTTIImpl &TTI, MVT LegalVT,
394
394
std::optional<unsigned > VLen, VectorType *Tp,
395
395
ArrayRef<int > Mask, TTI::TargetCostKind CostKind) {
396
+ assert (LegalVT.isFixedLengthVector ());
396
397
InstructionCost NumOfDests = InstructionCost::getInvalid ();
397
- if (VLen && LegalVT. isFixedLengthVector () && !Mask.empty ()) {
398
+ if (VLen && !Mask.empty ()) {
398
399
MVT ElemVT = LegalVT.getVectorElementType ();
399
400
unsigned ElemsPerVReg = *VLen / ElemVT.getFixedSizeInBits ();
400
401
LegalVT = TTI.getTypeLegalizationCost (
@@ -404,7 +405,6 @@ costShuffleViaVRegSplitting(RISCVTTIImpl &TTI, MVT LegalVT,
404
405
NumOfDests = divideCeil (Mask.size (), LegalVT.getVectorNumElements ());
405
406
}
406
407
if (!NumOfDests.isValid () || NumOfDests <= 1 ||
407
- !LegalVT.isFixedLengthVector () ||
408
408
LegalVT.getVectorElementType ().getSizeInBits () !=
409
409
Tp->getElementType ()->getPrimitiveSizeInBits () ||
410
410
LegalVT.getVectorNumElements () >= Tp->getElementCount ().getFixedValue ())
@@ -487,7 +487,8 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
487
487
// First, handle cases where having a fixed length vector enables us to
488
488
// give a more accurate cost than falling back to generic scalable codegen.
489
489
// TODO: Each of these cases hints at a modeling gap around scalable vectors.
490
- if (ST->hasVInstructions () && isa<FixedVectorType>(Tp)) {
490
+ if (ST->hasVInstructions () && isa<FixedVectorType>(Tp) &&
491
+ LT.second .isFixedLengthVector ()) {
491
492
InstructionCost VRegSplittingCost = costShuffleViaVRegSplitting (
492
493
*this , LT.second , ST->getRealVLen (), Tp, Mask, CostKind);
493
494
if (VRegSplittingCost.isValid ())
@@ -496,7 +497,7 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
496
497
default :
497
498
break ;
498
499
case TTI::SK_PermuteSingleSrc: {
499
- if (Mask.size () >= 2 && LT. second . isFixedLengthVector () ) {
500
+ if (Mask.size () >= 2 ) {
500
501
MVT EltTp = LT.second .getVectorElementType ();
501
502
// If the size of the element is < ELEN then shuffles of interleaves and
502
503
// deinterleaves of 2 vectors can be lowered into the following
@@ -545,10 +546,10 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
545
546
}
546
547
// vrgather + cost of generating the mask constant.
547
548
// We model this for an unknown mask with a single vrgather.
548
- if (LT.second . isFixedLengthVector () && LT.first == 1 &&
549
- ( LT.second .getScalarSizeInBits () != 8 ||
550
- LT. second . getVectorNumElements () <= 256 )) {
551
- VectorType *IdxTy = getVRGatherIndexType (LT.second , *ST, Tp->getContext ());
549
+ if (LT.first == 1 && ( LT.second . getScalarSizeInBits () != 8 ||
550
+ LT.second .getVectorNumElements () <= 256 )) {
551
+ VectorType *IdxTy =
552
+ getVRGatherIndexType (LT.second , *ST, Tp->getContext ());
552
553
InstructionCost IndexCost = getConstantPoolLoadCost (IdxTy, CostKind);
553
554
return IndexCost +
554
555
getRISCVInstructionCost (RISCV::VRGATHER_VV, LT.second , CostKind);
@@ -560,9 +561,8 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
560
561
// 2 x (vrgather + cost of generating the mask constant) + cost of mask
561
562
// register for the second vrgather. We model this for an unknown
562
563
// (shuffle) mask.
563
- if (LT.second .isFixedLengthVector () && LT.first == 1 &&
564
- (LT.second .getScalarSizeInBits () != 8 ||
565
- LT.second .getVectorNumElements () <= 256 )) {
564
+ if (LT.first == 1 && (LT.second .getScalarSizeInBits () != 8 ||
565
+ LT.second .getVectorNumElements () <= 256 )) {
566
566
auto &C = Tp->getContext ();
567
567
auto EC = Tp->getElementCount ();
568
568
VectorType *IdxTy = getVRGatherIndexType (LT.second , *ST, C);
@@ -581,7 +581,6 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
581
581
// multiple destinations. Providing an accurate cost only for splits where
582
582
// the element type remains the same.
583
583
if (!Mask.empty () && LT.first .isValid () && LT.first != 1 &&
584
- LT.second .isFixedLengthVector () &&
585
584
LT.second .getVectorElementType ().getSizeInBits () ==
586
585
Tp->getElementType ()->getPrimitiveSizeInBits () &&
587
586
LT.second .getVectorNumElements () <
0 commit comments