Skip to content

Commit bca407a

Browse files
[VPlan] Add cost model for CSA
1 parent 047a3ff commit bca407a

File tree

4 files changed

+236
-170
lines changed

4 files changed

+236
-170
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7438,9 +7438,17 @@ InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan,
74387438
/// not have corresponding recipes in \p Plan and are not marked to be ignored
74397439
/// in \p CostCtx. This means the VPlan contains simplification that the legacy
74407440
/// cost-model did not account for.
7441-
static bool planContainsAdditionalSimplifications(VPlan &Plan,
7442-
VPCostContext &CostCtx,
7443-
Loop *TheLoop) {
7441+
static bool
7442+
planContainsAdditionalSimplifications(VPlan &Plan, VPCostContext &CostCtx,
7443+
Loop *TheLoop,
7444+
LoopVectorizationLegality &Legal) {
7445+
// CSA cost is more complicated since there is significant overhead in the
7446+
// preheader and middle block. It also contains recipes that are not backed by
7447+
// underlying instructions in the original loop. This makes it difficult to
7448+
// model in the legacy cost model.
7449+
if (!Legal.getCSAs().empty())
7450+
return true;
7451+
74447452
// First collect all instructions for the recipes in Plan.
74457453
auto GetInstructionForCost = [](const VPRecipeBase *R) -> Instruction * {
74467454
if (auto *S = dyn_cast<VPSingleDefRecipe>(R))
@@ -7547,7 +7555,7 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() {
75477555
precomputeCosts(BestPlan, BestFactor.Width, CostCtx);
75487556
assert((BestFactor.Width == LegacyVF.Width ||
75497557
planContainsAdditionalSimplifications(getPlanFor(BestFactor.Width),
7550-
CostCtx, OrigLoop)) &&
7558+
CostCtx, OrigLoop, *Legal)) &&
75517559
" VPlan cost model and legacy cost model disagreed");
75527560
assert((BestFactor.Width.isScalar() || BestFactor.ScalarCost > 0) &&
75537561
"when vectorizing, the scalar cost must be computed.");

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2850,6 +2850,9 @@ class VPCSAHeaderPHIRecipe final : public VPHeaderPHIRecipe {
28502850

28512851
void execute(VPTransformState &State) override;
28522852

2853+
InstructionCost computeCost(ElementCount VF,
2854+
VPCostContext &Ctx) const override;
2855+
28532856
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
28542857
/// Print the recipe.
28552858
void print(raw_ostream &O, const Twine &Indent,
@@ -2881,6 +2884,9 @@ class VPCSADataUpdateRecipe final : public VPSingleDefRecipe {
28812884

28822885
void execute(VPTransformState &State) override;
28832886

2887+
InstructionCost computeCost(ElementCount VF,
2888+
VPCostContext &Ctx) const override;
2889+
28842890
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
28852891
/// Print the recipe.
28862892
void print(raw_ostream &O, const Twine &Indent,
@@ -2927,6 +2933,9 @@ class VPCSAExtractScalarRecipe final : public VPSingleDefRecipe {
29272933

29282934
void execute(VPTransformState &State) override;
29292935

2936+
InstructionCost computeCost(ElementCount VF,
2937+
VPCostContext &Ctx) const override;
2938+
29302939
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
29312940
/// Print the recipe.
29322941
void print(raw_ostream &O, const Twine &Indent,
@@ -2937,7 +2946,7 @@ class VPCSAExtractScalarRecipe final : public VPSingleDefRecipe {
29372946
VPValue *getVPMaskSel() const { return getOperand(1); }
29382947
VPValue *getVPDataSel() const { return getOperand(2); }
29392948
VPValue *getVPCSAVLSel() const { return getOperand(3); }
2940-
bool usesEVL() { return getNumOperands() == 4; }
2949+
bool usesEVL() const { return getNumOperands() == 4; }
29412950
};
29422951

29432952
/// VPPredInstPHIRecipe is a recipe for generating the phi nodes needed when

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2489,6 +2489,24 @@ void VPCSAHeaderPHIRecipe::execute(VPTransformState &State) {
24892489
State.set(this, DataPhi, Part);
24902490
}
24912491

2492+
InstructionCost VPCSAHeaderPHIRecipe::computeCost(ElementCount VF,
2493+
VPCostContext &Ctx) const {
2494+
if (VF.isScalar())
2495+
return 0;
2496+
2497+
InstructionCost C = 0;
2498+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2499+
const TargetTransformInfo &TTI = Ctx.TTI;
2500+
2501+
// FIXME: These costs should be moved into VPInstruction::computeCost. We put
2502+
// them here for now since there is no VPInstruction::computeCost support.
2503+
// CSAInitMask
2504+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
2505+
// CSAInitData
2506+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
2507+
return C;
2508+
}
2509+
24922510
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
24932511
void VPCSADataUpdateRecipe::print(raw_ostream &O, const Twine &Indent,
24942512
VPSlotTracker &SlotTracker) const {
@@ -2517,6 +2535,34 @@ void VPCSADataUpdateRecipe::execute(VPTransformState &State) {
25172535
}
25182536
}
25192537

2538+
InstructionCost VPCSADataUpdateRecipe::computeCost(ElementCount VF,
2539+
VPCostContext &Ctx) const {
2540+
if (VF.isScalar())
2541+
return 0;
2542+
2543+
InstructionCost C = 0;
2544+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2545+
auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
2546+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2547+
const TargetTransformInfo &TTI = Ctx.TTI;
2548+
2549+
// Data Update
2550+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2551+
2552+
// FIXME: These costs should be moved into VPInstruction::computeCost. We put
2553+
// them here for now since they are related to updating the data and there is
2554+
// no VPInstruction::computeCost support at the moment. CSAInitMask AnyActive
2555+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2556+
// vp.reduce.or
2557+
C += TTI.getArithmeticReductionCost(Instruction::Or, VTy, std::nullopt,
2558+
CostKind);
2559+
// VPVLSel
2560+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2561+
// MaskUpdate
2562+
C += TTI.getArithmeticInstrCost(Instruction::Select, MaskTy, CostKind);
2563+
return C;
2564+
}
2565+
25202566
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
25212567
void VPCSAExtractScalarRecipe::print(raw_ostream &O, const Twine &Indent,
25222568
VPSlotTracker &SlotTracker) const {
@@ -2577,6 +2623,60 @@ void VPCSAExtractScalarRecipe::execute(VPTransformState &State) {
25772623
State.set(this, ChooseFromVecOrInit, 0, /*IsScalar=*/true);
25782624
}
25792625

2626+
InstructionCost
2627+
VPCSAExtractScalarRecipe::computeCost(ElementCount VF,
2628+
VPCostContext &Ctx) const {
2629+
if (VF.isScalar())
2630+
return 0;
2631+
2632+
InstructionCost C = 0;
2633+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2634+
auto *Int32VTy =
2635+
VectorType::get(IntegerType::getInt32Ty(VTy->getContext()), VF);
2636+
auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
2637+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2638+
const TargetTransformInfo &TTI = Ctx.TTI;
2639+
2640+
// StepVector
2641+
ArrayRef<Value *> Args;
2642+
IntrinsicCostAttributes CostAttrs(Intrinsic::stepvector, Int32VTy, Args);
2643+
C += TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
2644+
// NegOneSplat
2645+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, Int32VTy);
2646+
// LastIdx
2647+
if (usesEVL()) {
2648+
C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
2649+
CostKind);
2650+
} else {
2651+
// ActiveLaneIdxs
2652+
C += TTI.getArithmeticInstrCost(Instruction::Select,
2653+
MaskTy->getScalarType(), CostKind);
2654+
// MaybeLastIdx
2655+
C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
2656+
CostKind);
2657+
// IsLaneZeroActive
2658+
C += TTI.getArithmeticInstrCost(Instruction::ExtractElement, MaskTy,
2659+
CostKind);
2660+
// MaybeLastIdxEQZero
2661+
C += TTI.getArithmeticInstrCost(Instruction::ICmp, MaskTy->getScalarType(),
2662+
CostKind);
2663+
// And
2664+
C += TTI.getArithmeticInstrCost(Instruction::And, MaskTy->getScalarType(),
2665+
CostKind);
2666+
// LastIdx
2667+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy->getScalarType(),
2668+
CostKind);
2669+
}
2670+
// ExtractFromVec
2671+
C += TTI.getArithmeticInstrCost(Instruction::ExtractElement, VTy, CostKind);
2672+
// LastIdxGeZero
2673+
C += TTI.getArithmeticInstrCost(Instruction::ICmp, Int32VTy, CostKind);
2674+
// ChooseFromVecOrInit
2675+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy->getScalarType(),
2676+
CostKind);
2677+
return C;
2678+
}
2679+
25802680
void VPBranchOnMaskRecipe::execute(VPTransformState &State) {
25812681
assert(State.Lane && "Branch on Mask works only on single instance.");
25822682

0 commit comments

Comments
 (0)