Skip to content

Commit 5b248ae

Browse files
committed
[VPlan] Set branch weight metadata on middle term in VPlan (NFC)
Manage branch weights for the BranchOnCond in the middle block in VPlan. This requires updating VPInstruction to inherit from VPIRMetadata, which in general makes sense as there are a number of opcodes that could take metadata.
1 parent c400fe2 commit 5b248ae

File tree

3 files changed

+59
-45
lines changed

3 files changed

+59
-45
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7279,6 +7279,30 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
72797279
BypassBlock, MainResumePhi->getIncomingValueForBlock(BypassBlock));
72807280
}
72817281

7282+
/// Add branch weight metadata, if the \p Plan's middle block is terminated by a
7283+
/// BranchOnCond recipe.
7284+
static void addBranchWeigthToMiddleTerminator(VPlan &Plan, ElementCount VF,
7285+
Loop *OrigLoop) {
7286+
// 4. Adjust branch weight of the branch in the middle block.
7287+
Instruction *LatchTerm = OrigLoop->getLoopLatch()->getTerminator();
7288+
if (!hasBranchWeightMD(*LatchTerm))
7289+
return;
7290+
7291+
VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock();
7292+
auto *MiddleTerm =
7293+
dyn_cast_or_null<VPInstruction>(MiddleVPBB->getTerminator());
7294+
if (!MiddleTerm)
7295+
return;
7296+
7297+
// Assume that `Count % VectorTripCount` is equally distributed.
7298+
unsigned TripCount = Plan.getUF() * VF.getKnownMinValue();
7299+
assert(TripCount > 0 && "trip count should not be zero");
7300+
MDBuilder MDB(LatchTerm->getContext());
7301+
MDNode *BranchWeights =
7302+
MDB.createBranchWeights({1, TripCount - 1}, /*IsExpected=*/false);
7303+
MiddleTerm->addMetadata(LLVMContext::MD_prof, BranchWeights);
7304+
}
7305+
72827306
DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
72837307
ElementCount BestVF, unsigned BestUF, VPlan &BestVPlan,
72847308
InnerLoopVectorizer &ILV, DominatorTree *DT, bool VectorizingEpilogue) {
@@ -7301,11 +7325,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
73017325

73027326
VPlanTransforms::convertToConcreteRecipes(BestVPlan,
73037327
*Legal->getWidestInductionType());
7304-
// Retrieve and store the middle block before dissolving regions. Regions are
7305-
// dissolved after optimizing for VF and UF, which completely removes unneeded
7306-
// loop regions first.
7307-
VPBasicBlock *MiddleVPBB =
7308-
BestVPlan.getVectorLoopRegion() ? BestVPlan.getMiddleBlock() : nullptr;
7328+
7329+
addBranchWeigthToMiddleTerminator(BestVPlan, BestVF, OrigLoop);
73097330
VPlanTransforms::dissolveLoopRegions(BestVPlan);
73107331
// Perform the actual loop transformation.
73117332
VPTransformState State(&TTI, BestVF, LI, DT, ILV.AC, ILV.Builder, &BestVPlan,
@@ -7454,20 +7475,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
74547475

74557476
ILV.printDebugTracesAtEnd();
74567477

7457-
// 4. Adjust branch weight of the branch in the middle block.
7458-
if (HeaderVPBB) {
7459-
auto *MiddleTerm =
7460-
cast<BranchInst>(State.CFG.VPBB2IRBB[MiddleVPBB]->getTerminator());
7461-
if (MiddleTerm->isConditional() &&
7462-
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
7463-
// Assume that `Count % VectorTripCount` is equally distributed.
7464-
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
7465-
assert(TripCount > 0 && "trip count should not be zero");
7466-
const uint32_t Weights[] = {1, TripCount - 1};
7467-
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
7468-
}
7469-
}
7470-
74717478
return ExpandedSCEVs;
74727479
}
74737480

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -882,11 +882,39 @@ template <unsigned PartOpIdx> class VPUnrollPartAccessor {
882882
unsigned getUnrollPart(VPUser &U) const;
883883
};
884884

885+
/// Helper to manage IR metadata for recipes. It filters out metadata that
886+
/// cannot be propagated.
887+
class VPIRMetadata {
888+
SmallVector<std::pair<unsigned, MDNode *>> Metadata;
889+
890+
public:
891+
VPIRMetadata() {}
892+
893+
/// Adds metatadata that can be preserved from the original instruction
894+
/// \p I.
895+
VPIRMetadata(Instruction &I) { getMetadataToPropagate(&I, Metadata); }
896+
897+
/// Adds metatadata that can be preserved from the original instruction
898+
/// \p I and noalias metadata guaranteed by runtime checks using \p LVer.
899+
VPIRMetadata(Instruction &I, LoopVersioning *LVer);
900+
901+
/// Copy constructor for cloning.
902+
VPIRMetadata(const VPIRMetadata &Other) : Metadata(Other.Metadata) {}
903+
904+
/// Add all metadata to \p I.
905+
void applyMetadata(Instruction &I) const;
906+
907+
void addMetadata(unsigned Kind, MDNode *Node) {
908+
Metadata.emplace_back(Kind, Node);
909+
}
910+
};
911+
885912
/// This is a concrete Recipe that models a single VPlan-level instruction.
886913
/// While as any Recipe it may generate a sequence of IR instructions when
887914
/// executed, these instructions would always form a single-def expression as
888915
/// the VPInstruction is also a single def-use vertex.
889916
class VPInstruction : public VPRecipeWithIRFlags,
917+
public VPIRMetadata,
890918
public VPUnrollPartAccessor<1> {
891919
friend class VPlanSlp;
892920

@@ -972,7 +1000,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
9721000
VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL = {},
9731001
const Twine &Name = "")
9741002
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
975-
Opcode(Opcode), Name(Name.str()) {}
1003+
VPIRMetadata(), Opcode(Opcode), Name(Name.str()) {}
9761004

9771005
VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
9781006
const VPIRFlags &Flags, DebugLoc DL = {},
@@ -1264,29 +1292,6 @@ struct VPIRPhi : public VPIRInstruction, public VPPhiAccessors {
12641292
const VPRecipeBase *getAsRecipe() const override { return this; }
12651293
};
12661294

1267-
/// Helper to manage IR metadata for recipes. It filters out metadata that
1268-
/// cannot be propagated.
1269-
class VPIRMetadata {
1270-
SmallVector<std::pair<unsigned, MDNode *>> Metadata;
1271-
1272-
public:
1273-
VPIRMetadata() {}
1274-
1275-
/// Adds metatadata that can be preserved from the original instruction
1276-
/// \p I.
1277-
VPIRMetadata(Instruction &I) { getMetadataToPropagate(&I, Metadata); }
1278-
1279-
/// Adds metatadata that can be preserved from the original instruction
1280-
/// \p I and noalias metadata guaranteed by runtime checks using \p LVer.
1281-
VPIRMetadata(Instruction &I, LoopVersioning *LVer);
1282-
1283-
/// Copy constructor for cloning.
1284-
VPIRMetadata(const VPIRMetadata &Other) : Metadata(Other.Metadata) {}
1285-
1286-
/// Add all metadata to \p I.
1287-
void applyMetadata(Instruction &I) const;
1288-
};
1289-
12901295
/// VPWidenRecipe is a recipe for producing a widened instruction using the
12911296
/// opcode and operands of the recipe. This recipe covers most of the
12921297
/// traditional vectorization cases where each recipe transforms into a

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ VPInstruction::VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
410410
const VPIRFlags &Flags, DebugLoc DL,
411411
const Twine &Name)
412412
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, Flags, DL),
413-
Opcode(Opcode), Name(Name.str()) {
413+
VPIRMetadata(), Opcode(Opcode), Name(Name.str()) {
414414
assert(flagsValidForOpcode(getOpcode()) &&
415415
"Set flags not supported for the provided opcode");
416416
}
@@ -591,7 +591,9 @@ Value *VPInstruction::generate(VPTransformState &State) {
591591
}
592592
case VPInstruction::BranchOnCond: {
593593
Value *Cond = State.get(getOperand(0), VPLane(0));
594-
return createCondBranch(Cond, getParent(), State);
594+
auto *Br = createCondBranch(Cond, getParent(), State);
595+
applyMetadata(*Br);
596+
return Br;
595597
}
596598
case VPInstruction::BranchOnCount: {
597599
// First create the compare.

0 commit comments

Comments
 (0)