Skip to content

Commit 13af0d9

Browse files
authored
Fix switch phi inversion (rust-lang#467)
1 parent f6af51e commit 13af0d9

File tree

4 files changed

+116
-1
lines changed

4 files changed

+116
-1
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,21 @@ void calculateUnusedValuesInFunction(
895895
});
896896

897897
if (EnzymePrintUnnecessary) {
898+
llvm::errs() << " val use analysis of " << func.getName()
899+
<< ": mode=" << to_string(mode) << "\n";
900+
for (auto &BB : func)
901+
for (auto &I : BB) {
902+
bool ivn = is_value_needed_in_reverse<ValueType::Primal>(
903+
TR, gutils, &I, mode, PrimalSeen, oldUnreachable);
904+
bool isn = is_value_needed_in_reverse<ValueType::ShadowPtr>(
905+
TR, gutils, &I, mode, PrimalSeen, oldUnreachable);
906+
llvm::errs() << I << " ivn=" << (int)ivn << " isn: " << (int)isn;
907+
auto found = gutils->knownRecomputeHeuristic.find(&I);
908+
if (found != gutils->knownRecomputeHeuristic.end()) {
909+
llvm::errs() << " krc=" << (int)found->second;
910+
}
911+
llvm::errs() << "\n";
912+
}
898913
llvm::errs() << "unnecessaryValues of " << func.getName()
899914
<< ": mode=" << to_string(mode) << "\n";
900915
for (auto a : unnecessaryValues) {

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4969,9 +4969,21 @@ fast:;
49694969
}
49704970
} else {
49714971
for (auto pair : *replacePHIs) {
4972-
Value *cas = si->findCaseDest(pair.first);
4972+
Value *cas = nullptr;
4973+
for (auto c : si->cases()) {
4974+
if (pair.first ==
4975+
*done[std::make_pair(block, c.getCaseSuccessor())].begin()) {
4976+
cas = c.getCaseValue();
4977+
break;
4978+
}
4979+
}
4980+
if (cas == nullptr) {
4981+
assert(pair.first ==
4982+
*done[std::make_pair(block, si->getDefaultDest())].begin());
4983+
}
49734984
Value *val = nullptr;
49744985
Value *phi = lookupM(si->getCondition(), BuilderM);
4986+
49754987
if (cas) {
49764988
val = BuilderM.CreateICmpEQ(cas, phi);
49774989
} else {

enzyme/Enzyme/GradientUtils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,15 @@ class GradientUtils : public CacheUtility {
525525
if (hasMetadata(CI, "enzyme_fromstack")) {
526526
allocationsWithGuaranteedFree[CI].insert(CI);
527527
}
528+
// TODO compute if an only load/store (non capture)
529+
// allocaion by traversing its users. If so, mark
530+
// all of its load/stores, as now the loads can
531+
// potentially be rematerialized without a cache
532+
// of the allocation, but the operands of all stores.
533+
// This info needs to be provided to minCutCache
534+
// the derivative of store needs to redo the store,
535+
// isValueNeededInReverse needs to know to preserve the
536+
// store operands in this case, etc
528537
}
529538
}
530539
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
declare double @llvm.pow.f64(double, double)
4+
5+
declare dso_local double @__enzyme_autodiff(i8*, double, i64)
6+
7+
@.str = private unnamed_addr constant [10 x i8] c"result=%f\00", align 1
8+
9+
declare dso_local i32 @printf(i8*, ...)
10+
11+
define void @main() {
12+
entry:
13+
%call = tail call double @__enzyme_autodiff(i8* bitcast (double (double, i64)* @julia_euroad_1769 to i8*), double 0.5, i64 3)
14+
%p = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([10 x i8], [10 x i8]* @.str, i64 0, i64 0), double %call)
15+
ret void
16+
}
17+
18+
define double @julia_euroad_1769(double %arg, i64 %i5) {
19+
bb:
20+
switch i64 %i5, label %bb9 [
21+
i64 12, label %bb12
22+
i64 7, label %bb7
23+
]
24+
25+
bb7: ; preds = %bb4
26+
%i7 = fmul double %arg, %arg
27+
br label %bb13
28+
29+
bb9: ; preds = %bb4
30+
%ti5 = uitofp i64 %i5 to double
31+
%i9 = call double @llvm.pow.f64(double %arg, double %ti5)
32+
br label %bb13
33+
34+
bb12: ; preds = %bb4
35+
br label %bb13
36+
37+
bb13: ; preds = %bb12, %bb9, %bb8, %bb7, %bb4
38+
%i14 = phi double [ %i7, %bb7 ], [ %i9, %bb9 ], [ %arg, %bb12 ]
39+
ret double %i14
40+
}
41+
42+
!llvm.module.flags = !{!0, !1}
43+
44+
!0 = !{i32 2, !"Dwarf Version", i32 4}
45+
!1 = !{i32 2, !"Debug Info Version", i32 3}
46+
47+
; CHECK: define internal { double } @diffejulia_euroad_1769(double %arg, i64 %i5, double %differeturn)
48+
; CHECK-NEXT: bb:
49+
; CHECK-NEXT: %0 = icmp eq i64 7, %i5
50+
; CHECK-NEXT: %1 = icmp eq i64 12, %i5
51+
; CHECK-NEXT: %2 = or i1 %1, %0
52+
; CHECK-NEXT: %3 = select {{(fast )?}}i1 %1, double %differeturn, double 0.000000e+00
53+
; CHECK-NEXT: %4 = select {{(fast )?}}i1 %2, double 0.000000e+00, double %differeturn
54+
; CHECK-NEXT: %5 = select {{(fast )?}}i1 %0, double %differeturn, double 0.000000e+00
55+
; CHECK-NEXT: switch i64 %i5, label %invertbb9 [
56+
; CHECK-NEXT: i64 12, label %invertbb
57+
; CHECK-NEXT: i64 7, label %invertbb7
58+
; CHECK-NEXT: ]
59+
60+
; CHECK: invertbb: ; preds = %bb, %invertbb9, %invertbb7
61+
; CHECK-NEXT: %"arg'de.0" = phi double [ %13, %invertbb9 ], [ %8, %invertbb7 ], [ %3, %bb ]
62+
; CHECK-NEXT: %6 = insertvalue { double } undef, double %"arg'de.0", 0
63+
; CHECK-NEXT: ret { double } %6
64+
65+
; CHECK: invertbb7: ; preds = %bb
66+
; CHECK-NEXT: %m0diffearg = fmul fast double %5, %arg
67+
; CHECK-NEXT: %7 = fadd fast double %3, %m0diffearg
68+
; CHECK-NEXT: %8 = fadd fast double %7, %m0diffearg
69+
; CHECK-NEXT: br label %invertbb
70+
71+
; CHECK: invertbb9: ; preds = %bb
72+
; CHECK-NEXT: %ti5_unwrap = uitofp i64 %i5 to double
73+
; CHECK-NEXT: %9 = fsub fast double %ti5_unwrap, 1.000000e+00
74+
; CHECK-NEXT: %10 = call fast double @llvm.pow.f64(double %arg, double %9)
75+
; CHECK-NEXT: %11 = fmul fast double %4, %10
76+
; CHECK-NEXT: %12 = fmul fast double %11, %ti5_unwrap
77+
; CHECK-NEXT: %13 = fadd fast double %3, %12
78+
; CHECK-NEXT: br label %invertbb
79+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)