@@ -594,11 +594,13 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
594
594
ExitCaseIndices.push_back (Case.getCaseIndex ());
595
595
}
596
596
BasicBlock *DefaultExitBB = nullptr ;
597
+ SwitchInstProfUpdateWrapper::CaseWeightOpt DefaultCaseWeight =
598
+ SwitchInstProfUpdateWrapper::getSuccessorWeight (SI, 0 );
597
599
if (!L.contains (SI.getDefaultDest ()) &&
598
600
areLoopExitPHIsLoopInvariant (L, *ParentBB, *SI.getDefaultDest ()) &&
599
- !isa<UnreachableInst>(SI.getDefaultDest ()->getTerminator ()))
601
+ !isa<UnreachableInst>(SI.getDefaultDest ()->getTerminator ())) {
600
602
DefaultExitBB = SI.getDefaultDest ();
601
- else if (ExitCaseIndices.empty ())
603
+ } else if (ExitCaseIndices.empty ())
602
604
return false ;
603
605
604
606
LLVM_DEBUG (dbgs () << " unswitching trivial switch...\n " );
@@ -622,8 +624,11 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
622
624
623
625
// Store the exit cases into a separate data structure and remove them from
624
626
// the switch.
625
- SmallVector<std::pair<ConstantInt *, BasicBlock *>, 4 > ExitCases;
627
+ SmallVector<std::tuple<ConstantInt *, BasicBlock *,
628
+ SwitchInstProfUpdateWrapper::CaseWeightOpt>,
629
+ 4 > ExitCases;
626
630
ExitCases.reserve (ExitCaseIndices.size ());
631
+ SwitchInstProfUpdateWrapper SIW (SI);
627
632
// We walk the case indices backwards so that we remove the last case first
628
633
// and don't disrupt the earlier indices.
629
634
for (unsigned Index : reverse (ExitCaseIndices)) {
@@ -633,9 +638,10 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
633
638
if (!ExitL || ExitL->contains (OuterL))
634
639
OuterL = ExitL;
635
640
// Save the value of this case.
636
- ExitCases.push_back ({CaseI->getCaseValue (), CaseI->getCaseSuccessor ()});
641
+ auto W = SIW.getSuccessorWeight (CaseI->getSuccessorIndex ());
642
+ ExitCases.emplace_back (CaseI->getCaseValue (), CaseI->getCaseSuccessor (), W);
637
643
// Delete the unswitched cases.
638
- SI .removeCase (CaseI);
644
+ SIW .removeCase (CaseI);
639
645
}
640
646
641
647
if (SE) {
@@ -673,6 +679,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
673
679
674
680
// Now add the unswitched switch.
675
681
auto *NewSI = SwitchInst::Create (LoopCond, NewPH, ExitCases.size (), OldPH);
682
+ SwitchInstProfUpdateWrapper NewSIW (*NewSI);
676
683
677
684
// Rewrite the IR for the unswitched basic blocks. This requires two steps.
678
685
// First, we split any exit blocks with remaining in-loop predecessors. Then
@@ -700,9 +707,9 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
700
707
}
701
708
// Note that we must use a reference in the for loop so that we update the
702
709
// container.
703
- for (auto &CasePair : reverse (ExitCases)) {
710
+ for (auto &ExitCase : reverse (ExitCases)) {
704
711
// Grab a reference to the exit block in the pair so that we can update it.
705
- BasicBlock *ExitBB = CasePair. second ;
712
+ BasicBlock *ExitBB = std::get< 1 >(ExitCase) ;
706
713
707
714
// If this case is the last edge into the exit block, we can simply reuse it
708
715
// as it will no longer be a loop exit. No mapping necessary.
@@ -724,27 +731,39 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
724
731
/* FullUnswitch*/ true );
725
732
}
726
733
// Update the case pair to point to the split block.
727
- CasePair. second = SplitExitBB;
734
+ std::get< 1 >(ExitCase) = SplitExitBB;
728
735
}
729
736
730
737
// Now add the unswitched cases. We do this in reverse order as we built them
731
738
// in reverse order.
732
- for (auto CasePair : reverse (ExitCases)) {
733
- ConstantInt *CaseVal = CasePair. first ;
734
- BasicBlock *UnswitchedBB = CasePair. second ;
739
+ for (auto &ExitCase : reverse (ExitCases)) {
740
+ ConstantInt *CaseVal = std::get< 0 >(ExitCase) ;
741
+ BasicBlock *UnswitchedBB = std::get< 1 >(ExitCase) ;
735
742
736
- NewSI-> addCase (CaseVal, UnswitchedBB);
743
+ NewSIW. addCase (CaseVal, UnswitchedBB, std::get< 2 >(ExitCase) );
737
744
}
738
745
739
746
// If the default was unswitched, re-point it and add explicit cases for
740
747
// entering the loop.
741
748
if (DefaultExitBB) {
742
- NewSI->setDefaultDest (DefaultExitBB);
749
+ NewSIW->setDefaultDest (DefaultExitBB);
750
+ NewSIW.setSuccessorWeight (0 , DefaultCaseWeight);
743
751
744
752
// We removed all the exit cases, so we just copy the cases to the
745
753
// unswitched switch.
746
- for (auto Case : SI.cases ())
747
- NewSI->addCase (Case.getCaseValue (), NewPH);
754
+ for (const auto &Case : SI.cases ())
755
+ NewSIW.addCase (Case.getCaseValue (), NewPH,
756
+ SIW.getSuccessorWeight (Case.getSuccessorIndex ()));
757
+ } else if (DefaultCaseWeight) {
758
+ // We have to set branch weight of the default case.
759
+ uint64_t SW = *DefaultCaseWeight;
760
+ for (const auto &Case : SI.cases ()) {
761
+ auto W = SIW.getSuccessorWeight (Case.getSuccessorIndex ());
762
+ assert (W &&
763
+ " case weight must be defined as default case weight is defined" );
764
+ SW += *W;
765
+ }
766
+ NewSIW.setSuccessorWeight (0 , SW);
748
767
}
749
768
750
769
// If we ended up with a common successor for every path through the switch
@@ -769,7 +788,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
769
788
/* KeepOneInputPHIs*/ true );
770
789
}
771
790
// Now nuke the switch and replace it with a direct branch.
772
- SI .eraseFromParent ();
791
+ SIW .eraseFromParent ();
773
792
BranchInst::Create (CommonSuccBB, BB);
774
793
} else if (DefaultExitBB) {
775
794
assert (SI.getNumCases () > 0 &&
@@ -779,8 +798,11 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
779
798
// being simple and keeping the number of edges from this switch to
780
799
// successors the same, and avoiding any PHI update complexity.
781
800
auto LastCaseI = std::prev (SI.case_end ());
801
+
782
802
SI.setDefaultDest (LastCaseI->getCaseSuccessor ());
783
- SI.removeCase (LastCaseI);
803
+ SIW.setSuccessorWeight (
804
+ 0 , SIW.getSuccessorWeight (LastCaseI->getSuccessorIndex ()));
805
+ SIW.removeCase (LastCaseI);
784
806
}
785
807
786
808
// Walk the unswitched exit blocks and the unswitched split blocks and update
0 commit comments