Skip to content

[VPlan] Set branch weight metadata on middle term in VPlan (NFC) #143035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7273,6 +7273,33 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
BypassBlock, MainResumePhi->getIncomingValueForBlock(BypassBlock));
}

/// Add branch weight metadata, if the \p Plan's middle block is terminated by a
/// BranchOnCond recipe.
static void addBranchWeightToMiddleTerminator(VPlan &Plan, ElementCount VF,
Loop *OrigLoop) {
// 4. Adjust branch weight of the branch in the middle block.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop "4", along with the entire line - being contained in the documentation of the function above?

Instruction *LatchTerm = OrigLoop->getLoopLatch()->getTerminator();
if (!hasBranchWeightMD(*LatchTerm))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that weights are added to the new middle terminator only if the original latch terminator has weights, although the weights themselves are independent. Suffice to indicate if VPlan should introduce branch weights by noting if the original loop has any?

return;

VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock();
auto *MiddleTerm =
dyn_cast_or_null<VPInstruction>(MiddleVPBB->getTerminator());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code that was removed look like this:

    auto *MiddleTerm =
        cast<BranchInst>(State.CFG.VPBB2IRBB[MiddleVPBB]->getTerminator());

whereas the code you've added here suggests that there may not be a terminator at all, or it may not be a VPInstruction. It may be right, but it's not obvious to me at least how this is NFC? I'm a bit worried that we may not be adding branch weights any more for some loops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In VPlan, blocks with a single successor won't have a terminator recipe. Only blocks with multiple successors have (conditional) terminator recipes, which must be BranchOnCond/BranchOnCount VPInstructions (this should be checked in the verifier). I think the new and old code should be equivalent, as the removed code also checks MiddleTerm->isConditional() to skip unconditional branches.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK thanks for explaining. Might be good to leave a comment here for the reader explaining that MiddleTerm is essentially only non-null for conditional terminators where branch weights would apply.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks

// Only add branch metadata if there is a (conditional) terminator.
if (!MiddleTerm)
return;

assert(MiddleTerm->getOpcode() == VPInstruction::BranchOnCond &&
"must have a BranchOnCond");
// Assume that `Count % VectorTripCount` is equally distributed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Count is aka original TripCount, and VectorTripCount should be VectorStep: the branch terminating middle block has 1:VectorStep chance of skipping scalar epilog, i.e., when Count or original TripCount is divisible by VectorStep, assuming modulo remainder Count % VectorStep is uniformly distributed.

unsigned TripCount = Plan.getUF() * VF.getKnownMinValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not saying you should fix this here, but I just realised that this doesn't really use the estimated value of vscale for a given target when the VF is scalable. Probably something we should improve.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes would probably be good to improve handling for vscale separately.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TripCount should be called VectorStep.

assert(TripCount > 0 && "trip count should not be zero");
MDBuilder MDB(LatchTerm->getContext());
MDNode *BranchWeights =
MDB.createBranchWeights({1, TripCount - 1}, /*IsExpected=*/false);
MiddleTerm->addMetadata(LLVMContext::MD_prof, BranchWeights);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative to introducing addMetadata() is feeding these weights when constructing terminator of middle block, but that requires knowing VF and UF.

}

DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
ElementCount BestVF, unsigned BestUF, VPlan &BestVPlan,
InnerLoopVectorizer &ILV, DominatorTree *DT, bool VectorizingEpilogue) {
Expand All @@ -7295,11 +7322,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(

VPlanTransforms::convertToConcreteRecipes(BestVPlan,
*Legal->getWidestInductionType());
// Retrieve and store the middle block before dissolving regions. Regions are
// dissolved after optimizing for VF and UF, which completely removes unneeded
// loop regions first.
Comment on lines -7298 to -7300
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the comment explains the positioning of dissolveLoopRegions() below, so should be retained?

  // Regions are dissolved after optimizing for VF and UF, which completely removes unneeded
  // loop regions first.

VPBasicBlock *MiddleVPBB =
BestVPlan.getVectorLoopRegion() ? BestVPlan.getMiddleBlock() : nullptr;

addBranchWeightToMiddleTerminator(BestVPlan, BestVF, OrigLoop);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a VPlanTransform?

VPlanTransforms::dissolveLoopRegions(BestVPlan);
// Perform the actual loop transformation.
VPTransformState State(&TTI, BestVF, LI, DT, ILV.AC, ILV.Builder, &BestVPlan,
Expand Down Expand Up @@ -7442,20 +7466,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(

ILV.printDebugTracesAtEnd();

// 4. Adjust branch weight of the branch in the middle block.
if (HeaderVPBB) {
auto *MiddleTerm =
cast<BranchInst>(State.CFG.VPBB2IRBB[MiddleVPBB]->getTerminator());
if (MiddleTerm->isConditional() &&
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
// Assume that `Count % VectorTripCount` is equally distributed.
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
assert(TripCount > 0 && "trip count should not be zero");
const uint32_t Weights[] = {1, TripCount - 1};
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
}
}

return ExpandedSCEVs;
}

Expand Down
53 changes: 29 additions & 24 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -882,11 +882,39 @@ template <unsigned PartOpIdx> class VPUnrollPartAccessor {
unsigned getUnrollPart(VPUser &U) const;
};

/// Helper to manage IR metadata for recipes. It filters out metadata that
/// cannot be propagated.
class VPIRMetadata {
SmallVector<std::pair<unsigned, MDNode *>> Metadata;

public:
VPIRMetadata() {}

/// Adds metatadata that can be preserved from the original instruction
/// \p I.
VPIRMetadata(Instruction &I) { getMetadataToPropagate(&I, Metadata); }

/// Adds metatadata that can be preserved from the original instruction
/// \p I and noalias metadata guaranteed by runtime checks using \p LVer.
VPIRMetadata(Instruction &I, LoopVersioning *LVer);

/// Copy constructor for cloning.
VPIRMetadata(const VPIRMetadata &Other) : Metadata(Other.Metadata) {}

/// Add all metadata to \p I.
void applyMetadata(Instruction &I) const;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The newly introduced addMetadata() method deserves documenting, albeit trivial.

void addMetadata(unsigned Kind, MDNode *Node) {
Metadata.emplace_back(Kind, Node);
}
};

/// This is a concrete Recipe that models a single VPlan-level instruction.
/// While as any Recipe it may generate a sequence of IR instructions when
/// executed, these instructions would always form a single-def expression as
/// the VPInstruction is also a single def-use vertex.
class VPInstruction : public VPRecipeWithIRFlags,
public VPIRMetadata,
public VPUnrollPartAccessor<1> {
friend class VPlanSlp;

Expand Down Expand Up @@ -976,7 +1004,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL = {},
const Twine &Name = "")
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
Opcode(Opcode), Name(Name.str()) {}
VPIRMetadata(), Opcode(Opcode), Name(Name.str()) {}

VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
const VPIRFlags &Flags, DebugLoc DL = {},
Expand Down Expand Up @@ -1268,29 +1296,6 @@ struct VPIRPhi : public VPIRInstruction, public VPPhiAccessors {
const VPRecipeBase *getAsRecipe() const override { return this; }
};

/// Helper to manage IR metadata for recipes. It filters out metadata that
/// cannot be propagated.
class VPIRMetadata {
SmallVector<std::pair<unsigned, MDNode *>> Metadata;

public:
VPIRMetadata() {}

/// Adds metatadata that can be preserved from the original instruction
/// \p I.
VPIRMetadata(Instruction &I) { getMetadataToPropagate(&I, Metadata); }

/// Adds metatadata that can be preserved from the original instruction
/// \p I and noalias metadata guaranteed by runtime checks using \p LVer.
VPIRMetadata(Instruction &I, LoopVersioning *LVer);

/// Copy constructor for cloning.
VPIRMetadata(const VPIRMetadata &Other) : Metadata(Other.Metadata) {}

/// Add all metadata to \p I.
void applyMetadata(Instruction &I) const;
};

/// VPWidenRecipe is a recipe for producing a widened instruction using the
/// opcode and operands of the recipe. This recipe covers most of the
/// traditional vectorization cases where each recipe transforms into a
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ VPInstruction::VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
const VPIRFlags &Flags, DebugLoc DL,
const Twine &Name)
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, Flags, DL),
Opcode(Opcode), Name(Name.str()) {
VPIRMetadata(), Opcode(Opcode), Name(Name.str()) {
assert(flagsValidForOpcode(getOpcode()) &&
"Set flags not supported for the provided opcode");
}
Expand Down Expand Up @@ -591,7 +591,9 @@ Value *VPInstruction::generate(VPTransformState &State) {
}
case VPInstruction::BranchOnCond: {
Value *Cond = State.get(getOperand(0), VPLane(0));
return createCondBranch(Cond, getParent(), State);
auto *Br = createCondBranch(Cond, getParent(), State);
applyMetadata(*Br);
return Br;
}
case VPInstruction::BranchOnCount: {
// First create the compare.
Expand Down
Loading