30
30
#include " llvm/IR/Metadata.h"
31
31
#include " llvm/IR/PassManager.h"
32
32
#include " llvm/IR/PatternMatch.h"
33
+ #include " llvm/IR/ProfDataUtils.h"
33
34
#include " llvm/IR/Type.h"
34
35
#include " llvm/IR/Use.h"
35
36
#include " llvm/IR/Value.h"
47
48
#include " llvm/Transforms/Utils/SSAUpdater.h"
48
49
#include < algorithm>
49
50
#include < cassert>
51
+ #include < optional>
50
52
#include < utility>
51
53
52
54
using namespace llvm ;
@@ -85,7 +87,46 @@ using PhiMap = MapVector<PHINode *, BBValueVector>;
85
87
using BB2BBVecMap = MapVector<BasicBlock *, BBVector>;
86
88
87
89
using BBPhiMap = DenseMap<BasicBlock *, PhiMap>;
88
- using BBPredicates = DenseMap<BasicBlock *, Value *>;
90
+
91
+ using MaybeCondBranchWeights = std::optional<class CondBranchWeights >;
92
+
93
+ class CondBranchWeights {
94
+ uint32_t TrueWeight;
95
+ uint32_t FalseWeight;
96
+
97
+ public:
98
+ CondBranchWeights (unsigned T, unsigned F) : TrueWeight(T), FalseWeight(F) {}
99
+
100
+ static MaybeCondBranchWeights tryParse (const BranchInst &Br) {
101
+ assert (Br.isConditional ());
102
+
103
+ SmallVector<uint32_t , 2 > Weights;
104
+ if (!extractBranchWeights (Br, Weights))
105
+ return std::nullopt;
106
+
107
+ if (Weights.size () != 2 )
108
+ return std::nullopt;
109
+
110
+ return CondBranchWeights{Weights[0 ], Weights[1 ]};
111
+ }
112
+
113
+ static void setMetadata (BranchInst &Br,
114
+ MaybeCondBranchWeights const &Weights) {
115
+ assert (Br.isConditional ());
116
+ if (!Weights)
117
+ return ;
118
+ uint32_t Arr[] = {Weights->TrueWeight , Weights->FalseWeight };
119
+ setBranchWeights (Br, Arr, false );
120
+ }
121
+
122
+ CondBranchWeights invert () const {
123
+ return CondBranchWeights{FalseWeight, TrueWeight};
124
+ }
125
+ };
126
+
127
+ using ValueWeightPair = std::pair<Value *, MaybeCondBranchWeights>;
128
+
129
+ using BBPredicates = DenseMap<BasicBlock *, ValueWeightPair>;
89
130
using PredMap = DenseMap<BasicBlock *, BBPredicates>;
90
131
using BB2BBMap = DenseMap<BasicBlock *, BasicBlock *>;
91
132
@@ -271,7 +312,7 @@ class StructurizeCFG {
271
312
272
313
void analyzeLoops (RegionNode *N);
273
314
274
- Value * buildCondition (BranchInst *Term, unsigned Idx, bool Invert);
315
+ ValueWeightPair buildCondition (BranchInst *Term, unsigned Idx, bool Invert);
275
316
276
317
void gatherPredicates (RegionNode *N);
277
318
@@ -449,16 +490,22 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) {
449
490
}
450
491
451
492
// / Build the condition for one edge
452
- Value * StructurizeCFG::buildCondition (BranchInst *Term, unsigned Idx,
453
- bool Invert) {
493
+ ValueWeightPair StructurizeCFG::buildCondition (BranchInst *Term, unsigned Idx,
494
+ bool Invert) {
454
495
Value *Cond = Invert ? BoolFalse : BoolTrue;
496
+ MaybeCondBranchWeights Weights = std::nullopt;
497
+
455
498
if (Term->isConditional ()) {
456
499
Cond = Term->getCondition ();
500
+ Weights = CondBranchWeights::tryParse (*Term);
457
501
458
- if (Idx != (unsigned )Invert)
502
+ if (Idx != (unsigned )Invert) {
459
503
Cond = invertCondition (Cond);
504
+ if (Weights)
505
+ Weights = Weights->invert ();
506
+ }
460
507
}
461
- return Cond;
508
+ return { Cond, Weights} ;
462
509
}
463
510
464
511
// / Analyze the predecessors of each block and build up predicates
@@ -490,8 +537,8 @@ void StructurizeCFG::gatherPredicates(RegionNode *N) {
490
537
if (Visited.count (Other) && !Loops.count (Other) &&
491
538
!Pred.count (Other) && !Pred.count (P)) {
492
539
493
- Pred[Other] = BoolFalse;
494
- Pred[P] = BoolTrue;
540
+ Pred[Other] = { BoolFalse, std::nullopt} ;
541
+ Pred[P] = { BoolTrue, std::nullopt} ;
495
542
continue ;
496
543
}
497
544
}
@@ -512,9 +559,9 @@ void StructurizeCFG::gatherPredicates(RegionNode *N) {
512
559
513
560
BasicBlock *Entry = R->getEntry ();
514
561
if (Visited.count (Entry))
515
- Pred[Entry] = BoolTrue;
562
+ Pred[Entry] = { BoolTrue, std::nullopt} ;
516
563
else
517
- LPred[Entry] = BoolFalse;
564
+ LPred[Entry] = { BoolFalse, std::nullopt} ;
518
565
}
519
566
}
520
567
}
@@ -578,12 +625,14 @@ void StructurizeCFG::insertConditions(bool Loops) {
578
625
Dominator.addBlock (Parent);
579
626
580
627
Value *ParentValue = nullptr ;
581
- for (std::pair<BasicBlock *, Value *> BBAndPred : Preds) {
628
+ MaybeCondBranchWeights ParentWeights = std::nullopt;
629
+ for (std::pair<BasicBlock *, ValueWeightPair> BBAndPred : Preds) {
582
630
BasicBlock *BB = BBAndPred.first ;
583
- Value *Pred = BBAndPred.second ;
631
+ Value *Pred = BBAndPred.second . first ;
584
632
585
633
if (BB == Parent) {
586
634
ParentValue = Pred;
635
+ ParentWeights = BBAndPred.second .second ;
587
636
break ;
588
637
}
589
638
PhiInserter.AddAvailableValue (BB, Pred);
@@ -592,6 +641,7 @@ void StructurizeCFG::insertConditions(bool Loops) {
592
641
593
642
if (ParentValue) {
594
643
Term->setCondition (ParentValue);
644
+ CondBranchWeights::setMetadata (*Term, ParentWeights);
595
645
} else {
596
646
if (!Dominator.resultIsRememberedBlock ())
597
647
PhiInserter.AddAvailableValue (Dominator.result (), Default);
@@ -607,7 +657,7 @@ void StructurizeCFG::simplifyConditions() {
607
657
for (auto &I : concat<PredMap::value_type>(Predicates, LoopPreds)) {
608
658
auto &Preds = I.second ;
609
659
for (auto &J : Preds) {
610
- auto &Cond = J.second ;
660
+ auto &Cond = J.second . first ;
611
661
Instruction *Inverted;
612
662
if (match (Cond, m_Not (m_OneUse (m_Instruction (Inverted)))) &&
613
663
!Cond->use_empty ()) {
@@ -904,9 +954,10 @@ void StructurizeCFG::setPrevNode(BasicBlock *BB) {
904
954
// / Does BB dominate all the predicates of Node?
905
955
bool StructurizeCFG::dominatesPredicates (BasicBlock *BB, RegionNode *Node) {
906
956
BBPredicates &Preds = Predicates[Node->getEntry ()];
907
- return llvm::all_of (Preds, [&](std::pair<BasicBlock *, Value *> Pred) {
908
- return DT->dominates (BB, Pred.first );
909
- });
957
+ return llvm::all_of (Preds,
958
+ [&](std::pair<BasicBlock *, ValueWeightPair> Pred) {
959
+ return DT->dominates (BB, Pred.first );
960
+ });
910
961
}
911
962
912
963
// / Can we predict that this node will always be called?
@@ -918,9 +969,9 @@ bool StructurizeCFG::isPredictableTrue(RegionNode *Node) {
918
969
if (!PrevNode)
919
970
return true ;
920
971
921
- for (std::pair<BasicBlock*, Value* > Pred : Preds) {
972
+ for (std::pair<BasicBlock *, ValueWeightPair > Pred : Preds) {
922
973
BasicBlock *BB = Pred.first ;
923
- Value *V = Pred.second ;
974
+ Value *V = Pred.second . first ;
924
975
925
976
if (V != BoolTrue)
926
977
return false ;
0 commit comments