Skip to content

[SLP]: Introduce and use getDataFlowCost #112999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,14 @@ class TargetTransformInfo {
Function *F, Type *RetTy, ArrayRef<Type *> Tys,
TTI::TargetCostKind CostKind = TTI::TCK_SizeAndLatency) const;

/// \returns The cost of propagating Type \p DataType through Basic Block /
/// function boundaries. If \p IsCallingConv is specified, then \p DataType is
/// associated with either a function argument or return. Otherwise, \p
/// DataType is used in either a GEP instruction, or spans across BasicBlocks
/// (this is relevant because SelectionDAG builder may, for example, scalarize
/// illegal vectors across blocks, which introduces extract/insert code).
InstructionCost getDataFlowCost(Type *DataType, bool IsCallingConv) const;

/// \returns The number of pieces into which the provided type must be
/// split during legalization. Zero is returned when the answer is unknown.
unsigned getNumberOfParts(Type *Tp) const;
Expand Down Expand Up @@ -2096,6 +2104,8 @@ class TargetTransformInfo::Concept {
virtual InstructionCost getCallInstrCost(Function *F, Type *RetTy,
ArrayRef<Type *> Tys,
TTI::TargetCostKind CostKind) = 0;
virtual InstructionCost getDataFlowCost(Type *DataType,
bool IsCallingConv) = 0;
virtual unsigned getNumberOfParts(Type *Tp) = 0;
virtual InstructionCost
getAddressComputationCost(Type *Ty, ScalarEvolution *SE, const SCEV *Ptr) = 0;
Expand Down Expand Up @@ -2781,6 +2791,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
TTI::TargetCostKind CostKind) override {
return Impl.getCallInstrCost(F, RetTy, Tys, CostKind);
}
InstructionCost getDataFlowCost(Type *DataType, bool IsCallingConv) override {
return Impl.getDataFlowCost(DataType, IsCallingConv);
}
unsigned getNumberOfParts(Type *Tp) override {
return Impl.getNumberOfParts(Tp);
}
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,10 @@ class TargetTransformInfoImplBase {
return 1;
}

InstructionCost getDataFlowCost(Type *DataType, bool IsCallingConv) const {
return 0;
}

// Assume that we have a register of the right size for the type.
unsigned getNumberOfParts(Type *Tp) const { return 1; }

Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2410,6 +2410,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return 10;
}

InstructionCost getDataFlowCost(Type *DataType, bool IsCallingConv) {
return 0;
}

unsigned getNumberOfParts(Type *Tp) {
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
return LT.first.isValid() ? *LT.first.getValue() : 0;
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,13 @@ TargetTransformInfo::getCallInstrCost(Function *F, Type *RetTy,
return Cost;
}

InstructionCost TargetTransformInfo::getDataFlowCost(Type *DataType,
bool IsCallingConv) const {
InstructionCost Cost = TTIImpl->getDataFlowCost(DataType, IsCallingConv);
assert(Cost >= 0 && "TTI should not produce negative costs!");
return Cost;
}

unsigned TargetTransformInfo::getNumberOfParts(Type *Tp) const {
return TTIImpl->getNumberOfParts(Tp);
}
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,14 @@ bool GCNTTIImpl::hasBranchDivergence(const Function *F) const {
return !F || !ST->isSingleLaneExecution(*F);
}

InstructionCost GCNTTIImpl::getDataFlowCost(Type *DataType,
bool IsCallingConv) {
if (IsCallingConv || isTypeLegal(DataType))
return BaseT::getDataFlowCost(DataType, IsCallingConv);

return getNumberOfParts(DataType);
}

unsigned GCNTTIImpl::getNumberOfRegisters(unsigned RCID) const {
// NB: RCID is not an RCID. In fact it is 0 or 1 for scalar or vector
// registers. See getRegisterClassForType for the implementation.
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind,
const Instruction *I = nullptr);

InstructionCost getDataFlowCost(Type *DataType, bool IsCallingConv);

bool isInlineAsmSourceOfDivergence(const CallInst *CI,
ArrayRef<unsigned> Indices = {}) const;

Expand Down
53 changes: 47 additions & 6 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9075,15 +9075,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
auto *FinalVecTy = FixedVectorType::get(ScalarTy, EntryVF);

bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty();
InstructionCost CommonCost = 0;
if (E->State == TreeEntry::NeedToGather) {
if (allConstant(VL))
return 0;
return CommonCost;
if (isa<InsertElementInst>(VL[0]))
return InstructionCost::getInvalid();
return processBuildVector<ShuffleCostEstimator, InstructionCost>(
E, ScalarTy, *TTI, VectorizedVals, *this, CheckedExtracts);
return CommonCost +
processBuildVector<ShuffleCostEstimator, InstructionCost>(
E, ScalarTy, *TTI, VectorizedVals, *this, CheckedExtracts);
}
InstructionCost CommonCost = 0;
SmallVector<int> Mask;
bool IsReverseOrder = isReverseOrder(E->ReorderIndices);
if (!E->ReorderIndices.empty() &&
Expand Down Expand Up @@ -9222,6 +9223,18 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
OpTE->Scalars.size());
}

// Calculate the cost difference of propagating a vector vs series of
// scalars across blocks. This may be nonzero in the case of illegal
// vectors.
Comment on lines +9227 to +9228
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment talks about illegal vector types but the code affects only legal vector types

Type *ScalarTy = VL0->getType()->getScalarType();
if (ScalarTy && isValidElementType(ScalarTy)) {
ScalarCost += TTI->getDataFlowCost(ScalarTy,
/*IsCallingConv=*/false) *
Comment on lines +9231 to +9232
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it account register pressure or something else?

EntryVF;
CommonCost += TTI->getDataFlowCost(
FixedVectorType::get(ScalarTy, EntryVF), /*IsCallingConv=*/false);
}

return CommonCost - ScalarCost;
}
case Instruction::ExtractValue:
Expand Down Expand Up @@ -10241,6 +10254,27 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {

InstructionCost C = getEntryCost(&TE, VectorizedVals, CheckedExtracts);
Cost += C;

// Calculate the cost difference of propagating a vector vs series of scalars
// across blocks. This may be nonzero in the case of illegal vectors.
Instruction *VL0 = TE.getMainOp();
if (VL0 && ((I + 1) < VectorizableTree.size())) {
Instruction *VL1 = VectorizableTree[I + 1]->getMainOp();
if (VL1 && (VL0->getParent() != VL1->getParent())) {
Type *ScalarTy = VL0->getType()->getScalarType();
if (ScalarTy && isValidElementType(ScalarTy)) {
InstructionCost ScalarDFlow =
TTI->getDataFlowCost(ScalarTy,
/*IsCallingConv=*/false) *
TE.getVectorFactor();
InstructionCost VectorDFlow = TTI->getDataFlowCost(
FixedVectorType::get(ScalarTy, TE.getVectorFactor()),
/*IsCallingConv=*/false);
Cost += (VectorDFlow - ScalarDFlow);
}
}
}

Comment on lines +10257 to +10277
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It should not be here, it should be implemented in getEntryCost.
  2. vectorizableTree[I + 1] does not always point to the operand of the previous node

LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C << " for bundle "
<< shortBundleName(TE.Scalars) << ".\n"
<< "SLP: Current total cost = " << Cost << "\n");
Expand All @@ -10257,8 +10291,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
for (ExternalUser &EU : ExternalUses) {
// We only add extract cost once for the same scalar.
if (!isa_and_nonnull<InsertElementInst>(EU.User) &&
!ExtractCostCalculated.insert(EU.Scalar).second)
!ExtractCostCalculated.insert(EU.Scalar).second) {
continue;
}

// Uses by ephemeral values are free (because the ephemeral value will be
// removed prior to code generation, and so the extraction will be
Expand All @@ -10267,8 +10302,14 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
continue;

// No extract cost for vector "scalar"
if (isa<FixedVectorType>(EU.Scalar->getType()))
if (isa<FixedVectorType>(EU.Scalar->getType())) {
// Account for any additional costs required by CallingConvention for the
// type.
if (isa_and_nonnull<ReturnInst>(EU.User))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outgoing call arguments too. But this is just the type legality cost?

Cost +=
TTI->getDataFlowCost(EU.Scalar->getType(), /*IsCallingConv=*/true);
continue;
}

// If found user is an insertelement, do not calculate extract cost but try
// to detect it as a final shuffled/identity match.
Expand Down
Loading