Skip to content

[VPlan] Set branch weight metadata on middle term in VPlan (NFC) #143035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 12, 2025

Conversation

fhahn
Copy link
Contributor

@fhahn fhahn commented Jun 5, 2025

Manage branch weights for the BranchOnCond in the middle block in VPlan. This requires updating VPInstruction to inherit from VPIRMetadata, which in general makes sense as there are a number of opcodes that could take metadata.

There are other branches (part of the skeleton) that also need branch weights adding.

@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2025

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

Changes

Manage branch weights for the BranchOnCond in the middle block in VPlan. This requires updating VPInstruction to inherit from VPIRMetadata, which in general makes sense as there are a number of opcodes that could take metadata.

There are other branches (part of the skeleton) that also need branch weights adding.


Full diff: https://github.com/llvm/llvm-project/pull/143035.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+26-19)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+29-24)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+4-2)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index fc8ebebcf21b7..be77d0a579249 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7279,6 +7279,30 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
       BypassBlock, MainResumePhi->getIncomingValueForBlock(BypassBlock));
 }
 
+/// Add branch weight metadata, if the \p Plan's middle block is terminated by a
+/// BranchOnCond recipe.
+static void addBranchWeigthToMiddleTerminator(VPlan &Plan, ElementCount VF,
+                                              Loop *OrigLoop) {
+  // 4. Adjust branch weight of the branch in the middle block.
+  Instruction *LatchTerm = OrigLoop->getLoopLatch()->getTerminator();
+  if (!hasBranchWeightMD(*LatchTerm))
+    return;
+
+  VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock();
+  auto *MiddleTerm =
+      dyn_cast_or_null<VPInstruction>(MiddleVPBB->getTerminator());
+  if (!MiddleTerm)
+    return;
+
+  // Assume that `Count % VectorTripCount` is equally distributed.
+  unsigned TripCount = Plan.getUF() * VF.getKnownMinValue();
+  assert(TripCount > 0 && "trip count should not be zero");
+  MDBuilder MDB(LatchTerm->getContext());
+  MDNode *BranchWeights =
+      MDB.createBranchWeights({1, TripCount - 1}, /*IsExpected=*/false);
+  MiddleTerm->addMetadata(LLVMContext::MD_prof, BranchWeights);
+}
+
 DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
     ElementCount BestVF, unsigned BestUF, VPlan &BestVPlan,
     InnerLoopVectorizer &ILV, DominatorTree *DT, bool VectorizingEpilogue) {
@@ -7301,11 +7325,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
 
   VPlanTransforms::convertToConcreteRecipes(BestVPlan,
                                             *Legal->getWidestInductionType());
-  // Retrieve and store the middle block before dissolving regions. Regions are
-  // dissolved after optimizing for VF and UF, which completely removes unneeded
-  // loop regions first.
-  VPBasicBlock *MiddleVPBB =
-      BestVPlan.getVectorLoopRegion() ? BestVPlan.getMiddleBlock() : nullptr;
+
+  addBranchWeightsToVPlan(BestVPlan, BestVF, OrigLoop);
   VPlanTransforms::dissolveLoopRegions(BestVPlan);
   // Perform the actual loop transformation.
   VPTransformState State(&TTI, BestVF, LI, DT, ILV.AC, ILV.Builder, &BestVPlan,
@@ -7454,20 +7475,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
 
   ILV.printDebugTracesAtEnd();
 
-  // 4. Adjust branch weight of the branch in the middle block.
-  if (HeaderVPBB) {
-    auto *MiddleTerm =
-        cast<BranchInst>(State.CFG.VPBB2IRBB[MiddleVPBB]->getTerminator());
-    if (MiddleTerm->isConditional() &&
-        hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
-      // Assume that `Count % VectorTripCount` is equally distributed.
-      unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
-      assert(TripCount > 0 && "trip count should not be zero");
-      const uint32_t Weights[] = {1, TripCount - 1};
-      setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
-    }
-  }
-
   return ExpandedSCEVs;
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 273df55188c16..14a9521658fa3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -882,11 +882,39 @@ template <unsigned PartOpIdx> class VPUnrollPartAccessor {
   unsigned getUnrollPart(VPUser &U) const;
 };
 
+/// Helper to manage IR metadata for recipes. It filters out metadata that
+/// cannot be propagated.
+class VPIRMetadata {
+  SmallVector<std::pair<unsigned, MDNode *>> Metadata;
+
+public:
+  VPIRMetadata() {}
+
+  /// Adds metatadata that can be preserved from the original instruction
+  /// \p I.
+  VPIRMetadata(Instruction &I) { getMetadataToPropagate(&I, Metadata); }
+
+  /// Adds metatadata that can be preserved from the original instruction
+  /// \p I and noalias metadata guaranteed by runtime checks using \p LVer.
+  VPIRMetadata(Instruction &I, LoopVersioning *LVer);
+
+  /// Copy constructor for cloning.
+  VPIRMetadata(const VPIRMetadata &Other) : Metadata(Other.Metadata) {}
+
+  /// Add all metadata to \p I.
+  void applyMetadata(Instruction &I) const;
+
+  void addMetadata(unsigned Kind, MDNode *Node) {
+    Metadata.emplace_back(Kind, Node);
+  }
+};
+
 /// This is a concrete Recipe that models a single VPlan-level instruction.
 /// While as any Recipe it may generate a sequence of IR instructions when
 /// executed, these instructions would always form a single-def expression as
 /// the VPInstruction is also a single def-use vertex.
 class VPInstruction : public VPRecipeWithIRFlags,
+                      public VPIRMetadata,
                       public VPUnrollPartAccessor<1> {
   friend class VPlanSlp;
 
@@ -972,7 +1000,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
   VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL = {},
                 const Twine &Name = "")
       : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
-        Opcode(Opcode), Name(Name.str()) {}
+        VPIRMetadata(), Opcode(Opcode), Name(Name.str()) {}
 
   VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
                 const VPIRFlags &Flags, DebugLoc DL = {},
@@ -1264,29 +1292,6 @@ struct VPIRPhi : public VPIRInstruction, public VPPhiAccessors {
   const VPRecipeBase *getAsRecipe() const override { return this; }
 };
 
-/// Helper to manage IR metadata for recipes. It filters out metadata that
-/// cannot be propagated.
-class VPIRMetadata {
-  SmallVector<std::pair<unsigned, MDNode *>> Metadata;
-
-public:
-  VPIRMetadata() {}
-
-  /// Adds metatadata that can be preserved from the original instruction
-  /// \p I.
-  VPIRMetadata(Instruction &I) { getMetadataToPropagate(&I, Metadata); }
-
-  /// Adds metatadata that can be preserved from the original instruction
-  /// \p I and noalias metadata guaranteed by runtime checks using \p LVer.
-  VPIRMetadata(Instruction &I, LoopVersioning *LVer);
-
-  /// Copy constructor for cloning.
-  VPIRMetadata(const VPIRMetadata &Other) : Metadata(Other.Metadata) {}
-
-  /// Add all metadata to \p I.
-  void applyMetadata(Instruction &I) const;
-};
-
 /// VPWidenRecipe is a recipe for producing a widened instruction using the
 /// opcode and operands of the recipe. This recipe covers most of the
 /// traditional vectorization cases where each recipe transforms into a
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 2aa5dd1b48c00..6e4821c0e5387 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -410,7 +410,7 @@ VPInstruction::VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
                              const VPIRFlags &Flags, DebugLoc DL,
                              const Twine &Name)
     : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, Flags, DL),
-      Opcode(Opcode), Name(Name.str()) {
+      VPIRMetadata(), Opcode(Opcode), Name(Name.str()) {
   assert(flagsValidForOpcode(getOpcode()) &&
          "Set flags not supported for the provided opcode");
 }
@@ -591,7 +591,9 @@ Value *VPInstruction::generate(VPTransformState &State) {
   }
   case VPInstruction::BranchOnCond: {
     Value *Cond = State.get(getOperand(0), VPLane(0));
-    return createCondBranch(Cond, getParent(), State);
+    auto *Br = createCondBranch(Cond, getParent(), State);
+    applyMetadata(*Br);
+    return Br;
   }
   case VPInstruction::BranchOnCount: {
     // First create the compare.

@fhahn fhahn force-pushed the vplan-set-branch-weight-middle-term branch from 2e1d91a to 5b248ae Compare June 9, 2025 12:52

VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock();
auto *MiddleTerm =
dyn_cast_or_null<VPInstruction>(MiddleVPBB->getTerminator());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code that was removed look like this:

    auto *MiddleTerm =
        cast<BranchInst>(State.CFG.VPBB2IRBB[MiddleVPBB]->getTerminator());

whereas the code you've added here suggests that there may not be a terminator at all, or it may not be a VPInstruction. It may be right, but it's not obvious to me at least how this is NFC? I'm a bit worried that we may not be adding branch weights any more for some loops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In VPlan, blocks with a single successor won't have a terminator recipe. Only blocks with multiple successors have (conditional) terminator recipes, which must be BranchOnCond/BranchOnCount VPInstructions (this should be checked in the verifier). I think the new and old code should be equivalent, as the removed code also checks MiddleTerm->isConditional() to skip unconditional branches.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK thanks for explaining. Might be good to leave a comment here for the reader explaining that MiddleTerm is essentially only non-null for conditional terminators where branch weights would apply.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks

return;

// Assume that `Count % VectorTripCount` is equally distributed.
unsigned TripCount = Plan.getUF() * VF.getKnownMinValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not saying you should fix this here, but I just realised that this doesn't really use the estimated value of vscale for a given target when the VF is scalable. Probably something we should improve.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes would probably be good to improve handling for vscale separately.

@fhahn fhahn force-pushed the vplan-set-branch-weight-middle-term branch from 5b248ae to d5da6cc Compare June 10, 2025 15:13
Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

fhahn added 3 commits June 12, 2025 09:38
Manage branch weights for the BranchOnCond in the middle block in VPlan.
This requires updating VPInstruction to inherit from VPIRMetadata, which
in general makes sense as there are a number of opcodes that could take
metadata.
@fhahn fhahn force-pushed the vplan-set-branch-weight-middle-term branch from d5da6cc to 91c63c6 Compare June 12, 2025 08:44
@fhahn fhahn merged commit db8d34d into llvm:main Jun 12, 2025
5 of 7 checks passed
@fhahn fhahn deleted the vplan-set-branch-weight-middle-term branch June 12, 2025 09:04
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jun 12, 2025
… (NFC) (#143035)

Manage branch weights for the BranchOnCond in the middle block in VPlan.
This requires updating VPInstruction to inherit from VPIRMetadata, which
in general makes sense as there are a number of opcodes that could take
metadata.

There are other branches (part of the skeleton) that also need branch
weights adding.

PR: llvm/llvm-project#143035
Copy link
Collaborator

@ayalz ayalz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post-commit review.

Comment on lines -7298 to -7300
// Retrieve and store the middle block before dissolving regions. Regions are
// dissolved after optimizing for VF and UF, which completely removes unneeded
// loop regions first.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the comment explains the positioning of dissolveLoopRegions() below, so should be retained?

  // Regions are dissolved after optimizing for VF and UF, which completely removes unneeded
  // loop regions first.


/// Add all metadata to \p I.
void applyMetadata(Instruction &I) const;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The newly introduced addMetadata() method deserves documenting, albeit trivial.

VPBasicBlock *MiddleVPBB =
BestVPlan.getVectorLoopRegion() ? BestVPlan.getMiddleBlock() : nullptr;

addBranchWeightToMiddleTerminator(BestVPlan, BestVF, OrigLoop);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a VPlanTransform?

/// BranchOnCond recipe.
static void addBranchWeightToMiddleTerminator(VPlan &Plan, ElementCount VF,
Loop *OrigLoop) {
// 4. Adjust branch weight of the branch in the middle block.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop "4", along with the entire line - being contained in the documentation of the function above?

Loop *OrigLoop) {
// 4. Adjust branch weight of the branch in the middle block.
Instruction *LatchTerm = OrigLoop->getLoopLatch()->getTerminator();
if (!hasBranchWeightMD(*LatchTerm))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that weights are added to the new middle terminator only if the original latch terminator has weights, although the weights themselves are independent. Suffice to indicate if VPlan should introduce branch weights by noting if the original loop has any?

assert(MiddleTerm->getOpcode() == VPInstruction::BranchOnCond &&
"must have a BranchOnCond");
// Assume that `Count % VectorTripCount` is equally distributed.
unsigned TripCount = Plan.getUF() * VF.getKnownMinValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TripCount should be called VectorStep.


assert(MiddleTerm->getOpcode() == VPInstruction::BranchOnCond &&
"must have a BranchOnCond");
// Assume that `Count % VectorTripCount` is equally distributed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Count is aka original TripCount, and VectorTripCount should be VectorStep: the branch terminating middle block has 1:VectorStep chance of skipping scalar epilog, i.e., when Count or original TripCount is divisible by VectorStep, assuming modulo remainder Count % VectorStep is uniformly distributed.

MDBuilder MDB(LatchTerm->getContext());
MDNode *BranchWeights =
MDB.createBranchWeights({1, TripCount - 1}, /*IsExpected=*/false);
MiddleTerm->addMetadata(LLVMContext::MD_prof, BranchWeights);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative to introducing addMetadata() is feeding these weights when constructing terminator of middle block, but that requires knowing VF and UF.

@zmodem
Copy link
Collaborator

zmodem commented Jun 12, 2025

We're hitting assertion failures after this change:

      llvm/lib/Transforms/Vectorize/VPlan.h:4021:
      llvm::VPBasicBlock* llvm::VPlan::getMiddleBlock():
      Assertion `LoopRegion && "cannot call the function after vector loop region has been removed"' failed.

See https://crbug.com/424377400#comment3 for a reproducer.

I'll revert while it gets fixed.

zmodem added a commit that referenced this pull request Jun 12, 2025
…FC) (#143035)"

This caused assertion failures:

  llvm/lib/Transforms/Vectorize/VPlan.h:4021:
  llvm::VPBasicBlock* llvm::VPlan::getMiddleBlock():
  Assertion `LoopRegion && "cannot call the function after vector loop region has been removed"' failed.

See comment on the PR.

> Manage branch weights for the BranchOnCond in the middle block in VPlan.
> This requires updating VPInstruction to inherit from VPIRMetadata, which
> in general makes sense as there are a number of opcodes that could take
> metadata.
>
> There are other branches (part of the skeleton) that also need branch
> weights adding.
>
> PR: #143035

This reverts commit db8d34d.
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jun 12, 2025
…in VPlan (NFC) (#143035)"

This caused assertion failures:

  llvm/lib/Transforms/Vectorize/VPlan.h:4021:
  llvm::VPBasicBlock* llvm::VPlan::getMiddleBlock():
  Assertion `LoopRegion && "cannot call the function after vector loop region has been removed"' failed.

See comment on the PR.

> Manage branch weights for the BranchOnCond in the middle block in VPlan.
> This requires updating VPInstruction to inherit from VPIRMetadata, which
> in general makes sense as there are a number of opcodes that could take
> metadata.
>
> There are other branches (part of the skeleton) that also need branch
> weights adding.
>
> PR: llvm/llvm-project#143035

This reverts commit db8d34d.
fhahn added a commit to fhahn/llvm-project that referenced this pull request Jun 14, 2025
fhahn added a commit that referenced this pull request Jun 14, 2025
Add test case with branch weights where the vector loop can
be removed. Exposed a crash with db8d34d
(#143035).
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jun 14, 2025
Add test case with branch weights where the vector loop can
be removed. Exposed a crash with db8d34d
(llvm/llvm-project#143035).
fhahn added a commit that referenced this pull request Jun 14, 2025
…NFC) (#143035)"

This reverts commit 0604dc1.

The recommitted version addresses post-commit comments and adjusts the
place the branch weights are added. It now runs before VPlans are optimized
for VF and UF, which may remove the vector loop region, causing a crash
trying to get the middle block after that. Test case added in
72f99b7.

Original message:
Manage branch weights for the BranchOnCond in the middle block in VPlan.
This requires updating VPInstruction to inherit from VPIRMetadata, which
in general makes sense as there are a number of opcodes that could take
metadata.

There are other branches (part of the skeleton) that also need branch
weights adding.

PR: #143035
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jun 14, 2025
… in VPlan (NFC) (#143035)"

This reverts commit 0604dc1.

The recommitted version addresses post-commit comments and adjusts the
place the branch weights are added. It now runs before VPlans are optimized
for VF and UF, which may remove the vector loop region, causing a crash
trying to get the middle block after that. Test case added in
72f99b7.

Original message:
Manage branch weights for the BranchOnCond in the middle block in VPlan.
This requires updating VPInstruction to inherit from VPIRMetadata, which
in general makes sense as there are a number of opcodes that could take
metadata.

There are other branches (part of the skeleton) that also need branch
weights adding.

PR: llvm/llvm-project#143035
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…m#143035)

Manage branch weights for the BranchOnCond in the middle block in VPlan.
This requires updating VPInstruction to inherit from VPIRMetadata, which
in general makes sense as there are a number of opcodes that could take
metadata.

There are other branches (part of the skeleton) that also need branch
weights adding.

PR: llvm#143035
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…FC) (llvm#143035)"

This caused assertion failures:

  llvm/lib/Transforms/Vectorize/VPlan.h:4021:
  llvm::VPBasicBlock* llvm::VPlan::getMiddleBlock():
  Assertion `LoopRegion && "cannot call the function after vector loop region has been removed"' failed.

See comment on the PR.

> Manage branch weights for the BranchOnCond in the middle block in VPlan.
> This requires updating VPInstruction to inherit from VPIRMetadata, which
> in general makes sense as there are a number of opcodes that could take
> metadata.
>
> There are other branches (part of the skeleton) that also need branch
> weights adding.
>
> PR: llvm#143035

This reverts commit db8d34d.
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
Add test case with branch weights where the vector loop can
be removed. Exposed a crash with db8d34d
(llvm#143035).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants