Skip to content

Commit b1ef919

Browse files
committed
[ARM] Add CostKind to getMVEVectorCostFactor.
This adds the CostKind to getMVEVectorCostFactor, so that it can automatically account for CodeSize costs, where it returns a cost of 1 not the MVEFactor used for Throughput/Latency. This helps simplify the caller code and allows us to get the codesize cost more correct in more cases.
1 parent 059a335 commit b1ef919

File tree

6 files changed

+158
-150
lines changed

6 files changed

+158
-150
lines changed

llvm/lib/Target/ARM/ARMSubtarget.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "ARMISelLowering.h"
2121
#include "ARMSelectionDAGInfo.h"
2222
#include "llvm/ADT/Triple.h"
23+
#include "llvm/Analysis/TargetTransformInfo.h"
2324
#include "llvm/CodeGen/GlobalISel/CallLowering.h"
2425
#include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
2526
#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
@@ -911,7 +912,12 @@ class ARMSubtarget : public ARMGenSubtargetInfo {
911912

912913
unsigned getPrefLoopLogAlignment() const { return PrefLoopLogAlignment; }
913914

914-
unsigned getMVEVectorCostFactor() const { return MVEVectorCostFactor; }
915+
unsigned
916+
getMVEVectorCostFactor(TargetTransformInfo::TargetCostKind CostKind) const {
917+
if (CostKind == TargetTransformInfo::TCK_CodeSize)
918+
return 1;
919+
return MVEVectorCostFactor;
920+
}
915921

916922
bool ignoreCSRForAllocationOrder(const MachineFunction &MF,
917923
unsigned PhysReg) const override;

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
426426
(Opcode == Instruction::FPExt || Opcode == Instruction::FPTrunc) &&
427427
IsLegalFPType(SrcTy) && IsLegalFPType(DstTy)))
428428
if (CCH == TTI::CastContextHint::Masked && DstTy.getSizeInBits() > 128)
429-
return 2 * DstTy.getVectorNumElements() * ST->getMVEVectorCostFactor();
429+
return 2 * DstTy.getVectorNumElements() *
430+
ST->getMVEVectorCostFactor(CostKind);
430431

431432
// The extend of other kinds of load is free
432433
if (CCH == TTI::CastContextHint::Normal ||
@@ -470,7 +471,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
470471
if (const auto *Entry =
471472
ConvertCostTableLookup(MVELoadConversionTbl, ISD,
472473
DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
473-
return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
474+
return Entry->Cost * ST->getMVEVectorCostFactor(CostKind);
474475
}
475476

476477
static const TypeConversionCostTblEntry MVEFLoadConversionTbl[] = {
@@ -482,7 +483,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
482483
if (const auto *Entry =
483484
ConvertCostTableLookup(MVEFLoadConversionTbl, ISD,
484485
DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
485-
return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
486+
return Entry->Cost * ST->getMVEVectorCostFactor(CostKind);
486487
}
487488

488489
// The truncate of a store is free. This is the mirror of extends above.
@@ -499,7 +500,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
499500
if (const auto *Entry =
500501
ConvertCostTableLookup(MVEStoreConversionTbl, ISD,
501502
SrcTy.getSimpleVT(), DstTy.getSimpleVT()))
502-
return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
503+
return Entry->Cost * ST->getMVEVectorCostFactor(CostKind);
503504
}
504505

505506
static const TypeConversionCostTblEntry MVEFStoreConversionTbl[] = {
@@ -510,7 +511,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
510511
if (const auto *Entry =
511512
ConvertCostTableLookup(MVEFStoreConversionTbl, ISD,
512513
SrcTy.getSimpleVT(), DstTy.getSimpleVT()))
513-
return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
514+
return Entry->Cost * ST->getMVEVectorCostFactor(CostKind);
514515
}
515516
}
516517

@@ -734,7 +735,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
734735
if (const auto *Entry = ConvertCostTableLookup(MVEVectorConversionTbl,
735736
ISD, DstTy.getSimpleVT(),
736737
SrcTy.getSimpleVT()))
737-
return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
738+
return Entry->Cost * ST->getMVEVectorCostFactor(CostKind);
738739
}
739740

740741
if (ISD == ISD::FP_ROUND || ISD == ISD::FP_EXTEND) {
@@ -784,7 +785,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
784785
}
785786

786787
int BaseCost = ST->hasMVEIntegerOps() && Src->isVectorTy()
787-
? ST->getMVEVectorCostFactor()
788+
? ST->getMVEVectorCostFactor(CostKind)
788789
: 1;
789790
return AdjustCost(
790791
BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
@@ -819,7 +820,7 @@ int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
819820
// vector, to prevent vectorising where we end up just scalarising the
820821
// result anyway.
821822
return std::max(BaseT::getVectorInstrCost(Opcode, ValTy, Index),
822-
ST->getMVEVectorCostFactor()) *
823+
ST->getMVEVectorCostFactor(TTI::TCK_RecipThroughput)) *
823824
cast<FixedVectorType>(ValTy)->getNumElements() / 2;
824825
}
825826

@@ -881,9 +882,8 @@ int ARMTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
881882
// Default to cheap (throughput/size of 1 instruction) but adjust throughput
882883
// for "multiple beats" potentially needed by MVE instructions.
883884
int BaseCost = 1;
884-
if (CostKind != TTI::TCK_CodeSize && ST->hasMVEIntegerOps() &&
885-
ValTy->isVectorTy())
886-
BaseCost = ST->getMVEVectorCostFactor();
885+
if (ST->hasMVEIntegerOps() && ValTy->isVectorTy())
886+
BaseCost = ST->getMVEVectorCostFactor(CostKind);
887887

888888
return BaseCost *
889889
BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
@@ -1132,11 +1132,12 @@ int ARMTTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp,
11321132

11331133
if (const auto *Entry = CostTableLookup(MVEDupTbl, ISD::VECTOR_SHUFFLE,
11341134
LT.second))
1135-
return LT.first * Entry->Cost * ST->getMVEVectorCostFactor();
1135+
return LT.first * Entry->Cost *
1136+
ST->getMVEVectorCostFactor(TTI::TCK_RecipThroughput);
11361137
}
11371138
}
11381139
int BaseCost = ST->hasMVEIntegerOps() && Tp->isVectorTy()
1139-
? ST->getMVEVectorCostFactor()
1140+
? ST->getMVEVectorCostFactor(TTI::TCK_RecipThroughput)
11401141
: 1;
11411142
return BaseCost * BaseT::getShuffleCost(Kind, Tp, Index, SubTp);
11421143
}
@@ -1262,9 +1263,8 @@ int ARMTTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty,
12621263
// Default to cheap (throughput/size of 1 instruction) but adjust throughput
12631264
// for "multiple beats" potentially needed by MVE instructions.
12641265
int BaseCost = 1;
1265-
if (CostKind != TTI::TCK_CodeSize && ST->hasMVEIntegerOps() &&
1266-
Ty->isVectorTy())
1267-
BaseCost = ST->getMVEVectorCostFactor();
1266+
if (ST->hasMVEIntegerOps() && Ty->isVectorTy())
1267+
BaseCost = ST->getMVEVectorCostFactor(CostKind);
12681268

12691269
// The rest of this mostly follows what is done in BaseT::getArithmeticInstrCost,
12701270
// without treating floats as more expensive that scalars or increasing the
@@ -1321,11 +1321,11 @@ int ARMTTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
13211321
: cast<Instruction>(I->getOperand(0))->getOperand(0)->getType();
13221322
if (SrcVTy->getNumElements() == 4 && SrcVTy->getScalarType()->isHalfTy() &&
13231323
DstTy->getScalarType()->isFloatTy())
1324-
return ST->getMVEVectorCostFactor();
1324+
return ST->getMVEVectorCostFactor(CostKind);
13251325
}
13261326

13271327
int BaseCost = ST->hasMVEIntegerOps() && Src->isVectorTy()
1328-
? ST->getMVEVectorCostFactor()
1328+
? ST->getMVEVectorCostFactor(CostKind)
13291329
: 1;
13301330
return BaseCost * BaseT::getMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
13311331
CostKind, I);
@@ -1337,9 +1337,9 @@ unsigned ARMTTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
13371337
TTI::TargetCostKind CostKind) {
13381338
if (ST->hasMVEIntegerOps()) {
13391339
if (Opcode == Instruction::Load && isLegalMaskedLoad(Src, Alignment))
1340-
return ST->getMVEVectorCostFactor();
1340+
return ST->getMVEVectorCostFactor(CostKind);
13411341
if (Opcode == Instruction::Store && isLegalMaskedStore(Src, Alignment))
1342-
return ST->getMVEVectorCostFactor();
1342+
return ST->getMVEVectorCostFactor(CostKind);
13431343
}
13441344
if (!isa<FixedVectorType>(Src))
13451345
return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
@@ -1368,7 +1368,8 @@ int ARMTTIImpl::getInterleavedMemoryOpCost(
13681368
// vldN/vstN only support legal vector types of size 64 or 128 in bits.
13691369
// Accesses having vector types that are a multiple of 128 bits can be
13701370
// matched to more than one vldN/vstN instruction.
1371-
int BaseCost = ST->hasMVEIntegerOps() ? ST->getMVEVectorCostFactor() : 1;
1371+
int BaseCost =
1372+
ST->hasMVEIntegerOps() ? ST->getMVEVectorCostFactor(CostKind) : 1;
13721373
if (NumElts % Factor == 0 &&
13731374
TLI->isLegalInterleavedAccessType(Factor, SubVecTy, Alignment, DL))
13741375
return Factor * BaseCost * TLI->getNumInterleavedAccesses(SubVecTy, DL);
@@ -1413,7 +1414,8 @@ unsigned ARMTTIImpl::getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
14131414
// multiplied by the number of elements being loaded. This is possibly very
14141415
// conservative, but even so we still end up vectorising loops because the
14151416
// cost per iteration for many loops is lower than for scalar loops.
1416-
unsigned VectorCost = NumElems * LT.first * ST->getMVEVectorCostFactor();
1417+
unsigned VectorCost =
1418+
NumElems * LT.first * ST->getMVEVectorCostFactor(CostKind);
14171419
// The scalarization cost should be a lot higher. We use the number of vector
14181420
// elements plus the scalarization overhead.
14191421
unsigned ScalarCost = NumElems * LT.first +
@@ -1506,7 +1508,7 @@ int ARMTTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
15061508
{ISD::ADD, MVT::v4i32, 1},
15071509
};
15081510
if (const auto *Entry = CostTableLookup(CostTblAdd, ISD, LT.second))
1509-
return Entry->Cost * ST->getMVEVectorCostFactor() * LT.first;
1511+
return Entry->Cost * ST->getMVEVectorCostFactor(CostKind) * LT.first;
15101512

15111513
return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwiseForm,
15121514
CostKind);
@@ -1524,7 +1526,7 @@ ARMTTIImpl::getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned,
15241526
(LT.second == MVT::v8i16 &&
15251527
ResVT.getSizeInBits() <= (IsMLA ? 64 : 32)) ||
15261528
(LT.second == MVT::v4i32 && ResVT.getSizeInBits() <= 64))
1527-
return ST->getMVEVectorCostFactor() * LT.first;
1529+
return ST->getMVEVectorCostFactor(CostKind) * LT.first;
15281530
}
15291531

15301532
return BaseT::getExtendedAddReductionCost(IsMLA, IsUnsigned, ResTy, ValTy,
@@ -1566,7 +1568,7 @@ int ARMTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
15661568
ICA.getReturnType()->getScalarSizeInBits()
15671569
? 1
15681570
: 4;
1569-
return LT.first * ST->getMVEVectorCostFactor() * Instrs;
1571+
return LT.first * ST->getMVEVectorCostFactor(CostKind) * Instrs;
15701572
}
15711573
break;
15721574
}

0 commit comments

Comments
 (0)