Skip to content

Commit b873aba

Browse files
committed
[LoopVectorizer] NFCI: Calculate register usage based on TLI.getTypeLegalizationCost.
This is more accurate than dividing the bitwidth based on the element count by the maximum register size, as it can just reuse whatever has been calculated for legalization of these types. This change is also necessary when calculating register usage for scalable vectors, where the legalization of these types cannot be done based on the widest register size, because that does not take the 'vscale' component into account. Reviewed By: SjoerdMeijer Differential Revision: https://reviews.llvm.org/D91059
1 parent 91ce6fb commit b873aba

File tree

5 files changed

+19
-7
lines changed

5 files changed

+19
-7
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,9 @@ class TargetTransformInfo {
708708
/// Return true if this type is legal.
709709
bool isTypeLegal(Type *Ty) const;
710710

711+
/// Returns the estimated number of registers required to represent \p Ty.
712+
unsigned getRegUsageForType(Type *Ty) const;
713+
711714
/// Return true if switches should be turned into lookup tables for the
712715
/// target.
713716
bool shouldBuildLookupTables() const;
@@ -1447,6 +1450,7 @@ class TargetTransformInfo::Concept {
14471450
virtual bool isProfitableToHoist(Instruction *I) = 0;
14481451
virtual bool useAA() = 0;
14491452
virtual bool isTypeLegal(Type *Ty) = 0;
1453+
virtual unsigned getRegUsageForType(Type *Ty) = 0;
14501454
virtual bool shouldBuildLookupTables() = 0;
14511455
virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0;
14521456
virtual bool useColdCCForColdCall(Function &F) = 0;
@@ -1807,6 +1811,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
18071811
}
18081812
bool useAA() override { return Impl.useAA(); }
18091813
bool isTypeLegal(Type *Ty) override { return Impl.isTypeLegal(Ty); }
1814+
unsigned getRegUsageForType(Type *Ty) override {
1815+
return Impl.getRegUsageForType(Ty);
1816+
}
18101817
bool shouldBuildLookupTables() override {
18111818
return Impl.shouldBuildLookupTables();
18121819
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ class TargetTransformInfoImplBase {
259259

260260
bool isTypeLegal(Type *Ty) { return false; }
261261

262+
unsigned getRegUsageForType(Type *Ty) { return 1; }
263+
262264
bool shouldBuildLookupTables() { return true; }
263265
bool shouldBuildLookupTablesForConstant(Constant *C) { return true; }
264266

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
297297
return getTLI()->isTypeLegal(VT);
298298
}
299299

300+
unsigned getRegUsageForType(Type *Ty) {
301+
return getTLI()->getTypeLegalizationCost(DL, Ty).first;
302+
}
303+
300304
int getGEPCost(Type *PointeeType, const Value *Ptr,
301305
ArrayRef<const Value *> Operands) {
302306
return BaseT::getGEPCost(PointeeType, Ptr, Operands);

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,10 @@ bool TargetTransformInfo::isTypeLegal(Type *Ty) const {
482482
return TTIImpl->isTypeLegal(Ty);
483483
}
484484

485+
unsigned TargetTransformInfo::getRegUsageForType(Type *Ty) const {
486+
return TTIImpl->getRegUsageForType(Ty);
487+
}
488+
485489
bool TargetTransformInfo::shouldBuildLookupTables() const {
486490
return TTIImpl->shouldBuildLookupTables();
487491
}

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5793,8 +5793,6 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
57935793
unsigned MaxSafeDepDist = -1U;
57945794
if (Legal->getMaxSafeDepDistBytes() != -1U)
57955795
MaxSafeDepDist = Legal->getMaxSafeDepDistBytes() * 8;
5796-
unsigned WidestRegister =
5797-
std::min(TTI.getRegisterBitWidth(true), MaxSafeDepDist);
57985796
const DataLayout &DL = TheFunction->getParent()->getDataLayout();
57995797

58005798
SmallVector<RegisterUsage, 8> RUs(VFs.size());
@@ -5803,13 +5801,10 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
58035801
LLVM_DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n");
58045802

58055803
// A lambda that gets the register usage for the given type and VF.
5806-
auto GetRegUsage = [&DL, WidestRegister](Type *Ty, ElementCount VF) {
5804+
auto GetRegUsage = [&DL, &TTI=TTI](Type *Ty, ElementCount VF) {
58075805
if (Ty->isTokenTy())
58085806
return 0U;
5809-
unsigned TypeSize = DL.getTypeSizeInBits(Ty->getScalarType());
5810-
assert(!VF.isScalable() && "scalable vectors not yet supported.");
5811-
return std::max<unsigned>(1, VF.getKnownMinValue() * TypeSize /
5812-
WidestRegister);
5807+
return TTI.getRegUsageForType(VectorType::get(Ty, VF));
58135808
};
58145809

58155810
for (unsigned int i = 0, s = IdxToInstr.size(); i < s; ++i) {

0 commit comments

Comments
 (0)