Skip to content

LoopRotationUtils: Special case zero-branch weight cases #66681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 52 additions & 23 deletions llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,33 +295,62 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
// We cannot generally deduce how often we had a zero-trip count loop so we
// have to make a guess for how to distribute x among the new x0 and x1.

uint32_t ExitWeight0 = 0; // aka x0
if (HasConditionalPreHeader) {
// Here we cannot know how many 0-trip count loops we have, so we guess:
if (OrigLoopBackedgeWeight > OrigLoopExitWeight) {
// If the loop count is bigger than the exit count then we set
// probabilities as if 0-trip count nearly never happens.
ExitWeight0 = ZeroTripCountWeights[0];
// Scale up counts if necessary so we can match `ZeroTripCountWeights` for
// the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
// ... but don't overflow.
uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
(OrigLoopExitWeight & HighBit) != 0)
break;
OrigLoopBackedgeWeight <<= 1;
OrigLoopExitWeight <<= 1;
uint32_t ExitWeight0; // aka x0
uint32_t ExitWeight1; // aka x1
uint32_t EnterWeight; // aka y0
uint32_t LoopBackWeight; // aka y1
if (OrigLoopExitWeight > 0 && OrigLoopBackedgeWeight > 0) {
ExitWeight0 = 0;
if (HasConditionalPreHeader) {
// Here we cannot know how many 0-trip count loops we have, so we guess:
if (OrigLoopBackedgeWeight >= OrigLoopExitWeight) {
// If the loop count is bigger than the exit count then we set
// probabilities as if 0-trip count nearly never happens.
ExitWeight0 = ZeroTripCountWeights[0];
// Scale up counts if necessary so we can match `ZeroTripCountWeights`
// for the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
// ... but don't overflow.
uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
(OrigLoopExitWeight & HighBit) != 0)
break;
OrigLoopBackedgeWeight <<= 1;
OrigLoopExitWeight <<= 1;
}
} else {
// If there's a higher exit-count than backedge-count then we set
// probabilities as if there are only 0-trip and 1-trip cases.
ExitWeight0 = OrigLoopExitWeight - OrigLoopBackedgeWeight;
}
}
ExitWeight1 = OrigLoopExitWeight - ExitWeight0;
EnterWeight = ExitWeight1;
LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight;
} else if (OrigLoopExitWeight == 0) {
if (OrigLoopBackedgeWeight == 0) {
// degenerate case... keep everything zero...
ExitWeight0 = 0;
ExitWeight1 = 0;
EnterWeight = 0;
LoopBackWeight = 0;
} else {
// If there's a higher exit-count than backedge-count then we set
// probabilities as if there are only 0-trip and 1-trip cases.
ExitWeight0 = OrigLoopExitWeight - OrigLoopBackedgeWeight;
// Special case "LoopExitWeight == 0" weights which behaves like an
// endless where we don't want loop-enttry (y0) to be the same as
// loop-exit (x1).
ExitWeight0 = 0;
ExitWeight1 = 0;
EnterWeight = 1;
LoopBackWeight = OrigLoopBackedgeWeight;
}
} else {
// loop is never entered.
assert(OrigLoopBackedgeWeight == 0 && "remaining case is backedge zero");
ExitWeight0 = 1;
ExitWeight1 = 1;
EnterWeight = 0;
LoopBackWeight = 0;
}
uint32_t ExitWeight1 = OrigLoopExitWeight - ExitWeight0; // aka x1
uint32_t EnterWeight = ExitWeight1; // aka y0
uint32_t LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight; // aka y1

MDBuilder MDB(LoopBI.getContext());
MDNode *LoopWeightMD =
Expand Down
97 changes: 97 additions & 0 deletions llvm/test/Transforms/LoopRotate/update-branch-weights.ll
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
; BFI_AFTER: - inner_loop_exit: {{.*}} count = 1000
; BFI_AFTER: - outer_loop_exit: {{.*}} count = 1

; IR-LABEL: define void @func0
; IR: inner_loop_body:
; IR: br i1 %cmp1, label %inner_loop_body, label %inner_loop_exit, !prof [[PROF_FUNC0_0:![0-9]+]]
; IR: inner_loop_exit:
Expand Down Expand Up @@ -74,6 +75,7 @@ outer_loop_exit:
; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 1024
; BFI_AFTER: - loop_exit: {{.*}} count = 1024

; IR-LABEL: define void @func1
; IR: entry:
; IR: br i1 %cmp1, label %loop_body.lr.ph, label %loop_exit, !prof [[PROF_FUNC1_0:![0-9]+]]

Expand Down Expand Up @@ -114,6 +116,7 @@ loop_exit:
; - loop_header.loop_exit_crit_edge: {{.*}} count = 32
; - loop_exit: {{.*}} count = 1024

; IR-LABEL: define void @func2
; IR: entry:
; IR: br i1 %cmp1, label %loop_exit, label %loop_body.lr.ph, !prof [[PROF_FUNC2_0:![0-9]+]]

Expand Down Expand Up @@ -141,16 +144,110 @@ loop_exit:
ret void
}

; BFI_BEFORE-LABEL: block-frequency-info: func3_zero_branch_weight
; BFI_BEFORE: - entry: {{.*}} count = 1024
; BFI_BEFORE: - loop_header: {{.*}} count = 2199023255296
; BFI_BEFORE: - loop_body: {{.*}} count = 2199023254272
; BFI_BEFORE: - loop_exit: {{.*}} count = 1024

; BFI_AFTER-LABEL: block-frequency-info: func3_zero_branch_weight
; BFI_AFTER: - entry: {{.*}} count = 1024
; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 1024
; BFI_AFTER: - loop_body: {{.*}} count = 2199023255296
; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 1024
; BFI_AFTER: - loop_exit: {{.*}} count = 1024

; IR-LABEL: define void @func3_zero_branch_weight
; IR: entry:
; IR: br i1 %cmp1, label %loop_exit, label %loop_body.lr.ph, !prof [[PROF_FUNC3_0:![0-9]+]]

; IR: loop_body:
; IR: br i1 %cmp, label %loop_header.loop_exit_crit_edge, label %loop_body, !prof [[PROF_FUNC3_0]]

define void @func3_zero_branch_weight(i32 %n) !prof !3 {
entry:
br label %loop_header

loop_header:
%i = phi i32 [0, %entry], [%i_inc, %loop_body]
%cmp = icmp slt i32 %i, %n
br i1 %cmp, label %loop_exit, label %loop_body, !prof !6

loop_body:
store volatile i32 %i, ptr @g, align 4
%i_inc = add i32 %i, 1
br label %loop_header

loop_exit:
ret void
}

; IR-LABEL: define void @func4_zero_branch_weight
; IR: entry:
; IR: br i1 %cmp1, label %loop_exit, label %loop_body.lr.ph, !prof [[PROF_FUNC4_0:![0-9]+]]

; IR: loop_body:
; IR: br i1 %cmp, label %loop_header.loop_exit_crit_edge, label %loop_body, !prof [[PROF_FUNC4_0]]

define void @func4_zero_branch_weight(i32 %n) !prof !3 {
entry:
br label %loop_header

loop_header:
%i = phi i32 [0, %entry], [%i_inc, %loop_body]
%cmp = icmp slt i32 %i, %n
br i1 %cmp, label %loop_exit, label %loop_body, !prof !7

loop_body:
store volatile i32 %i, ptr @g, align 4
%i_inc = add i32 %i, 1
br label %loop_header

loop_exit:
ret void
}

; IR-LABEL: define void @func5_zero_branch_weight
; IR: entry:
; IR: br i1 %cmp1, label %loop_exit, label %loop_body.lr.ph, !prof [[PROF_FUNC5_0:![0-9]+]]

; IR: loop_body:
; IR: br i1 %cmp, label %loop_header.loop_exit_crit_edge, label %loop_body, !prof [[PROF_FUNC5_0]]

define void @func5_zero_branch_weight(i32 %n) !prof !3 {
entry:
br label %loop_header

loop_header:
%i = phi i32 [0, %entry], [%i_inc, %loop_body]
%cmp = icmp slt i32 %i, %n
br i1 %cmp, label %loop_exit, label %loop_body, !prof !8

loop_body:
store volatile i32 %i, ptr @g, align 4
%i_inc = add i32 %i, 1
br label %loop_header

loop_exit:
ret void
}

!0 = !{!"function_entry_count", i64 1}
!1 = !{!"branch_weights", i32 1000, i32 1}
!2 = !{!"branch_weights", i32 3000, i32 1000}
!3 = !{!"function_entry_count", i64 1024}
!4 = !{!"branch_weights", i32 40, i32 2}
!5 = !{!"branch_weights", i32 10240, i32 320}
!6 = !{!"branch_weights", i32 0, i32 1}
!7 = !{!"branch_weights", i32 1, i32 0}
!8 = !{!"branch_weights", i32 0, i32 0}

; IR: [[PROF_FUNC0_0]] = !{!"branch_weights", i32 2000, i32 1000}
; IR: [[PROF_FUNC0_1]] = !{!"branch_weights", i32 999, i32 1}
; IR: [[PROF_FUNC1_0]] = !{!"branch_weights", i32 127, i32 1}
; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 2433, i32 127}
; IR: [[PROF_FUNC2_0]] = !{!"branch_weights", i32 9920, i32 320}
; IR: [[PROF_FUNC2_1]] = !{!"branch_weights", i32 320, i32 0}
; IR: [[PROF_FUNC3_0]] = !{!"branch_weights", i32 0, i32 1}
; IR: [[PROF_FUNC4_0]] = !{!"branch_weights", i32 1, i32 0}
; IR: [[PROF_FUNC5_0]] = !{!"branch_weights", i32 0, i32 0}