Skip to content

Commit 08fac18

Browse files
[VPlan] Add cost model for CSA
1 parent 5149e22 commit 08fac18

File tree

4 files changed

+237
-171
lines changed

4 files changed

+237
-171
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7441,9 +7441,17 @@ InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan,
74417441
/// not have corresponding recipes in \p Plan and are not marked to be ignored
74427442
/// in \p CostCtx. This means the VPlan contains simplification that the legacy
74437443
/// cost-model did not account for.
7444-
static bool planContainsAdditionalSimplifications(VPlan &Plan,
7445-
VPCostContext &CostCtx,
7446-
Loop *TheLoop) {
7444+
static bool
7445+
planContainsAdditionalSimplifications(VPlan &Plan, VPCostContext &CostCtx,
7446+
Loop *TheLoop,
7447+
LoopVectorizationLegality &Legal) {
7448+
// CSA cost is more complicated since there is significant overhead in the
7449+
// preheader and middle block. It also contains recipes that are not backed by
7450+
// underlying instructions in the original loop. This makes it difficult to
7451+
// model in the legacy cost model.
7452+
if (!Legal.getCSAs().empty())
7453+
return true;
7454+
74477455
// First collect all instructions for the recipes in Plan.
74487456
auto GetInstructionForCost = [](const VPRecipeBase *R) -> Instruction * {
74497457
if (auto *S = dyn_cast<VPSingleDefRecipe>(R))
@@ -7550,9 +7558,9 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() {
75507558
precomputeCosts(BestPlan, BestFactor.Width, CostCtx);
75517559
assert((BestFactor.Width == LegacyVF.Width ||
75527560
planContainsAdditionalSimplifications(getPlanFor(BestFactor.Width),
7553-
CostCtx, OrigLoop) ||
7561+
CostCtx, OrigLoop, *Legal) ||
75547562
planContainsAdditionalSimplifications(getPlanFor(LegacyVF.Width),
7555-
CostCtx, OrigLoop)) &&
7563+
CostCtx, OrigLoop, *Legal)) &&
75567564
" VPlan cost model and legacy cost model disagreed");
75577565
assert((BestFactor.Width.isScalar() || BestFactor.ScalarCost > 0) &&
75587566
"when vectorizing, the scalar cost must be computed.");

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2856,6 +2856,9 @@ class VPCSAHeaderPHIRecipe final : public VPHeaderPHIRecipe {
28562856

28572857
void execute(VPTransformState &State) override;
28582858

2859+
InstructionCost computeCost(ElementCount VF,
2860+
VPCostContext &Ctx) const override;
2861+
28592862
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
28602863
/// Print the recipe.
28612864
void print(raw_ostream &O, const Twine &Indent,
@@ -2887,6 +2890,9 @@ class VPCSADataUpdateRecipe final : public VPSingleDefRecipe {
28872890

28882891
void execute(VPTransformState &State) override;
28892892

2893+
InstructionCost computeCost(ElementCount VF,
2894+
VPCostContext &Ctx) const override;
2895+
28902896
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
28912897
/// Print the recipe.
28922898
void print(raw_ostream &O, const Twine &Indent,
@@ -2933,6 +2939,9 @@ class VPCSAExtractScalarRecipe final : public VPSingleDefRecipe {
29332939

29342940
void execute(VPTransformState &State) override;
29352941

2942+
InstructionCost computeCost(ElementCount VF,
2943+
VPCostContext &Ctx) const override;
2944+
29362945
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
29372946
/// Print the recipe.
29382947
void print(raw_ostream &O, const Twine &Indent,
@@ -2943,7 +2952,7 @@ class VPCSAExtractScalarRecipe final : public VPSingleDefRecipe {
29432952
VPValue *getVPMaskSel() const { return getOperand(1); }
29442953
VPValue *getVPDataSel() const { return getOperand(2); }
29452954
VPValue *getVPCSAVLSel() const { return getOperand(3); }
2946-
bool usesEVL() { return getNumOperands() == 4; }
2955+
bool usesEVL() const { return getNumOperands() == 4; }
29472956
};
29482957

29492958
/// VPPredInstPHIRecipe is a recipe for generating the phi nodes needed when

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2491,6 +2491,24 @@ void VPCSAHeaderPHIRecipe::execute(VPTransformState &State) {
24912491
State.set(this, DataPhi, Part);
24922492
}
24932493

2494+
InstructionCost VPCSAHeaderPHIRecipe::computeCost(ElementCount VF,
2495+
VPCostContext &Ctx) const {
2496+
if (VF.isScalar())
2497+
return 0;
2498+
2499+
InstructionCost C = 0;
2500+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2501+
const TargetTransformInfo &TTI = Ctx.TTI;
2502+
2503+
// FIXME: These costs should be moved into VPInstruction::computeCost. We put
2504+
// them here for now since there is no VPInstruction::computeCost support.
2505+
// CSAInitMask
2506+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
2507+
// CSAInitData
2508+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
2509+
return C;
2510+
}
2511+
24942512
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
24952513
void VPCSADataUpdateRecipe::print(raw_ostream &O, const Twine &Indent,
24962514
VPSlotTracker &SlotTracker) const {
@@ -2519,6 +2537,34 @@ void VPCSADataUpdateRecipe::execute(VPTransformState &State) {
25192537
}
25202538
}
25212539

2540+
InstructionCost VPCSADataUpdateRecipe::computeCost(ElementCount VF,
2541+
VPCostContext &Ctx) const {
2542+
if (VF.isScalar())
2543+
return 0;
2544+
2545+
InstructionCost C = 0;
2546+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2547+
auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
2548+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2549+
const TargetTransformInfo &TTI = Ctx.TTI;
2550+
2551+
// Data Update
2552+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2553+
2554+
// FIXME: These costs should be moved into VPInstruction::computeCost. We put
2555+
// them here for now since they are related to updating the data and there is
2556+
// no VPInstruction::computeCost support at the moment. CSAInitMask AnyActive
2557+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2558+
// vp.reduce.or
2559+
C += TTI.getArithmeticReductionCost(Instruction::Or, VTy, std::nullopt,
2560+
CostKind);
2561+
// VPVLSel
2562+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2563+
// MaskUpdate
2564+
C += TTI.getArithmeticInstrCost(Instruction::Select, MaskTy, CostKind);
2565+
return C;
2566+
}
2567+
25222568
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
25232569
void VPCSAExtractScalarRecipe::print(raw_ostream &O, const Twine &Indent,
25242570
VPSlotTracker &SlotTracker) const {
@@ -2579,6 +2625,60 @@ void VPCSAExtractScalarRecipe::execute(VPTransformState &State) {
25792625
State.set(this, ChooseFromVecOrInit, 0, /*IsScalar=*/true);
25802626
}
25812627

2628+
InstructionCost
2629+
VPCSAExtractScalarRecipe::computeCost(ElementCount VF,
2630+
VPCostContext &Ctx) const {
2631+
if (VF.isScalar())
2632+
return 0;
2633+
2634+
InstructionCost C = 0;
2635+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2636+
auto *Int32VTy =
2637+
VectorType::get(IntegerType::getInt32Ty(VTy->getContext()), VF);
2638+
auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
2639+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2640+
const TargetTransformInfo &TTI = Ctx.TTI;
2641+
2642+
// StepVector
2643+
ArrayRef<Value *> Args;
2644+
IntrinsicCostAttributes CostAttrs(Intrinsic::stepvector, Int32VTy, Args);
2645+
C += TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
2646+
// NegOneSplat
2647+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, Int32VTy);
2648+
// LastIdx
2649+
if (usesEVL()) {
2650+
C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
2651+
CostKind);
2652+
} else {
2653+
// ActiveLaneIdxs
2654+
C += TTI.getArithmeticInstrCost(Instruction::Select,
2655+
MaskTy->getScalarType(), CostKind);
2656+
// MaybeLastIdx
2657+
C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
2658+
CostKind);
2659+
// IsLaneZeroActive
2660+
C += TTI.getArithmeticInstrCost(Instruction::ExtractElement, MaskTy,
2661+
CostKind);
2662+
// MaybeLastIdxEQZero
2663+
C += TTI.getArithmeticInstrCost(Instruction::ICmp, MaskTy->getScalarType(),
2664+
CostKind);
2665+
// And
2666+
C += TTI.getArithmeticInstrCost(Instruction::And, MaskTy->getScalarType(),
2667+
CostKind);
2668+
// LastIdx
2669+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy->getScalarType(),
2670+
CostKind);
2671+
}
2672+
// ExtractFromVec
2673+
C += TTI.getArithmeticInstrCost(Instruction::ExtractElement, VTy, CostKind);
2674+
// LastIdxGeZero
2675+
C += TTI.getArithmeticInstrCost(Instruction::ICmp, Int32VTy, CostKind);
2676+
// ChooseFromVecOrInit
2677+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy->getScalarType(),
2678+
CostKind);
2679+
return C;
2680+
}
2681+
25822682
void VPBranchOnMaskRecipe::execute(VPTransformState &State) {
25832683
assert(State.Lane && "Branch on Mask works only on single instance.");
25842684

0 commit comments

Comments
 (0)