Skip to content

Commit 69435ea

Browse files
committed
[VPlan] Set branch weight metadata on middle term in VPlan (NFC)
Manage branch weights for the BranchOnCond in the middle block in VPlan. This requires updating VPInstruction to inherit from VPIRMetadata, which in general makes sense as there are a number of opcodes that could take metadata.
1 parent 4079ed3 commit 69435ea

File tree

3 files changed

+59
-45
lines changed

3 files changed

+59
-45
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7273,6 +7273,30 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
72737273
BypassBlock, MainResumePhi->getIncomingValueForBlock(BypassBlock));
72747274
}
72757275

7276+
/// Add branch weight metadata, if the \p Plan's middle block is terminated by a
7277+
/// BranchOnCond recipe.
7278+
static void addBranchWeigthToMiddleTerminator(VPlan &Plan, ElementCount VF,
7279+
Loop *OrigLoop) {
7280+
// 4. Adjust branch weight of the branch in the middle block.
7281+
Instruction *LatchTerm = OrigLoop->getLoopLatch()->getTerminator();
7282+
if (!hasBranchWeightMD(*LatchTerm))
7283+
return;
7284+
7285+
VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock();
7286+
auto *MiddleTerm =
7287+
dyn_cast_or_null<VPInstruction>(MiddleVPBB->getTerminator());
7288+
if (!MiddleTerm)
7289+
return;
7290+
7291+
// Assume that `Count % VectorTripCount` is equally distributed.
7292+
unsigned TripCount = Plan.getUF() * VF.getKnownMinValue();
7293+
assert(TripCount > 0 && "trip count should not be zero");
7294+
MDBuilder MDB(LatchTerm->getContext());
7295+
MDNode *BranchWeights =
7296+
MDB.createBranchWeights({1, TripCount - 1}, /*IsExpected=*/false);
7297+
MiddleTerm->addMetadata(LLVMContext::MD_prof, BranchWeights);
7298+
}
7299+
72767300
DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
72777301
ElementCount BestVF, unsigned BestUF, VPlan &BestVPlan,
72787302
InnerLoopVectorizer &ILV, DominatorTree *DT, bool VectorizingEpilogue) {
@@ -7295,11 +7319,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
72957319

72967320
VPlanTransforms::convertToConcreteRecipes(BestVPlan,
72977321
*Legal->getWidestInductionType());
7298-
// Retrieve and store the middle block before dissolving regions. Regions are
7299-
// dissolved after optimizing for VF and UF, which completely removes unneeded
7300-
// loop regions first.
7301-
VPBasicBlock *MiddleVPBB =
7302-
BestVPlan.getVectorLoopRegion() ? BestVPlan.getMiddleBlock() : nullptr;
7322+
7323+
addBranchWeigthToMiddleTerminator(BestVPlan, BestVF, OrigLoop);
73037324
VPlanTransforms::dissolveLoopRegions(BestVPlan);
73047325
// Perform the actual loop transformation.
73057326
VPTransformState State(&TTI, BestVF, LI, DT, ILV.AC, ILV.Builder, &BestVPlan,
@@ -7442,20 +7463,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
74427463

74437464
ILV.printDebugTracesAtEnd();
74447465

7445-
// 4. Adjust branch weight of the branch in the middle block.
7446-
if (HeaderVPBB) {
7447-
auto *MiddleTerm =
7448-
cast<BranchInst>(State.CFG.VPBB2IRBB[MiddleVPBB]->getTerminator());
7449-
if (MiddleTerm->isConditional() &&
7450-
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
7451-
// Assume that `Count % VectorTripCount` is equally distributed.
7452-
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
7453-
assert(TripCount > 0 && "trip count should not be zero");
7454-
const uint32_t Weights[] = {1, TripCount - 1};
7455-
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
7456-
}
7457-
}
7458-
74597466
return ExpandedSCEVs;
74607467
}
74617468

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -882,11 +882,39 @@ template <unsigned PartOpIdx> class VPUnrollPartAccessor {
882882
unsigned getUnrollPart(VPUser &U) const;
883883
};
884884

885+
/// Helper to manage IR metadata for recipes. It filters out metadata that
886+
/// cannot be propagated.
887+
class VPIRMetadata {
888+
SmallVector<std::pair<unsigned, MDNode *>> Metadata;
889+
890+
public:
891+
VPIRMetadata() {}
892+
893+
/// Adds metatadata that can be preserved from the original instruction
894+
/// \p I.
895+
VPIRMetadata(Instruction &I) { getMetadataToPropagate(&I, Metadata); }
896+
897+
/// Adds metatadata that can be preserved from the original instruction
898+
/// \p I and noalias metadata guaranteed by runtime checks using \p LVer.
899+
VPIRMetadata(Instruction &I, LoopVersioning *LVer);
900+
901+
/// Copy constructor for cloning.
902+
VPIRMetadata(const VPIRMetadata &Other) : Metadata(Other.Metadata) {}
903+
904+
/// Add all metadata to \p I.
905+
void applyMetadata(Instruction &I) const;
906+
907+
void addMetadata(unsigned Kind, MDNode *Node) {
908+
Metadata.emplace_back(Kind, Node);
909+
}
910+
};
911+
885912
/// This is a concrete Recipe that models a single VPlan-level instruction.
886913
/// While as any Recipe it may generate a sequence of IR instructions when
887914
/// executed, these instructions would always form a single-def expression as
888915
/// the VPInstruction is also a single def-use vertex.
889916
class VPInstruction : public VPRecipeWithIRFlags,
917+
public VPIRMetadata,
890918
public VPUnrollPartAccessor<1> {
891919
friend class VPlanSlp;
892920

@@ -976,7 +1004,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
9761004
VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL = {},
9771005
const Twine &Name = "")
9781006
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
979-
Opcode(Opcode), Name(Name.str()) {}
1007+
VPIRMetadata(), Opcode(Opcode), Name(Name.str()) {}
9801008

9811009
VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
9821010
const VPIRFlags &Flags, DebugLoc DL = {},
@@ -1268,29 +1296,6 @@ struct VPIRPhi : public VPIRInstruction, public VPPhiAccessors {
12681296
const VPRecipeBase *getAsRecipe() const override { return this; }
12691297
};
12701298

1271-
/// Helper to manage IR metadata for recipes. It filters out metadata that
1272-
/// cannot be propagated.
1273-
class VPIRMetadata {
1274-
SmallVector<std::pair<unsigned, MDNode *>> Metadata;
1275-
1276-
public:
1277-
VPIRMetadata() {}
1278-
1279-
/// Adds metatadata that can be preserved from the original instruction
1280-
/// \p I.
1281-
VPIRMetadata(Instruction &I) { getMetadataToPropagate(&I, Metadata); }
1282-
1283-
/// Adds metatadata that can be preserved from the original instruction
1284-
/// \p I and noalias metadata guaranteed by runtime checks using \p LVer.
1285-
VPIRMetadata(Instruction &I, LoopVersioning *LVer);
1286-
1287-
/// Copy constructor for cloning.
1288-
VPIRMetadata(const VPIRMetadata &Other) : Metadata(Other.Metadata) {}
1289-
1290-
/// Add all metadata to \p I.
1291-
void applyMetadata(Instruction &I) const;
1292-
};
1293-
12941299
/// VPWidenRecipe is a recipe for producing a widened instruction using the
12951300
/// opcode and operands of the recipe. This recipe covers most of the
12961301
/// traditional vectorization cases where each recipe transforms into a

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ VPInstruction::VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
410410
const VPIRFlags &Flags, DebugLoc DL,
411411
const Twine &Name)
412412
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, Flags, DL),
413-
Opcode(Opcode), Name(Name.str()) {
413+
VPIRMetadata(), Opcode(Opcode), Name(Name.str()) {
414414
assert(flagsValidForOpcode(getOpcode()) &&
415415
"Set flags not supported for the provided opcode");
416416
}
@@ -591,7 +591,9 @@ Value *VPInstruction::generate(VPTransformState &State) {
591591
}
592592
case VPInstruction::BranchOnCond: {
593593
Value *Cond = State.get(getOperand(0), VPLane(0));
594-
return createCondBranch(Cond, getParent(), State);
594+
auto *Br = createCondBranch(Cond, getParent(), State);
595+
applyMetadata(*Br);
596+
return Br;
595597
}
596598
case VPInstruction::BranchOnCount: {
597599
// First create the compare.

0 commit comments

Comments
 (0)