Skip to content

Commit e903c5a

Browse files
committed
[AMDGPU][StructurizeCFG] Maintain branch MD_prof metadata
1 parent bd9a3b0 commit e903c5a

File tree

2 files changed

+75
-20
lines changed

2 files changed

+75
-20
lines changed

llvm/lib/Transforms/Scalar/StructurizeCFG.cpp

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "llvm/IR/Metadata.h"
3131
#include "llvm/IR/PassManager.h"
3232
#include "llvm/IR/PatternMatch.h"
33+
#include "llvm/IR/ProfDataUtils.h"
3334
#include "llvm/IR/Type.h"
3435
#include "llvm/IR/Use.h"
3536
#include "llvm/IR/Value.h"
@@ -47,6 +48,7 @@
4748
#include "llvm/Transforms/Utils/SSAUpdater.h"
4849
#include <algorithm>
4950
#include <cassert>
51+
#include <optional>
5052
#include <utility>
5153

5254
using namespace llvm;
@@ -85,7 +87,46 @@ using PhiMap = MapVector<PHINode *, BBValueVector>;
8587
using BB2BBVecMap = MapVector<BasicBlock *, BBVector>;
8688

8789
using BBPhiMap = DenseMap<BasicBlock *, PhiMap>;
88-
using BBPredicates = DenseMap<BasicBlock *, Value *>;
90+
91+
using MaybeCondBranchWeights = std::optional<class CondBranchWeights>;
92+
93+
class CondBranchWeights {
94+
uint32_t TrueWeight;
95+
uint32_t FalseWeight;
96+
97+
public:
98+
CondBranchWeights(unsigned T, unsigned F) : TrueWeight(T), FalseWeight(F) {}
99+
100+
static MaybeCondBranchWeights tryParse(const BranchInst &Br) {
101+
assert(Br.isConditional());
102+
103+
SmallVector<uint32_t, 2> Weights;
104+
if (!extractBranchWeights(Br, Weights))
105+
return std::nullopt;
106+
107+
if (Weights.size() != 2)
108+
return std::nullopt;
109+
110+
return CondBranchWeights{Weights[0], Weights[1]};
111+
}
112+
113+
static void setMetadata(BranchInst &Br,
114+
MaybeCondBranchWeights const &Weights) {
115+
assert(Br.isConditional());
116+
if (!Weights)
117+
return;
118+
uint32_t Arr[] = {Weights->TrueWeight, Weights->FalseWeight};
119+
setBranchWeights(Br, Arr, false);
120+
}
121+
122+
CondBranchWeights invert() const {
123+
return CondBranchWeights{FalseWeight, TrueWeight};
124+
}
125+
};
126+
127+
using ValueWeightPair = std::pair<Value *, MaybeCondBranchWeights>;
128+
129+
using BBPredicates = DenseMap<BasicBlock *, ValueWeightPair>;
89130
using PredMap = DenseMap<BasicBlock *, BBPredicates>;
90131
using BB2BBMap = DenseMap<BasicBlock *, BasicBlock *>;
91132

@@ -271,7 +312,7 @@ class StructurizeCFG {
271312

272313
void analyzeLoops(RegionNode *N);
273314

274-
Value *buildCondition(BranchInst *Term, unsigned Idx, bool Invert);
315+
ValueWeightPair buildCondition(BranchInst *Term, unsigned Idx, bool Invert);
275316

276317
void gatherPredicates(RegionNode *N);
277318

@@ -449,16 +490,22 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) {
449490
}
450491

451492
/// Build the condition for one edge
452-
Value *StructurizeCFG::buildCondition(BranchInst *Term, unsigned Idx,
453-
bool Invert) {
493+
ValueWeightPair StructurizeCFG::buildCondition(BranchInst *Term, unsigned Idx,
494+
bool Invert) {
454495
Value *Cond = Invert ? BoolFalse : BoolTrue;
496+
MaybeCondBranchWeights Weights = std::nullopt;
497+
455498
if (Term->isConditional()) {
456499
Cond = Term->getCondition();
500+
Weights = CondBranchWeights::tryParse(*Term);
457501

458-
if (Idx != (unsigned)Invert)
502+
if (Idx != (unsigned)Invert) {
459503
Cond = invertCondition(Cond);
504+
if (Weights)
505+
Weights = Weights->invert();
506+
}
460507
}
461-
return Cond;
508+
return {Cond, Weights};
462509
}
463510

464511
/// Analyze the predecessors of each block and build up predicates
@@ -490,8 +537,8 @@ void StructurizeCFG::gatherPredicates(RegionNode *N) {
490537
if (Visited.count(Other) && !Loops.count(Other) &&
491538
!Pred.count(Other) && !Pred.count(P)) {
492539

493-
Pred[Other] = BoolFalse;
494-
Pred[P] = BoolTrue;
540+
Pred[Other] = {BoolFalse, std::nullopt};
541+
Pred[P] = {BoolTrue, std::nullopt};
495542
continue;
496543
}
497544
}
@@ -512,9 +559,9 @@ void StructurizeCFG::gatherPredicates(RegionNode *N) {
512559

513560
BasicBlock *Entry = R->getEntry();
514561
if (Visited.count(Entry))
515-
Pred[Entry] = BoolTrue;
562+
Pred[Entry] = {BoolTrue, std::nullopt};
516563
else
517-
LPred[Entry] = BoolFalse;
564+
LPred[Entry] = {BoolFalse, std::nullopt};
518565
}
519566
}
520567
}
@@ -578,12 +625,14 @@ void StructurizeCFG::insertConditions(bool Loops) {
578625
Dominator.addBlock(Parent);
579626

580627
Value *ParentValue = nullptr;
581-
for (std::pair<BasicBlock *, Value *> BBAndPred : Preds) {
628+
MaybeCondBranchWeights ParentWeights = std::nullopt;
629+
for (std::pair<BasicBlock *, ValueWeightPair> BBAndPred : Preds) {
582630
BasicBlock *BB = BBAndPred.first;
583-
Value *Pred = BBAndPred.second;
631+
Value *Pred = BBAndPred.second.first;
584632

585633
if (BB == Parent) {
586634
ParentValue = Pred;
635+
ParentWeights = BBAndPred.second.second;
587636
break;
588637
}
589638
PhiInserter.AddAvailableValue(BB, Pred);
@@ -592,6 +641,7 @@ void StructurizeCFG::insertConditions(bool Loops) {
592641

593642
if (ParentValue) {
594643
Term->setCondition(ParentValue);
644+
CondBranchWeights::setMetadata(*Term, ParentWeights);
595645
} else {
596646
if (!Dominator.resultIsRememberedBlock())
597647
PhiInserter.AddAvailableValue(Dominator.result(), Default);
@@ -607,7 +657,7 @@ void StructurizeCFG::simplifyConditions() {
607657
for (auto &I : concat<PredMap::value_type>(Predicates, LoopPreds)) {
608658
auto &Preds = I.second;
609659
for (auto &J : Preds) {
610-
auto &Cond = J.second;
660+
auto &Cond = J.second.first;
611661
Instruction *Inverted;
612662
if (match(Cond, m_Not(m_OneUse(m_Instruction(Inverted)))) &&
613663
!Cond->use_empty()) {
@@ -904,9 +954,10 @@ void StructurizeCFG::setPrevNode(BasicBlock *BB) {
904954
/// Does BB dominate all the predicates of Node?
905955
bool StructurizeCFG::dominatesPredicates(BasicBlock *BB, RegionNode *Node) {
906956
BBPredicates &Preds = Predicates[Node->getEntry()];
907-
return llvm::all_of(Preds, [&](std::pair<BasicBlock *, Value *> Pred) {
908-
return DT->dominates(BB, Pred.first);
909-
});
957+
return llvm::all_of(Preds,
958+
[&](std::pair<BasicBlock *, ValueWeightPair> Pred) {
959+
return DT->dominates(BB, Pred.first);
960+
});
910961
}
911962

912963
/// Can we predict that this node will always be called?
@@ -918,9 +969,9 @@ bool StructurizeCFG::isPredictableTrue(RegionNode *Node) {
918969
if (!PrevNode)
919970
return true;
920971

921-
for (std::pair<BasicBlock*, Value*> Pred : Preds) {
972+
for (std::pair<BasicBlock *, ValueWeightPair> Pred : Preds) {
922973
BasicBlock *BB = Pred.first;
923-
Value *V = Pred.second;
974+
Value *V = Pred.second.first;
924975

925976
if (V != BoolTrue)
926977
return false;

llvm/test/CodeGen/AMDGPU/structurizer-keep-perf-md.ll

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ define amdgpu_ps i32 @if_else(i32 %0) {
55
; OPT-LABEL: define amdgpu_ps i32 @if_else(
66
; OPT-SAME: i32 [[TMP0:%.*]]) {
77
; OPT-NEXT: [[C:%.*]] = icmp ne i32 [[TMP0]], 0
8-
; OPT-NEXT: br i1 [[C]], label %[[FALSE:.*]], label %[[FLOW:.*]]
8+
; OPT-NEXT: br i1 [[C]], label %[[FALSE:.*]], label %[[FLOW:.*]], !prof [[PROF0:![0-9]+]]
99
; OPT: [[FLOW]]:
1010
; OPT-NEXT: [[TMP2:%.*]] = phi i32 [ 33, %[[FALSE]] ], [ undef, [[TMP1:%.*]] ]
1111
; OPT-NEXT: [[TMP3:%.*]] = phi i1 [ false, %[[FALSE]] ], [ true, [[TMP1]] ]
@@ -40,7 +40,7 @@ define amdgpu_ps void @loop_if_break(i32 %n) {
4040
; OPT: [[LOOP]]:
4141
; OPT-NEXT: [[I:%.*]] = phi i32 [ [[N]], %[[ENTRY]] ], [ [[TMP0:%.*]], %[[FLOW:.*]] ]
4242
; OPT-NEXT: [[C:%.*]] = icmp ugt i32 [[I]], 0
43-
; OPT-NEXT: br i1 [[C]], label %[[LOOP_BODY:.*]], label %[[FLOW]]
43+
; OPT-NEXT: br i1 [[C]], label %[[LOOP_BODY:.*]], label %[[FLOW]], !prof [[PROF1:![0-9]+]]
4444
; OPT: [[LOOP_BODY]]:
4545
; OPT-NEXT: [[I_NEXT:%.*]] = sub i32 [[I]], 1
4646
; OPT-NEXT: br label %[[FLOW]]
@@ -70,3 +70,7 @@ exit: ; preds = %loop
7070
attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
7171

7272
!0 = !{!"branch_weights", i32 1000, i32 1}
73+
;.
74+
; OPT: [[PROF0]] = !{!"branch_weights", i32 1, i32 1000}
75+
; OPT: [[PROF1]] = !{!"branch_weights", i32 1000, i32 1}
76+
;.

0 commit comments

Comments
 (0)