Skip to content

Commit 073fd60

Browse files
[VPlan] Add cost model for CSA
1 parent b293b60 commit 073fd60

File tree

4 files changed

+236
-170
lines changed

4 files changed

+236
-170
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7298,9 +7298,17 @@ InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan,
72987298
/// not have corresponding recipes in \p Plan and are not marked to be ignored
72997299
/// in \p CostCtx. This means the VPlan contains simplification that the legacy
73007300
/// cost-model did not account for.
7301-
static bool planContainsAdditionalSimplifications(VPlan &Plan,
7302-
VPCostContext &CostCtx,
7303-
Loop *TheLoop) {
7301+
static bool
7302+
planContainsAdditionalSimplifications(VPlan &Plan, VPCostContext &CostCtx,
7303+
Loop *TheLoop,
7304+
LoopVectorizationLegality &Legal) {
7305+
// CSA cost is more complicated since there is significant overhead in the
7306+
// preheader and middle block. It also contains recipes that are not backed by
7307+
// underlying instructions in the original loop. This makes it difficult to
7308+
// model in the legacy cost model.
7309+
if (!Legal.getCSAs().empty())
7310+
return true;
7311+
73047312
// First collect all instructions for the recipes in Plan.
73057313
auto GetInstructionForCost = [](const VPRecipeBase *R) -> Instruction * {
73067314
if (auto *S = dyn_cast<VPSingleDefRecipe>(R))
@@ -7408,7 +7416,7 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() {
74087416
precomputeCosts(BestPlan, BestFactor.Width, CostCtx);
74097417
assert((BestFactor.Width == LegacyVF.Width ||
74107418
planContainsAdditionalSimplifications(getPlanFor(BestFactor.Width),
7411-
CostCtx, OrigLoop)) &&
7419+
CostCtx, OrigLoop, *Legal)) &&
74127420
" VPlan cost model and legacy cost model disagreed");
74137421
assert((BestFactor.Width.isScalar() || BestFactor.ScalarCost > 0) &&
74147422
"when vectorizing, the scalar cost must be computed.");

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2502,6 +2502,9 @@ class VPCSAHeaderPHIRecipe final : public VPHeaderPHIRecipe {
25022502

25032503
void execute(VPTransformState &State) override;
25042504

2505+
InstructionCost computeCost(ElementCount VF,
2506+
VPCostContext &Ctx) const override;
2507+
25052508
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
25062509
/// Print the recipe.
25072510
void print(raw_ostream &O, const Twine &Indent,
@@ -2533,6 +2536,9 @@ class VPCSADataUpdateRecipe final : public VPSingleDefRecipe {
25332536

25342537
void execute(VPTransformState &State) override;
25352538

2539+
InstructionCost computeCost(ElementCount VF,
2540+
VPCostContext &Ctx) const override;
2541+
25362542
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
25372543
/// Print the recipe.
25382544
void print(raw_ostream &O, const Twine &Indent,
@@ -2579,6 +2585,9 @@ class VPCSAExtractScalarRecipe final : public VPSingleDefRecipe {
25792585

25802586
void execute(VPTransformState &State) override;
25812587

2588+
InstructionCost computeCost(ElementCount VF,
2589+
VPCostContext &Ctx) const override;
2590+
25822591
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
25832592
/// Print the recipe.
25842593
void print(raw_ostream &O, const Twine &Indent,
@@ -2589,7 +2598,7 @@ class VPCSAExtractScalarRecipe final : public VPSingleDefRecipe {
25892598
VPValue *getVPMaskSel() const { return getOperand(1); }
25902599
VPValue *getVPDataSel() const { return getOperand(2); }
25912600
VPValue *getVPCSAVLSel() const { return getOperand(3); }
2592-
bool usesEVL() { return getNumOperands() == 4; }
2601+
bool usesEVL() const { return getNumOperands() == 4; }
25932602
};
25942603

25952604
/// VPPredInstPHIRecipe is a recipe for generating the phi nodes needed when

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,6 +2200,24 @@ void VPCSAHeaderPHIRecipe::execute(VPTransformState &State) {
22002200
State.set(this, DataPhi, Part);
22012201
}
22022202

2203+
InstructionCost VPCSAHeaderPHIRecipe::computeCost(ElementCount VF,
2204+
VPCostContext &Ctx) const {
2205+
if (VF.isScalar())
2206+
return 0;
2207+
2208+
InstructionCost C = 0;
2209+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2210+
const TargetTransformInfo &TTI = Ctx.TTI;
2211+
2212+
// FIXME: These costs should be moved into VPInstruction::computeCost. We put
2213+
// them here for now since there is no VPInstruction::computeCost support.
2214+
// CSAInitMask
2215+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
2216+
// CSAInitData
2217+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
2218+
return C;
2219+
}
2220+
22032221
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
22042222
void VPCSADataUpdateRecipe::print(raw_ostream &O, const Twine &Indent,
22052223
VPSlotTracker &SlotTracker) const {
@@ -2228,6 +2246,34 @@ void VPCSADataUpdateRecipe::execute(VPTransformState &State) {
22282246
}
22292247
}
22302248

2249+
InstructionCost VPCSADataUpdateRecipe::computeCost(ElementCount VF,
2250+
VPCostContext &Ctx) const {
2251+
if (VF.isScalar())
2252+
return 0;
2253+
2254+
InstructionCost C = 0;
2255+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2256+
auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
2257+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2258+
const TargetTransformInfo &TTI = Ctx.TTI;
2259+
2260+
// Data Update
2261+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2262+
2263+
// FIXME: These costs should be moved into VPInstruction::computeCost. We put
2264+
// them here for now since they are related to updating the data and there is
2265+
// no VPInstruction::computeCost support at the moment. CSAInitMask AnyActive
2266+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2267+
// vp.reduce.or
2268+
C += TTI.getArithmeticReductionCost(Instruction::Or, VTy, std::nullopt,
2269+
CostKind);
2270+
// VPVLSel
2271+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2272+
// MaskUpdate
2273+
C += TTI.getArithmeticInstrCost(Instruction::Select, MaskTy, CostKind);
2274+
return C;
2275+
}
2276+
22312277
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
22322278
void VPCSAExtractScalarRecipe::print(raw_ostream &O, const Twine &Indent,
22332279
VPSlotTracker &SlotTracker) const {
@@ -2288,6 +2334,60 @@ void VPCSAExtractScalarRecipe::execute(VPTransformState &State) {
22882334
State.set(this, ChooseFromVecOrInit, 0, /*IsScalar=*/true);
22892335
}
22902336

2337+
InstructionCost
2338+
VPCSAExtractScalarRecipe::computeCost(ElementCount VF,
2339+
VPCostContext &Ctx) const {
2340+
if (VF.isScalar())
2341+
return 0;
2342+
2343+
InstructionCost C = 0;
2344+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2345+
auto *Int32VTy =
2346+
VectorType::get(IntegerType::getInt32Ty(VTy->getContext()), VF);
2347+
auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
2348+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2349+
const TargetTransformInfo &TTI = Ctx.TTI;
2350+
2351+
// StepVector
2352+
ArrayRef<Value *> Args;
2353+
IntrinsicCostAttributes CostAttrs(Intrinsic::stepvector, Int32VTy, Args);
2354+
C += TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
2355+
// NegOneSplat
2356+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, Int32VTy);
2357+
// LastIdx
2358+
if (usesEVL()) {
2359+
C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
2360+
CostKind);
2361+
} else {
2362+
// ActiveLaneIdxs
2363+
C += TTI.getArithmeticInstrCost(Instruction::Select,
2364+
MaskTy->getScalarType(), CostKind);
2365+
// MaybeLastIdx
2366+
C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
2367+
CostKind);
2368+
// IsLaneZeroActive
2369+
C += TTI.getArithmeticInstrCost(Instruction::ExtractElement, MaskTy,
2370+
CostKind);
2371+
// MaybeLastIdxEQZero
2372+
C += TTI.getArithmeticInstrCost(Instruction::ICmp, MaskTy->getScalarType(),
2373+
CostKind);
2374+
// And
2375+
C += TTI.getArithmeticInstrCost(Instruction::And, MaskTy->getScalarType(),
2376+
CostKind);
2377+
// LastIdx
2378+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy->getScalarType(),
2379+
CostKind);
2380+
}
2381+
// ExtractFromVec
2382+
C += TTI.getArithmeticInstrCost(Instruction::ExtractElement, VTy, CostKind);
2383+
// LastIdxGeZero
2384+
C += TTI.getArithmeticInstrCost(Instruction::ICmp, Int32VTy, CostKind);
2385+
// ChooseFromVecOrInit
2386+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy->getScalarType(),
2387+
CostKind);
2388+
return C;
2389+
}
2390+
22912391
void VPBranchOnMaskRecipe::execute(VPTransformState &State) {
22922392
assert(State.Instance && "Branch on Mask works only on single instance.");
22932393

0 commit comments

Comments
 (0)