30
30
#include " llvm/IR/Metadata.h"
31
31
#include " llvm/IR/PassManager.h"
32
32
#include " llvm/IR/PatternMatch.h"
33
+ #include " llvm/IR/ProfDataUtils.h"
33
34
#include " llvm/IR/Type.h"
34
35
#include " llvm/IR/Use.h"
35
36
#include " llvm/IR/Value.h"
@@ -85,7 +86,43 @@ using PhiMap = MapVector<PHINode *, BBValueVector>;
85
86
using BB2BBVecMap = MapVector<BasicBlock *, BBVector>;
86
87
87
88
using BBPhiMap = DenseMap<BasicBlock *, PhiMap>;
88
- using BBPredicates = DenseMap<BasicBlock *, Value *>;
89
+
90
+ using MaybeCondBranchWeights = std::optional<class CondBranchWeights >;
91
+
92
+ class CondBranchWeights {
93
+ uint32_t TrueWeight;
94
+ uint32_t FalseWeight;
95
+
96
+ CondBranchWeights (uint32_t T, uint32_t F) : TrueWeight(T), FalseWeight(F) {}
97
+
98
+ public:
99
+ static MaybeCondBranchWeights tryParse (const BranchInst &Br) {
100
+ assert (Br.isConditional ());
101
+
102
+ uint64_t T, F;
103
+ if (!extractBranchWeights (Br, T, F))
104
+ return std::nullopt;
105
+
106
+ return CondBranchWeights (T, F);
107
+ }
108
+
109
+ static void setMetadata (BranchInst &Br,
110
+ const MaybeCondBranchWeights &Weights) {
111
+ assert (Br.isConditional ());
112
+ if (!Weights)
113
+ return ;
114
+ uint32_t Arr[] = {Weights->TrueWeight , Weights->FalseWeight };
115
+ setBranchWeights (Br, Arr, false );
116
+ }
117
+
118
+ CondBranchWeights invert () const {
119
+ return CondBranchWeights{FalseWeight, TrueWeight};
120
+ }
121
+ };
122
+
123
+ using ValueWeightPair = std::pair<Value *, MaybeCondBranchWeights>;
124
+
125
+ using BBPredicates = DenseMap<BasicBlock *, ValueWeightPair>;
89
126
using PredMap = DenseMap<BasicBlock *, BBPredicates>;
90
127
using BB2BBMap = DenseMap<BasicBlock *, BasicBlock *>;
91
128
@@ -271,7 +308,7 @@ class StructurizeCFG {
271
308
272
309
void analyzeLoops (RegionNode *N);
273
310
274
- Value * buildCondition (BranchInst *Term, unsigned Idx, bool Invert);
311
+ ValueWeightPair buildCondition (BranchInst *Term, unsigned Idx, bool Invert);
275
312
276
313
void gatherPredicates (RegionNode *N);
277
314
@@ -449,16 +486,22 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) {
449
486
}
450
487
451
488
// / Build the condition for one edge
452
- Value * StructurizeCFG::buildCondition (BranchInst *Term, unsigned Idx,
453
- bool Invert) {
489
+ ValueWeightPair StructurizeCFG::buildCondition (BranchInst *Term, unsigned Idx,
490
+ bool Invert) {
454
491
Value *Cond = Invert ? BoolFalse : BoolTrue;
492
+ MaybeCondBranchWeights Weights;
493
+
455
494
if (Term->isConditional ()) {
456
495
Cond = Term->getCondition ();
496
+ Weights = CondBranchWeights::tryParse (*Term);
457
497
458
- if (Idx != (unsigned )Invert)
498
+ if (Idx != (unsigned )Invert) {
459
499
Cond = invertCondition (Cond);
500
+ if (Weights)
501
+ Weights = Weights->invert ();
502
+ }
460
503
}
461
- return Cond;
504
+ return { Cond, Weights} ;
462
505
}
463
506
464
507
// / Analyze the predecessors of each block and build up predicates
@@ -490,8 +533,8 @@ void StructurizeCFG::gatherPredicates(RegionNode *N) {
490
533
if (Visited.count (Other) && !Loops.count (Other) &&
491
534
!Pred.count (Other) && !Pred.count (P)) {
492
535
493
- Pred[Other] = BoolFalse;
494
- Pred[P] = BoolTrue;
536
+ Pred[Other] = { BoolFalse, std::nullopt} ;
537
+ Pred[P] = { BoolTrue, std::nullopt} ;
495
538
continue ;
496
539
}
497
540
}
@@ -512,9 +555,9 @@ void StructurizeCFG::gatherPredicates(RegionNode *N) {
512
555
513
556
BasicBlock *Entry = R->getEntry ();
514
557
if (Visited.count (Entry))
515
- Pred[Entry] = BoolTrue;
558
+ Pred[Entry] = { BoolTrue, std::nullopt} ;
516
559
else
517
- LPred[Entry] = BoolFalse;
560
+ LPred[Entry] = { BoolFalse, std::nullopt} ;
518
561
}
519
562
}
520
563
}
@@ -578,12 +621,14 @@ void StructurizeCFG::insertConditions(bool Loops) {
578
621
Dominator.addBlock (Parent);
579
622
580
623
Value *ParentValue = nullptr ;
581
- for (std::pair<BasicBlock *, Value *> BBAndPred : Preds) {
624
+ MaybeCondBranchWeights ParentWeights = std::nullopt;
625
+ for (std::pair<BasicBlock *, ValueWeightPair> BBAndPred : Preds) {
582
626
BasicBlock *BB = BBAndPred.first ;
583
- Value * Pred = BBAndPred.second ;
627
+ auto [ Pred, Weight] = BBAndPred.second ;
584
628
585
629
if (BB == Parent) {
586
630
ParentValue = Pred;
631
+ ParentWeights = Weight;
587
632
break ;
588
633
}
589
634
PhiInserter.AddAvailableValue (BB, Pred);
@@ -592,6 +637,7 @@ void StructurizeCFG::insertConditions(bool Loops) {
592
637
593
638
if (ParentValue) {
594
639
Term->setCondition (ParentValue);
640
+ CondBranchWeights::setMetadata (*Term, ParentWeights);
595
641
} else {
596
642
if (!Dominator.resultIsRememberedBlock ())
597
643
PhiInserter.AddAvailableValue (Dominator.result (), Default);
@@ -607,7 +653,7 @@ void StructurizeCFG::simplifyConditions() {
607
653
for (auto &I : concat<PredMap::value_type>(Predicates, LoopPreds)) {
608
654
auto &Preds = I.second ;
609
655
for (auto &J : Preds) {
610
- auto & Cond = J.second ;
656
+ Value * Cond = J.second . first ;
611
657
Instruction *Inverted;
612
658
if (match (Cond, m_Not (m_OneUse (m_Instruction (Inverted)))) &&
613
659
!Cond->use_empty ()) {
@@ -904,9 +950,10 @@ void StructurizeCFG::setPrevNode(BasicBlock *BB) {
904
950
// / Does BB dominate all the predicates of Node?
905
951
bool StructurizeCFG::dominatesPredicates (BasicBlock *BB, RegionNode *Node) {
906
952
BBPredicates &Preds = Predicates[Node->getEntry ()];
907
- return llvm::all_of (Preds, [&](std::pair<BasicBlock *, Value *> Pred) {
908
- return DT->dominates (BB, Pred.first );
909
- });
953
+ return llvm::all_of (Preds,
954
+ [&](std::pair<BasicBlock *, ValueWeightPair> Pred) {
955
+ return DT->dominates (BB, Pred.first );
956
+ });
910
957
}
911
958
912
959
// / Can we predict that this node will always be called?
@@ -918,9 +965,9 @@ bool StructurizeCFG::isPredictableTrue(RegionNode *Node) {
918
965
if (!PrevNode)
919
966
return true ;
920
967
921
- for (std::pair<BasicBlock*, Value* > Pred : Preds) {
968
+ for (std::pair<BasicBlock *, ValueWeightPair > Pred : Preds) {
922
969
BasicBlock *BB = Pred.first ;
923
- Value *V = Pred.second ;
970
+ Value *V = Pred.second . first ;
924
971
925
972
if (V != BoolTrue)
926
973
return false ;
0 commit comments