Skip to content

Commit 6d3b2c4

Browse files
committed
call conv fix
1 parent d7f4967 commit 6d3b2c4

File tree

4 files changed

+234
-7
lines changed

4 files changed

+234
-7
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3390,6 +3390,12 @@ class AdjointGenerator
33903390
if (called &&
33913391
(called->getName() == "asin" || called->getName() == "asinf" ||
33923392
called->getName() == "asinl")) {
3393+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3394+
if (!gutils->knownRecomputeHeuristic[orig]) {
3395+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3396+
getIndex(orig, CacheType::Self));
3397+
}
3398+
}
33933399
eraseIfUnused(*orig);
33943400
if (Mode == DerivativeMode::ReverseModePrimal ||
33953401
gutils->isConstantInstruction(orig))
@@ -3418,6 +3424,12 @@ class AdjointGenerator
34183424
(called->getName() == "atan" || called->getName() == "atanf" ||
34193425
called->getName() == "atanl" ||
34203426
called->getName() == "__fd_atan_1")) {
3427+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3428+
if (!gutils->knownRecomputeHeuristic[orig]) {
3429+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3430+
getIndex(orig, CacheType::Self));
3431+
}
3432+
}
34213433
eraseIfUnused(*orig);
34223434
if (Mode == DerivativeMode::ReverseModePrimal ||
34233435
gutils->isConstantInstruction(orig))
@@ -3436,6 +3448,12 @@ class AdjointGenerator
34363448

34373449
if (called &&
34383450
(called->getName() == "tanhf" || called->getName() == "tanh")) {
3451+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3452+
if (!gutils->knownRecomputeHeuristic[orig]) {
3453+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3454+
getIndex(orig, CacheType::Self));
3455+
}
3456+
}
34393457
eraseIfUnused(*orig);
34403458
if (Mode == DerivativeMode::ReverseModePrimal ||
34413459
gutils->isConstantInstruction(orig))
@@ -3459,6 +3477,12 @@ class AdjointGenerator
34593477
}
34603478

34613479
if (called->getName() == "coshf" || called->getName() == "cosh") {
3480+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3481+
if (!gutils->knownRecomputeHeuristic[orig]) {
3482+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3483+
getIndex(orig, CacheType::Self));
3484+
}
3485+
}
34623486
eraseIfUnused(*orig);
34633487
if (Mode == DerivativeMode::ReverseModePrimal ||
34643488
gutils->isConstantInstruction(orig))
@@ -3480,6 +3504,12 @@ class AdjointGenerator
34803504
return;
34813505
}
34823506
if (called->getName() == "sinhf" || called->getName() == "sinh") {
3507+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3508+
if (!gutils->knownRecomputeHeuristic[orig]) {
3509+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3510+
getIndex(orig, CacheType::Self));
3511+
}
3512+
}
34833513
eraseIfUnused(*orig);
34843514
if (Mode == DerivativeMode::ReverseModePrimal ||
34853515
gutils->isConstantInstruction(orig))
@@ -3503,6 +3533,12 @@ class AdjointGenerator
35033533

35043534
if (called) {
35053535
if (called->getName() == "erf") {
3536+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3537+
if (!gutils->knownRecomputeHeuristic[orig]) {
3538+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3539+
getIndex(orig, CacheType::Self));
3540+
}
3541+
}
35063542
eraseIfUnused(*orig);
35073543
if (Mode == DerivativeMode::ReverseModePrimal ||
35083544
gutils->isConstantInstruction(orig))
@@ -3529,6 +3565,12 @@ class AdjointGenerator
35293565
return;
35303566
}
35313567
if (called->getName() == "erfi") {
3568+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3569+
if (!gutils->knownRecomputeHeuristic[orig]) {
3570+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3571+
getIndex(orig, CacheType::Self));
3572+
}
3573+
}
35323574
eraseIfUnused(*orig);
35333575
if (Mode == DerivativeMode::ReverseModePrimal ||
35343576
gutils->isConstantInstruction(orig))
@@ -3555,6 +3597,12 @@ class AdjointGenerator
35553597
return;
35563598
}
35573599
if (called->getName() == "erfc") {
3600+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3601+
if (!gutils->knownRecomputeHeuristic[orig]) {
3602+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3603+
getIndex(orig, CacheType::Self));
3604+
}
3605+
}
35583606
eraseIfUnused(*orig);
35593607
if (Mode == DerivativeMode::ReverseModePrimal ||
35603608
gutils->isConstantInstruction(orig))
@@ -3583,6 +3631,12 @@ class AdjointGenerator
35833631

35843632
if (called->getName() == "j0" || called->getName() == "y0" ||
35853633
called->getName() == "j0f" || called->getName() == "y0f") {
3634+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3635+
if (!gutils->knownRecomputeHeuristic[orig]) {
3636+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3637+
getIndex(orig, CacheType::Self));
3638+
}
3639+
}
35863640
eraseIfUnused(*orig);
35873641
if (Mode == DerivativeMode::ReverseModePrimal ||
35883642
gutils->isConstantInstruction(orig))
@@ -3609,6 +3663,12 @@ class AdjointGenerator
36093663

36103664
if (called->getName() == "j1" || called->getName() == "y1" ||
36113665
called->getName() == "j1f" || called->getName() == "y1f") {
3666+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3667+
if (!gutils->knownRecomputeHeuristic[orig]) {
3668+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3669+
getIndex(orig, CacheType::Self));
3670+
}
3671+
}
36123672
eraseIfUnused(*orig);
36133673
if (Mode == DerivativeMode::ReverseModePrimal ||
36143674
gutils->isConstantInstruction(orig))
@@ -3648,6 +3708,12 @@ class AdjointGenerator
36483708

36493709
if (called->getName() == "jn" || called->getName() == "yn" ||
36503710
called->getName() == "jnf" || called->getName() == "ynf") {
3711+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3712+
if (!gutils->knownRecomputeHeuristic[orig]) {
3713+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3714+
getIndex(orig, CacheType::Self));
3715+
}
3716+
}
36513717
eraseIfUnused(*orig);
36523718
if (Mode == DerivativeMode::ReverseModePrimal ||
36533719
gutils->isConstantInstruction(orig))
@@ -3717,6 +3783,12 @@ class AdjointGenerator
37173783
}
37183784
}
37193785
if (called->getName() == "__fd_sincos_1") {
3786+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3787+
if (!gutils->knownRecomputeHeuristic[orig]) {
3788+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3789+
getIndex(orig, CacheType::Self));
3790+
}
3791+
}
37203792
if (Mode == DerivativeMode::ReverseModePrimal ||
37213793
gutils->isConstantInstruction(orig)) {
37223794
eraseIfUnused(*orig);
@@ -3753,6 +3825,12 @@ class AdjointGenerator
37533825
}
37543826
if (called->getName() == "cabs" || called->getName() == "cabsf" ||
37553827
called->getName() == "cabsl") {
3828+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3829+
if (!gutils->knownRecomputeHeuristic[orig]) {
3830+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3831+
getIndex(orig, CacheType::Self));
3832+
}
3833+
}
37563834
if (Mode == DerivativeMode::ReverseModePrimal ||
37573835
gutils->isConstantInstruction(orig)) {
37583836
eraseIfUnused(*orig);
@@ -3789,6 +3867,12 @@ class AdjointGenerator
37893867
}
37903868
if (called->getName() == "ldexp" || called->getName() == "ldexpf" ||
37913869
called->getName() == "ldexpl") {
3870+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3871+
if (!gutils->knownRecomputeHeuristic[orig]) {
3872+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3873+
getIndex(orig, CacheType::Self));
3874+
}
3875+
}
37923876
if (Mode == DerivativeMode::ReverseModePrimal ||
37933877
gutils->isConstantInstruction(orig)) {
37943878
eraseIfUnused(*orig);
@@ -3815,6 +3899,12 @@ class AdjointGenerator
38153899
n == "lgamma_r" || n == "lgammaf_r" || n == "lgammal_r" ||
38163900
n == "__lgamma_r_finite" || n == "__lgammaf_r_finite" ||
38173901
n == "__lgammal_r_finite") {
3902+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3903+
if (!gutils->knownRecomputeHeuristic[orig]) {
3904+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3905+
getIndex(orig, CacheType::Self));
3906+
}
3907+
}
38183908
if (Mode == DerivativeMode::ReverseModePrimal ||
38193909
gutils->isConstantInstruction(orig)) {
38203910
return;
@@ -4049,8 +4139,20 @@ class AdjointGenerator
40494139
// gutils->isConstantValue(orig) << " subretused=" << subretused << " ivn:"
40504140
// << is_value_needed_in_reverse<Primal>(TR, gutils, &call, /*topLevel*/Mode
40514141
// == DerivativeMode::Both) << "\n";
4142+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
4143+
if (!gutils->knownRecomputeHeuristic[orig]) {
4144+
subretused = true;
4145+
}
4146+
}
40524147

40534148
if (gutils->isConstantInstruction(orig) && gutils->isConstantValue(orig)) {
4149+
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
4150+
if (!gutils->knownRecomputeHeuristic[orig]) {
4151+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
4152+
getIndex(orig, CacheType::Self));
4153+
return;
4154+
}
4155+
}
40544156
// If we need this value and it is illegal to recompute it (it writes or
40554157
// may load uncacheable data)
40564158
// Store and reload it

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ void calculateUnusedValuesInFunction(
818818
}
819819
return UseReq::Recur;
820820
});
821-
#if 0
821+
#if 1
822822
llvm::errs() << "unnecessaryValues of " << func.getName() << ":\n";
823823
for (auto a : unnecessaryValues) {
824824
llvm::errs() << *a << "\n";

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4341,8 +4341,8 @@ void GradientUtils::computeMinCache(
43414341
TR, this, &I,
43424342
/*topLevel*/ mode == DerivativeMode::ReverseModeCombined,
43434343
OneLevelSeen, guaranteedUnreachable);
4344-
// llvm::errs() << " not legal recompute: " << I << " oneneed: " <<
4345-
// (int)oneneed << "\n";
4344+
llvm::errs() << " not legal recompute: " << I << " oneneed: " <<
4345+
(int)oneneed << "\n";
43464346
if (oneneed)
43474347
knownRecomputeHeuristic[&I] = false;
43484348
else
@@ -4391,7 +4391,7 @@ void GradientUtils::computeMinCache(
43914391
TR, this, V,
43924392
/*topLevel*/ mode == DerivativeMode::ReverseModeCombined,
43934393
OneLevelSeen, guaranteedUnreachable)) {
4394-
// llvm::errs() << " Required: " << *V << "\n";
4394+
llvm::errs() << " Required: " << *V << "\n";
43954395
Required.insert(V);
43964396
} else {
43974397
for (auto V2 : V->users()) {
@@ -4422,11 +4422,11 @@ void GradientUtils::computeMinCache(
44224422
}
44234423

44244424
for (auto V : Intermediates) {
4425-
// llvm::errs() << " int: " << *V << " minreq: " << (int)MinReq.count(V)
4426-
// << "\n";
4425+
llvm::errs() << " int: " << *V << " minreq: " << (int)MinReq.count(V)
4426+
<< "\n";
44274427
knownRecomputeHeuristic[V] = !MinReq.count(V);
44284428
if (!NeedGraph.count(V)) {
4429-
// llvm::errs() << " ++ unnecessary\n";
4429+
llvm::errs() << " ++ unnecessary\n";
44304430
unnecessaryIntermediates.insert(cast<Instruction>(V));
44314431
}
44324432
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -simplifycfg -S | FileCheck %s
2+
; ModuleID = 'inp.ll'
3+
4+
declare dso_local void @_Z17__enzyme_autodiffPvPdS0_i(i8*, double*, double*, i64*) local_unnamed_addr #4
5+
define dso_local void @outer(double* %m, double* %m2, i64* %n) local_unnamed_addr #2 {
6+
entry:
7+
call void @_Z17__enzyme_autodiffPvPdS0_i(i8* bitcast (double (double*, i64*)* @_Z10reduce_maxPdi to i8*), double* nonnull %m, double* nonnull %m2, i64* %n)
8+
ret void
9+
}
10+
; Function Attrs: nounwind uwtable
11+
define dso_local double @_Z10reduce_maxPdi(double* %vec, i64* %v) #0 {
12+
entry:
13+
%res = call double @pb(double* %vec, i64* %v)
14+
store i64 0, i64* %v, align 8
15+
ret double %res
16+
}
17+
18+
define double @pb(double* %__x, i64* %v) {
19+
entry:
20+
br label %for.body
21+
22+
for.body: ; preds = %for.body, %entry
23+
%tiv = phi i64 [ %inc, %for.body ], [ 0, %entry ]
24+
%ig = getelementptr inbounds i64, i64* %v, i64 %tiv
25+
%iload = load i64, i64* %ig
26+
%icall = call double @tpfop(i64 %iload)
27+
28+
%dg = getelementptr inbounds double, double* %__x, i64 %tiv
29+
%dload = load double, double* %dg
30+
%mul = fmul double %dload, %icall
31+
store double %mul, double* %dg
32+
33+
%inc = add nsw i64 %tiv, 1
34+
%cmp = icmp slt i64 %inc, 4
35+
br i1 %cmp, label %for.body, label %for.end
36+
37+
for.end: ; preds = %for.cond
38+
ret double 0.000000e+00
39+
}
40+
41+
define double @usesize(double* %ptr, i64 %off) {
42+
entry:
43+
%p2 = getelementptr inbounds double, double* %ptr, i64 %off
44+
%ld = load double, double* %p2, align 8
45+
ret double %ld
46+
}
47+
48+
define double @tpfop(i64 %ptr) #0 {
49+
entry:
50+
%d = bitcast i64 %ptr to double
51+
ret double %d
52+
}
53+
54+
!4 = !{!5, i64 1, !"omnipotent char"}
55+
!5 = !{!"Simple C++ TBAA"}
56+
!6 = !{!7, !7, i64 0, i64 8}
57+
!7 = !{!4, i64 8, !"long"}
58+
59+
attributes #0 = { readnone speculatable }
60+
61+
62+
; CHECK: define internal double* @augmented_pb(double* %__x, double* %"__x'", i64* %v)
63+
; CHECK-NEXT: entry:
64+
; CHECK-NEXT: %malloccall = tail call noalias nonnull dereferenceable(32) dereferenceable_or_null(32) i8* @malloc(i64 32)
65+
; CHECK-NEXT: %icall_malloccache = bitcast i8* %malloccall to double*
66+
; CHECK-NEXT: br label %for.body
67+
68+
; CHECK: for.body: ; preds = %for.body, %entry
69+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.body ], [ 0, %entry ]
70+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
71+
; CHECK-NEXT: %ig = getelementptr inbounds i64, i64* %v, i64 %iv
72+
; CHECK-NEXT: %iload = load i64, i64* %ig
73+
; CHECK-NEXT: %icall = call double @tpfop(i64 %iload)
74+
; CHECK-NEXT: %dg = getelementptr inbounds double, double* %__x, i64 %iv
75+
; CHECK-NEXT: %dload = load double, double* %dg
76+
; CHECK-NEXT: %mul = fmul double %dload, %icall
77+
; CHECK-NEXT: store double %mul, double* %dg
78+
; CHECK-NEXT: %0 = getelementptr inbounds double, double* %icall_malloccache, i64 %iv
79+
; CHECK-NEXT: store double %icall, double* %0, align 8, !invariant.group !0
80+
; CHECK-NEXT: %cmp = icmp ne i64 %iv.next, 4
81+
; CHECK-NEXT: br i1 %cmp, label %for.body, label %for.end
82+
83+
; CHECK: for.end: ; preds = %for.body
84+
; CHECK-NEXT: ret double* %icall_malloccache
85+
; CHECK-NEXT: }
86+
87+
; CHECK: define internal void @diffepb(double* %__x, double* %"__x'", i64* %v, double %differeturn, double* %tapeArg)
88+
; CHECK-NEXT: entry:
89+
; CHECK-NEXT: br label %for.body
90+
91+
; CHECK: for.body: ; preds = %for.body, %entry
92+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.body ], [ 0, %entry ]
93+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
94+
; CHECK-NEXT: %0 = getelementptr inbounds double, double* %tapeArg, i64 %iv
95+
; CHECK-NEXT: %icall = load double, double* %0, align 8, !invariant.group !1
96+
; CHECK-NEXT: %"dg'ipg" = getelementptr inbounds double, double* %"__x'", i64 %iv
97+
; CHECK-NEXT: %cmp = icmp ne i64 %iv.next, 4
98+
; CHECK-NEXT: br i1 %cmp, label %for.body, label %invertfor.body
99+
100+
; CHECK: invertentry: ; preds = %invertfor.body
101+
; CHECK-NEXT: %1 = bitcast double* %tapeArg to i8*
102+
; CHECK-NEXT: tail call void @free(i8* nonnull %1)
103+
; CHECK-NEXT: ret void
104+
105+
; CHECK: invertfor.body: ; preds = %for.body, %incinvertfor.body
106+
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %11, %incinvertfor.body ], [ 3, %for.body ]
107+
; CHECK-NEXT: %"dg'ipg_unwrap" = getelementptr inbounds double, double* %"__x'", i64 %"iv'ac.0"
108+
; CHECK-NEXT: %2 = load double, double* %"dg'ipg_unwrap"
109+
; CHECK-NEXT: store double 0.000000e+00, double* %"dg'ipg_unwrap"
110+
; CHECK-NEXT: %3 = fadd fast double 0.000000e+00, %2
111+
; CHECK-NEXT: %4 = getelementptr inbounds double, double* %tapeArg, i64 %"iv'ac.0"
112+
; CHECK-NEXT: %5 = load double, double* %4, align 8, !invariant.group !2
113+
; CHECK-NEXT: %m0diffedload = fmul fast double %3, %5
114+
; CHECK-NEXT: %6 = fadd fast double 0.000000e+00, %m0diffedload
115+
; CHECK-NEXT: %7 = load double, double* %"dg'ipg_unwrap"
116+
; CHECK-NEXT: %8 = fadd fast double %7, %6
117+
; CHECK-NEXT: store double %8, double* %"dg'ipg_unwrap"
118+
; CHECK-NEXT: %9 = icmp eq i64 %"iv'ac.0", 0
119+
; CHECK-NEXT: %10 = xor i1 %9, true
120+
; CHECK-NEXT: br i1 %9, label %invertentry, label %incinvertfor.body
121+
122+
; CHECK: incinvertfor.body: ; preds = %invertfor.body
123+
; CHECK-NEXT: %11 = add nsw i64 %"iv'ac.0", -1
124+
; CHECK-NEXT: br label %invertfor.body
125+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)