Skip to content

Commit 83628bf

Browse files
authored
Correct caching behavior of intrinsics (rust-lang#482)
1 parent 8faf3b4 commit 83628bf

File tree

4 files changed

+162
-10
lines changed

4 files changed

+162
-10
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,6 +2805,18 @@ class AdjointGenerator
28052805
orig_ops[i] = II.getOperand(i);
28062806
}
28072807
handleAdjointForIntrinsic(II.getIntrinsicID(), II, orig_ops);
2808+
if (gutils->knownRecomputeHeuristic.find(&II) !=
2809+
gutils->knownRecomputeHeuristic.end()) {
2810+
if (!gutils->knownRecomputeHeuristic[&II]) {
2811+
CallInst *const newCall =
2812+
cast<CallInst>(gutils->getNewFromOriginal(&II));
2813+
IRBuilder<> BuilderZ(newCall);
2814+
BuilderZ.setFastMathFlags(getFast());
2815+
2816+
gutils->cacheForReverse(BuilderZ, newCall,
2817+
getIndex(&II, CacheType::Self));
2818+
}
2819+
}
28082820
eraseIfUnused(II);
28092821
}
28102822

@@ -7859,6 +7871,14 @@ class AdjointGenerator
78597871
if (isMemFreeLibMFunction(funcName, &ID)) {
78607872
if (Mode == DerivativeMode::ReverseModePrimal ||
78617873
gutils->isConstantInstruction(orig)) {
7874+
7875+
if (gutils->knownRecomputeHeuristic.find(orig) !=
7876+
gutils->knownRecomputeHeuristic.end()) {
7877+
if (!gutils->knownRecomputeHeuristic[orig]) {
7878+
gutils->cacheForReverse(BuilderZ, newCall,
7879+
getIndex(orig, CacheType::Self));
7880+
}
7881+
}
78627882
eraseIfUnused(*orig);
78637883
return;
78647884
}
@@ -7869,6 +7889,13 @@ class AdjointGenerator
78697889
orig_ops[i] = orig->getOperand(i);
78707890
}
78717891
handleAdjointForIntrinsic(ID, *orig, orig_ops);
7892+
if (gutils->knownRecomputeHeuristic.find(orig) !=
7893+
gutils->knownRecomputeHeuristic.end()) {
7894+
if (!gutils->knownRecomputeHeuristic[orig]) {
7895+
gutils->cacheForReverse(BuilderZ, newCall,
7896+
getIndex(orig, CacheType::Self));
7897+
}
7898+
}
78727899
eraseIfUnused(*orig);
78737900
return;
78747901
}

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,21 +97,32 @@ static inline bool is_use_directly_needed_in_reverse(
9797
return false;
9898
}
9999

100+
Intrinsic::ID ID = Intrinsic::not_intrinsic;
100101
if (auto II = dyn_cast<IntrinsicInst>(user)) {
101-
if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
102-
II->getIntrinsicID() == Intrinsic::lifetime_end ||
103-
II->getIntrinsicID() == Intrinsic::stacksave ||
104-
II->getIntrinsicID() == Intrinsic::stackrestore) {
102+
ID = II->getIntrinsicID();
103+
} else if (auto CI = dyn_cast<CallInst>(user)) {
104+
if (auto called = getFunctionFromCall(const_cast<CallInst *>(CI))) {
105+
StringRef funcName;
106+
if (called->hasFnAttribute("enzyme_math"))
107+
funcName = called->getFnAttribute("enzyme_math").getValueAsString();
108+
else
109+
funcName = called->getName();
110+
isMemFreeLibMFunction(funcName, &ID);
111+
}
112+
}
113+
114+
if (ID != Intrinsic::not_intrinsic) {
115+
if (ID == Intrinsic::lifetime_start || ID == Intrinsic::lifetime_end ||
116+
ID == Intrinsic::stacksave || ID == Intrinsic::stackrestore) {
105117
return false;
106118
}
107-
if (II->getIntrinsicID() == Intrinsic::fma ||
108-
II->getIntrinsicID() == Intrinsic::fmuladd) {
119+
if (ID == Intrinsic::fma || ID == Intrinsic::fmuladd) {
109120
bool needed = false;
110-
if (II->getArgOperand(0) == val &&
111-
!gutils->isConstantValue(II->getArgOperand(1)))
121+
if (user->getOperand(0) == val &&
122+
!gutils->isConstantValue(user->getOperand(1)))
112123
needed = true;
113-
if (II->getArgOperand(1) == val &&
114-
!gutils->isConstantValue(II->getArgOperand(0)))
124+
if (user->getOperand(1) == val &&
125+
!gutils->isConstantValue(user->getOperand(0)))
115126
needed = true;
116127
return needed;
117128
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -simplifycfg -instsimplify -gvn -adce -S | FileCheck %s
2+
source_filename = "/app/example.c"
3+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
4+
target triple = "x86_64-unknown-linux-gnu"
5+
6+
; Function Attrs: mustprogress nofree noinline nosync nounwind readonly uwtable willreturn
7+
define dso_local double @foo(double* nocapture readonly %x) {
8+
entry:
9+
%0 = load double, double* %x, align 8
10+
%arrayidx1 = getelementptr inbounds double, double* %x, i64 2
11+
%1 = load double, double* %arrayidx1, align 8
12+
%f2 = tail call double @llvm.fma.f64(double %0, double 2.000000e+00, double %1)
13+
%mul = fmul double %f2, %f2
14+
ret double %mul
15+
}
16+
17+
declare double @llvm.fma.f64(double, double, double) readnone
18+
19+
; Function Attrs: mustprogress nofree nosync nounwind uwtable willreturn
20+
define dso_local double @square(double* nocapture %x) {
21+
entry:
22+
%call = tail call double @foo(double* %x)
23+
store double 0.000000e+00, double* %x, align 8
24+
ret double %call
25+
}
26+
27+
define dso_local double @dsquare(double* %x, double* %dx) {
28+
entry:
29+
%call = call double (i8*, ...) @__enzyme_autodiff(i8* bitcast (double (double*)* @square to i8*), double* %x, double* %dx)
30+
ret double %call
31+
}
32+
33+
declare dso_local double @__enzyme_autodiff(i8*, ...)
34+
35+
; CHECK: define internal double @augmented_foo(double* nocapture readonly %x, double* nocapture %"x'")
36+
; CHECK-NEXT: entry:
37+
; CHECK-NEXT: %0 = load double, double* %x, align 8
38+
; CHECK-NEXT: %arrayidx1 = getelementptr inbounds double, double* %x, i64 2
39+
; CHECK-NEXT: %1 = load double, double* %arrayidx1, align 8
40+
; CHECK-NEXT: %f2 = tail call double @llvm.fma.f64(double %0, double 2.000000e+00, double %1)
41+
; CHECK-NEXT: ret double %f2
42+
; CHECK-NEXT: }
43+
44+
; CHECK: define internal void @diffefoo(double* nocapture readonly %x, double* nocapture %"x'", double %differeturn, double %f2)
45+
; CHECK-NEXT: entry:
46+
; CHECK-NEXT: %"arrayidx1'ipg" = getelementptr inbounds double, double* %"x'", i64 2
47+
; CHECK-NEXT: %m0diffef2 = fmul fast double %differeturn, %f2
48+
; CHECK-NEXT: %0 = fadd fast double %m0diffef2, %m0diffef2
49+
; CHECK-NEXT: %1 = fmul fast double %0, 2.000000e+00
50+
; CHECK-NEXT: %2 = load double, double* %"arrayidx1'ipg", align 8
51+
; CHECK-NEXT: %3 = fadd fast double %2, %0
52+
; CHECK-NEXT: store double %3, double* %"arrayidx1'ipg", align 8
53+
; CHECK-NEXT: %4 = load double, double* %"x'", align 8
54+
; CHECK-NEXT: %5 = fadd fast double %4, %1
55+
; CHECK-NEXT: store double %5, double* %"x'", align 8
56+
; CHECK-NEXT: ret void
57+
; CHECK-NEXT: }
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -simplifycfg -instsimplify -gvn -adce -S | FileCheck %s
2+
source_filename = "/app/example.c"
3+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
4+
target triple = "x86_64-unknown-linux-gnu"
5+
6+
; Function Attrs: mustprogress nofree noinline nosync nounwind readonly uwtable willreturn
7+
define dso_local double @foo(double* nocapture readonly %x) {
8+
entry:
9+
%0 = load double, double* %x, align 8
10+
%arrayidx1 = getelementptr inbounds double, double* %x, i64 2
11+
%1 = load double, double* %arrayidx1, align 8
12+
%f2 = tail call double @fma(double %0, double 2.000000e+00, double %1)
13+
%mul = fmul double %f2, %f2
14+
ret double %mul
15+
}
16+
17+
declare double @fma(double, double, double) readnone
18+
19+
; Function Attrs: mustprogress nofree nosync nounwind uwtable willreturn
20+
define dso_local double @square(double* nocapture %x) {
21+
entry:
22+
%call = tail call double @foo(double* %x)
23+
store double 0.000000e+00, double* %x, align 8
24+
ret double %call
25+
}
26+
27+
define dso_local double @dsquare(double* %x, double* %dx) {
28+
entry:
29+
%call = call double (i8*, ...) @__enzyme_autodiff(i8* bitcast (double (double*)* @square to i8*), double* %x, double* %dx)
30+
ret double %call
31+
}
32+
33+
declare dso_local double @__enzyme_autodiff(i8*, ...)
34+
35+
; CHECK: define internal double @augmented_foo(double* nocapture readonly %x, double* nocapture %"x'")
36+
; CHECK-NEXT: entry:
37+
; CHECK-NEXT: %0 = load double, double* %x, align 8
38+
; CHECK-NEXT: %arrayidx1 = getelementptr inbounds double, double* %x, i64 2
39+
; CHECK-NEXT: %1 = load double, double* %arrayidx1, align 8
40+
; CHECK-NEXT: %f2 = tail call double @fma(double %0, double 2.000000e+00, double %1)
41+
; CHECK-NEXT: ret double %f2
42+
; CHECK-NEXT: }
43+
44+
; CHECK: define internal void @diffefoo(double* nocapture readonly %x, double* nocapture %"x'", double %differeturn, double %f2)
45+
; CHECK-NEXT: entry:
46+
; CHECK-NEXT: %"arrayidx1'ipg" = getelementptr inbounds double, double* %"x'", i64 2
47+
; CHECK-NEXT: %m0diffef2 = fmul fast double %differeturn, %f2
48+
; CHECK-NEXT: %0 = fadd fast double %m0diffef2, %m0diffef2
49+
; CHECK-NEXT: %1 = fmul fast double %0, 2.000000e+00
50+
; CHECK-NEXT: %2 = load double, double* %"arrayidx1'ipg", align 8
51+
; CHECK-NEXT: %3 = fadd fast double %2, %0
52+
; CHECK-NEXT: store double %3, double* %"arrayidx1'ipg", align 8
53+
; CHECK-NEXT: %4 = load double, double* %"x'", align 8
54+
; CHECK-NEXT: %5 = fadd fast double %4, %1
55+
; CHECK-NEXT: store double %5, double* %"x'", align 8
56+
; CHECK-NEXT: ret void
57+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)