Skip to content

Commit bbaac6f

Browse files
[VPlan] Add cost model for CSA
1 parent 1441aea commit bbaac6f

File tree

4 files changed

+237
-171
lines changed

4 files changed

+237
-171
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7462,9 +7462,17 @@ InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan,
74627462
/// not have corresponding recipes in \p Plan and are not marked to be ignored
74637463
/// in \p CostCtx. This means the VPlan contains simplification that the legacy
74647464
/// cost-model did not account for.
7465-
static bool planContainsAdditionalSimplifications(VPlan &Plan,
7466-
VPCostContext &CostCtx,
7467-
Loop *TheLoop) {
7465+
static bool
7466+
planContainsAdditionalSimplifications(VPlan &Plan, VPCostContext &CostCtx,
7467+
Loop *TheLoop,
7468+
LoopVectorizationLegality &Legal) {
7469+
// CSA cost is more complicated since there is significant overhead in the
7470+
// preheader and middle block. It also contains recipes that are not backed by
7471+
// underlying instructions in the original loop. This makes it difficult to
7472+
// model in the legacy cost model.
7473+
if (!Legal.getCSAs().empty())
7474+
return true;
7475+
74687476
// First collect all instructions for the recipes in Plan.
74697477
auto GetInstructionForCost = [](const VPRecipeBase *R) -> Instruction * {
74707478
if (auto *S = dyn_cast<VPSingleDefRecipe>(R))
@@ -7571,9 +7579,9 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() {
75717579
precomputeCosts(BestPlan, BestFactor.Width, CostCtx);
75727580
assert((BestFactor.Width == LegacyVF.Width ||
75737581
planContainsAdditionalSimplifications(getPlanFor(BestFactor.Width),
7574-
CostCtx, OrigLoop) ||
7582+
CostCtx, OrigLoop, *Legal) ||
75757583
planContainsAdditionalSimplifications(getPlanFor(LegacyVF.Width),
7576-
CostCtx, OrigLoop)) &&
7584+
CostCtx, OrigLoop, *Legal)) &&
75777585
" VPlan cost model and legacy cost model disagreed");
75787586
assert((BestFactor.Width.isScalar() || BestFactor.ScalarCost > 0) &&
75797587
"when vectorizing, the scalar cost must be computed.");

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2897,6 +2897,9 @@ class VPCSAHeaderPHIRecipe final : public VPHeaderPHIRecipe {
28972897

28982898
void execute(VPTransformState &State) override;
28992899

2900+
InstructionCost computeCost(ElementCount VF,
2901+
VPCostContext &Ctx) const override;
2902+
29002903
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
29012904
/// Print the recipe.
29022905
void print(raw_ostream &O, const Twine &Indent,
@@ -2928,6 +2931,9 @@ class VPCSADataUpdateRecipe final : public VPSingleDefRecipe {
29282931

29292932
void execute(VPTransformState &State) override;
29302933

2934+
InstructionCost computeCost(ElementCount VF,
2935+
VPCostContext &Ctx) const override;
2936+
29312937
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
29322938
/// Print the recipe.
29332939
void print(raw_ostream &O, const Twine &Indent,
@@ -2974,6 +2980,9 @@ class VPCSAExtractScalarRecipe final : public VPSingleDefRecipe {
29742980

29752981
void execute(VPTransformState &State) override;
29762982

2983+
InstructionCost computeCost(ElementCount VF,
2984+
VPCostContext &Ctx) const override;
2985+
29772986
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
29782987
/// Print the recipe.
29792988
void print(raw_ostream &O, const Twine &Indent,
@@ -2984,7 +2993,7 @@ class VPCSAExtractScalarRecipe final : public VPSingleDefRecipe {
29842993
VPValue *getVPMaskSel() const { return getOperand(1); }
29852994
VPValue *getVPDataSel() const { return getOperand(2); }
29862995
VPValue *getVPCSAVLSel() const { return getOperand(3); }
2987-
bool usesEVL() { return getNumOperands() == 4; }
2996+
bool usesEVL() const { return getNumOperands() == 4; }
29882997
};
29892998

29902999
/// VPPredInstPHIRecipe is a recipe for generating the phi nodes needed when

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,6 +2496,24 @@ void VPCSAHeaderPHIRecipe::execute(VPTransformState &State) {
24962496
State.set(this, DataPhi, Part);
24972497
}
24982498

2499+
InstructionCost VPCSAHeaderPHIRecipe::computeCost(ElementCount VF,
2500+
VPCostContext &Ctx) const {
2501+
if (VF.isScalar())
2502+
return 0;
2503+
2504+
InstructionCost C = 0;
2505+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2506+
const TargetTransformInfo &TTI = Ctx.TTI;
2507+
2508+
// FIXME: These costs should be moved into VPInstruction::computeCost. We put
2509+
// them here for now since there is no VPInstruction::computeCost support.
2510+
// CSAInitMask
2511+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
2512+
// CSAInitData
2513+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
2514+
return C;
2515+
}
2516+
24992517
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
25002518
void VPCSADataUpdateRecipe::print(raw_ostream &O, const Twine &Indent,
25012519
VPSlotTracker &SlotTracker) const {
@@ -2524,6 +2542,34 @@ void VPCSADataUpdateRecipe::execute(VPTransformState &State) {
25242542
}
25252543
}
25262544

2545+
InstructionCost VPCSADataUpdateRecipe::computeCost(ElementCount VF,
2546+
VPCostContext &Ctx) const {
2547+
if (VF.isScalar())
2548+
return 0;
2549+
2550+
InstructionCost C = 0;
2551+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2552+
auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
2553+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2554+
const TargetTransformInfo &TTI = Ctx.TTI;
2555+
2556+
// Data Update
2557+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2558+
2559+
// FIXME: These costs should be moved into VPInstruction::computeCost. We put
2560+
// them here for now since they are related to updating the data and there is
2561+
// no VPInstruction::computeCost support at the moment. CSAInitMask AnyActive
2562+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2563+
// vp.reduce.or
2564+
C += TTI.getArithmeticReductionCost(Instruction::Or, VTy, std::nullopt,
2565+
CostKind);
2566+
// VPVLSel
2567+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2568+
// MaskUpdate
2569+
C += TTI.getArithmeticInstrCost(Instruction::Select, MaskTy, CostKind);
2570+
return C;
2571+
}
2572+
25272573
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
25282574
void VPCSAExtractScalarRecipe::print(raw_ostream &O, const Twine &Indent,
25292575
VPSlotTracker &SlotTracker) const {
@@ -2584,6 +2630,60 @@ void VPCSAExtractScalarRecipe::execute(VPTransformState &State) {
25842630
State.set(this, ChooseFromVecOrInit, 0, /*IsScalar=*/true);
25852631
}
25862632

2633+
InstructionCost
2634+
VPCSAExtractScalarRecipe::computeCost(ElementCount VF,
2635+
VPCostContext &Ctx) const {
2636+
if (VF.isScalar())
2637+
return 0;
2638+
2639+
InstructionCost C = 0;
2640+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2641+
auto *Int32VTy =
2642+
VectorType::get(IntegerType::getInt32Ty(VTy->getContext()), VF);
2643+
auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
2644+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2645+
const TargetTransformInfo &TTI = Ctx.TTI;
2646+
2647+
// StepVector
2648+
ArrayRef<Value *> Args;
2649+
IntrinsicCostAttributes CostAttrs(Intrinsic::stepvector, Int32VTy, Args);
2650+
C += TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
2651+
// NegOneSplat
2652+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, Int32VTy);
2653+
// LastIdx
2654+
if (usesEVL()) {
2655+
C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
2656+
CostKind);
2657+
} else {
2658+
// ActiveLaneIdxs
2659+
C += TTI.getArithmeticInstrCost(Instruction::Select,
2660+
MaskTy->getScalarType(), CostKind);
2661+
// MaybeLastIdx
2662+
C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
2663+
CostKind);
2664+
// IsLaneZeroActive
2665+
C += TTI.getArithmeticInstrCost(Instruction::ExtractElement, MaskTy,
2666+
CostKind);
2667+
// MaybeLastIdxEQZero
2668+
C += TTI.getArithmeticInstrCost(Instruction::ICmp, MaskTy->getScalarType(),
2669+
CostKind);
2670+
// And
2671+
C += TTI.getArithmeticInstrCost(Instruction::And, MaskTy->getScalarType(),
2672+
CostKind);
2673+
// LastIdx
2674+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy->getScalarType(),
2675+
CostKind);
2676+
}
2677+
// ExtractFromVec
2678+
C += TTI.getArithmeticInstrCost(Instruction::ExtractElement, VTy, CostKind);
2679+
// LastIdxGeZero
2680+
C += TTI.getArithmeticInstrCost(Instruction::ICmp, Int32VTy, CostKind);
2681+
// ChooseFromVecOrInit
2682+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy->getScalarType(),
2683+
CostKind);
2684+
return C;
2685+
}
2686+
25872687
void VPBranchOnMaskRecipe::execute(VPTransformState &State) {
25882688
assert(State.Lane && "Branch on Mask works only on single instance.");
25892689

0 commit comments

Comments
 (0)