Skip to content

Commit f054947

Browse files
authored
[SimplifyCFG] Prevent merging cbranch to cbranch if the branch probability from the first to second is too low. (#69375)
AMDGPU target has faced the situation which can be illustrated with the following testcase: define void @dont_merge_cbranches(i32 %V) { %divergent_cond = icmp ne i32 %V, 0 %uniform_cond = call i1 @uniform_result(i1 %divergent_cond) br i1 %uniform_cond, label %bb2, label %exit, !prof !0 bb2: br i1 %divergent_cond, label %bb3, label %exit bb3: call void @bar( ) br label %exit exit: ret void } !0 = !{!"branch_weights", i32 1, i32 100000} SimplifyCFG merges branches on %uniform_cond and %divergent_cond which is undesirable because the first branch to bb2 is taken extremely rare and the second branch is expensive. The merged branch becomes as expensive as the second. This patch prevents such merging if the branch to the second branch is unlikely to happen.
1 parent dde85f8 commit f054947

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4347,6 +4347,20 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
43474347
if (PBI->getSuccessor(PBIOp) == BB)
43484348
return false;
43494349

4350+
// If predecessor's branch probability to BB is too low don't merge branches.
4351+
SmallVector<uint32_t, 2> PredWeights;
4352+
if (!PBI->getMetadata(LLVMContext::MD_unpredictable) &&
4353+
extractBranchWeights(*PBI, PredWeights) &&
4354+
(PredWeights[0] + PredWeights[1]) != 0) {
4355+
4356+
BranchProbability CommonDestProb = BranchProbability::getBranchProbability(
4357+
PredWeights[PBIOp], PredWeights[0] + PredWeights[1]);
4358+
4359+
BranchProbability Likely = TTI.getPredictableBranchThreshold();
4360+
if (CommonDestProb >= Likely)
4361+
return false;
4362+
}
4363+
43504364
// Do not perform this transformation if it would require
43514365
// insertion of a large number of select instructions. For targets
43524366
// without predication/cmovs, this is a big pessimization.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -passes=simplifycfg -S | FileCheck %s
3+
4+
declare void @bar()
5+
declare i1 @uniform_result(i1 %c)
6+
7+
define void @dont_merge_cbranches1(i32 %V) {
8+
; CHECK-LABEL: @dont_merge_cbranches1(
9+
; CHECK-NEXT: [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
10+
; CHECK-NEXT: [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
11+
; CHECK-NEXT: br i1 [[UNIFORM_COND]], label [[BB2:%.*]], label [[EXIT:%.*]], !prof [[PROF0:![0-9]+]]
12+
; CHECK: bb2:
13+
; CHECK-NEXT: br i1 [[DIVERGENT_COND]], label [[BB3:%.*]], label [[EXIT]]
14+
; CHECK: bb3:
15+
; CHECK-NEXT: call void @bar()
16+
; CHECK-NEXT: br label [[EXIT]]
17+
; CHECK: exit:
18+
; CHECK-NEXT: ret void
19+
;
20+
%divergent_cond = icmp ne i32 %V, 0
21+
%uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
22+
br i1 %uniform_cond, label %bb2, label %exit, !prof !0
23+
bb2:
24+
br i1 %divergent_cond, label %bb3, label %exit
25+
bb3:
26+
call void @bar( )
27+
br label %exit
28+
exit:
29+
ret void
30+
}
31+
32+
define void @dont_merge_cbranches2(i32 %V) {
33+
; CHECK-LABEL: @dont_merge_cbranches2(
34+
; CHECK-NEXT: [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
35+
; CHECK-NEXT: [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
36+
; CHECK-NEXT: br i1 [[UNIFORM_COND]], label [[EXIT:%.*]], label [[BB2:%.*]], !prof [[PROF1:![0-9]+]]
37+
; CHECK: bb2:
38+
; CHECK-NEXT: br i1 [[DIVERGENT_COND]], label [[BB3:%.*]], label [[EXIT]]
39+
; CHECK: bb3:
40+
; CHECK-NEXT: call void @bar()
41+
; CHECK-NEXT: br label [[EXIT]]
42+
; CHECK: exit:
43+
; CHECK-NEXT: ret void
44+
;
45+
%divergent_cond = icmp ne i32 %V, 0
46+
%uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
47+
br i1 %uniform_cond, label %exit, label %bb2, !prof !1
48+
bb2:
49+
br i1 %divergent_cond, label %bb3, label %exit
50+
bb3:
51+
call void @bar( )
52+
br label %exit
53+
exit:
54+
ret void
55+
}
56+
57+
define void @merge_cbranches(i32 %V) {
58+
; CHECK-LABEL: @merge_cbranches(
59+
; CHECK-NEXT: [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
60+
; CHECK-NEXT: [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
61+
; CHECK-NEXT: [[DIVERGENT_COND_NOT:%.*]] = xor i1 [[DIVERGENT_COND]], true
62+
; CHECK-NEXT: [[BRMERGE:%.*]] = select i1 [[UNIFORM_COND]], i1 true, i1 [[DIVERGENT_COND_NOT]]
63+
; CHECK-NEXT: br i1 [[BRMERGE]], label [[EXIT:%.*]], label [[BB3:%.*]], !prof [[PROF2:![0-9]+]]
64+
; CHECK: bb3:
65+
; CHECK-NEXT: call void @bar()
66+
; CHECK-NEXT: br label [[EXIT]]
67+
; CHECK: exit:
68+
; CHECK-NEXT: ret void
69+
;
70+
%divergent_cond = icmp ne i32 %V, 0
71+
%uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
72+
br i1 %uniform_cond, label %exit, label %bb2, !prof !2
73+
bb2:
74+
br i1 %divergent_cond, label %bb3, label %exit
75+
bb3:
76+
call void @bar( )
77+
br label %exit
78+
exit:
79+
ret void
80+
}
81+
82+
!0 = !{!"branch_weights", i32 1, i32 1000}
83+
!1 = !{!"branch_weights", i32 1000, i32 1}
84+
!2 = !{!"branch_weights", i32 3, i32 2}

0 commit comments

Comments
 (0)