Skip to content

Commit c69b07a

Browse files
committed
[VPlan] Set branch weight metadata on middle term in VPlan (NFC)
Manage branch weights for the BranchOnCond in the middle block in VPlan. This requires updating VPInstruction to inherit from VPIRMetadata, which in general makes sense as there are a number of opcodes that could take metadata.
1 parent 3cb104e commit c69b07a

File tree

3 files changed

+59
-45
lines changed

3 files changed

+59
-45
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7282,6 +7282,30 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
72827282
BypassBlock, MainResumePhi->getIncomingValueForBlock(BypassBlock));
72837283
}
72847284

7285+
/// Add branch weight metadata, if the \p Plan's middle block is terminated by a
7286+
/// BranchOnCond recipe.
7287+
static void addBranchWeigthToMiddleTerminator(VPlan &Plan, ElementCount VF,
7288+
Loop *OrigLoop) {
7289+
// 4. Adjust branch weight of the branch in the middle block.
7290+
Instruction *LatchTerm = OrigLoop->getLoopLatch()->getTerminator();
7291+
if (!hasBranchWeightMD(*LatchTerm))
7292+
return;
7293+
7294+
VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock();
7295+
auto *MiddleTerm =
7296+
dyn_cast_or_null<VPInstruction>(MiddleVPBB->getTerminator());
7297+
if (!MiddleTerm)
7298+
return;
7299+
7300+
// Assume that `Count % VectorTripCount` is equally distributed.
7301+
unsigned TripCount = Plan.getUF() * VF.getKnownMinValue();
7302+
assert(TripCount > 0 && "trip count should not be zero");
7303+
MDBuilder MDB(LatchTerm->getContext());
7304+
MDNode *BranchWeights =
7305+
MDB.createBranchWeights({1, TripCount - 1}, /*IsExpected=*/false);
7306+
MiddleTerm->addMetadata(LLVMContext::MD_prof, BranchWeights);
7307+
}
7308+
72857309
DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
72867310
ElementCount BestVF, unsigned BestUF, VPlan &BestVPlan,
72877311
InnerLoopVectorizer &ILV, DominatorTree *DT, bool VectorizingEpilogue) {
@@ -7304,11 +7328,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
73047328

73057329
VPlanTransforms::convertToConcreteRecipes(BestVPlan,
73067330
*Legal->getWidestInductionType());
7307-
// Retrieve and store the middle block before dissolving regions. Regions are
7308-
// dissolved after optimizing for VF and UF, which completely removes unneeded
7309-
// loop regions first.
7310-
VPBasicBlock *MiddleVPBB =
7311-
BestVPlan.getVectorLoopRegion() ? BestVPlan.getMiddleBlock() : nullptr;
7331+
7332+
addBranchWeigthToMiddleTerminator(BestVPlan, BestVF, OrigLoop);
73127333
VPlanTransforms::dissolveLoopRegions(BestVPlan);
73137334
// Perform the actual loop transformation.
73147335
VPTransformState State(&TTI, BestVF, LI, DT, ILV.AC, ILV.Builder, &BestVPlan,
@@ -7451,20 +7472,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
74517472

74527473
ILV.printDebugTracesAtEnd();
74537474

7454-
// 4. Adjust branch weight of the branch in the middle block.
7455-
if (HeaderVPBB) {
7456-
auto *MiddleTerm =
7457-
cast<BranchInst>(State.CFG.VPBB2IRBB[MiddleVPBB]->getTerminator());
7458-
if (MiddleTerm->isConditional() &&
7459-
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
7460-
// Assume that `Count % VectorTripCount` is equally distributed.
7461-
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
7462-
assert(TripCount > 0 && "trip count should not be zero");
7463-
const uint32_t Weights[] = {1, TripCount - 1};
7464-
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
7465-
}
7466-
}
7467-
74687475
return ExpandedSCEVs;
74697476
}
74707477

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -882,11 +882,39 @@ template <unsigned PartOpIdx> class VPUnrollPartAccessor {
882882
unsigned getUnrollPart(VPUser &U) const;
883883
};
884884

885+
/// Helper to manage IR metadata for recipes. It filters out metadata that
886+
/// cannot be propagated.
887+
class VPIRMetadata {
888+
SmallVector<std::pair<unsigned, MDNode *>> Metadata;
889+
890+
public:
891+
VPIRMetadata() {}
892+
893+
/// Adds metatadata that can be preserved from the original instruction
894+
/// \p I.
895+
VPIRMetadata(Instruction &I) { getMetadataToPropagate(&I, Metadata); }
896+
897+
/// Adds metatadata that can be preserved from the original instruction
898+
/// \p I and noalias metadata guaranteed by runtime checks using \p LVer.
899+
VPIRMetadata(Instruction &I, LoopVersioning *LVer);
900+
901+
/// Copy constructor for cloning.
902+
VPIRMetadata(const VPIRMetadata &Other) : Metadata(Other.Metadata) {}
903+
904+
/// Add all metadata to \p I.
905+
void applyMetadata(Instruction &I) const;
906+
907+
void addMetadata(unsigned Kind, MDNode *Node) {
908+
Metadata.emplace_back(Kind, Node);
909+
}
910+
};
911+
885912
/// This is a concrete Recipe that models a single VPlan-level instruction.
886913
/// While as any Recipe it may generate a sequence of IR instructions when
887914
/// executed, these instructions would always form a single-def expression as
888915
/// the VPInstruction is also a single def-use vertex.
889916
class VPInstruction : public VPRecipeWithIRFlags,
917+
public VPIRMetadata,
890918
public VPUnrollPartAccessor<1> {
891919
friend class VPlanSlp;
892920

@@ -976,7 +1004,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
9761004
VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL = {},
9771005
const Twine &Name = "")
9781006
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
979-
Opcode(Opcode), Name(Name.str()) {}
1007+
VPIRMetadata(), Opcode(Opcode), Name(Name.str()) {}
9801008

9811009
VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
9821010
const VPIRFlags &Flags, DebugLoc DL = {},
@@ -1268,29 +1296,6 @@ struct VPIRPhi : public VPIRInstruction, public VPPhiAccessors {
12681296
const VPRecipeBase *getAsRecipe() const override { return this; }
12691297
};
12701298

1271-
/// Helper to manage IR metadata for recipes. It filters out metadata that
1272-
/// cannot be propagated.
1273-
class VPIRMetadata {
1274-
SmallVector<std::pair<unsigned, MDNode *>> Metadata;
1275-
1276-
public:
1277-
VPIRMetadata() {}
1278-
1279-
/// Adds metatadata that can be preserved from the original instruction
1280-
/// \p I.
1281-
VPIRMetadata(Instruction &I) { getMetadataToPropagate(&I, Metadata); }
1282-
1283-
/// Adds metatadata that can be preserved from the original instruction
1284-
/// \p I and noalias metadata guaranteed by runtime checks using \p LVer.
1285-
VPIRMetadata(Instruction &I, LoopVersioning *LVer);
1286-
1287-
/// Copy constructor for cloning.
1288-
VPIRMetadata(const VPIRMetadata &Other) : Metadata(Other.Metadata) {}
1289-
1290-
/// Add all metadata to \p I.
1291-
void applyMetadata(Instruction &I) const;
1292-
};
1293-
12941299
/// VPWidenRecipe is a recipe for producing a widened instruction using the
12951300
/// opcode and operands of the recipe. This recipe covers most of the
12961301
/// traditional vectorization cases where each recipe transforms into a

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ VPInstruction::VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
410410
const VPIRFlags &Flags, DebugLoc DL,
411411
const Twine &Name)
412412
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, Flags, DL),
413-
Opcode(Opcode), Name(Name.str()) {
413+
VPIRMetadata(), Opcode(Opcode), Name(Name.str()) {
414414
assert(flagsValidForOpcode(getOpcode()) &&
415415
"Set flags not supported for the provided opcode");
416416
}
@@ -591,7 +591,9 @@ Value *VPInstruction::generate(VPTransformState &State) {
591591
}
592592
case VPInstruction::BranchOnCond: {
593593
Value *Cond = State.get(getOperand(0), VPLane(0));
594-
return createCondBranch(Cond, getParent(), State);
594+
auto *Br = createCondBranch(Cond, getParent(), State);
595+
applyMetadata(*Br);
596+
return Br;
595597
}
596598
case VPInstruction::BranchOnCount: {
597599
// First create the compare.

0 commit comments

Comments
 (0)