Skip to content

Commit 4aba2d3

Browse files
committed
[VPlan] Implment VPReductionRecipe::computeCost(). NFC
Implementation of `computeCost()` function for `VPReductionRecipe`.
1 parent af47038 commit 4aba2d3

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,6 +2393,10 @@ class VPReductionRecipe : public VPSingleDefRecipe {
23932393
/// Generate the reduction in the loop
23942394
void execute(VPTransformState &State) override;
23952395

2396+
/// Return the cost of VPReductionRecipe.
2397+
InstructionCost computeCost(ElementCount VF,
2398+
VPCostContext &Ctx) const override;
2399+
23962400
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
23972401
/// Print the recipe.
23982402
void print(raw_ostream &O, const Twine &Indent,

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2022,6 +2022,30 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
20222022
State.set(this, NewRed, /*IsScalar*/ true);
20232023
}
20242024

2025+
InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
2026+
VPCostContext &Ctx) const {
2027+
RecurKind RdxKind = RdxDesc.getRecurrenceKind();
2028+
Type *ElementTy = RdxDesc.getRecurrenceType();
2029+
auto *VectorTy = dyn_cast<VectorType>(ToVectorTy(ElementTy, VF));
2030+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2031+
unsigned Opcode = RdxDesc.getOpcode();
2032+
2033+
if (VectorTy == nullptr)
2034+
return InstructionCost::getInvalid();
2035+
2036+
// Cost = Reduction cost + BinOp cost
2037+
InstructionCost Cost =
2038+
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
2039+
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
2040+
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2041+
return Cost + Ctx.TTI.getMinMaxReductionCost(
2042+
Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2043+
}
2044+
2045+
return Cost + Ctx.TTI.getArithmeticReductionCost(
2046+
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2047+
}
2048+
20252049
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
20262050
void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
20272051
VPSlotTracker &SlotTracker) const {

0 commit comments

Comments
 (0)