Skip to content

Commit 91447ec

Browse files
authored
Fix malloc placeholder (rust-lang#343)
* fix malloc placeholder
1 parent a318065 commit 91447ec

File tree

3 files changed

+110
-5
lines changed

3 files changed

+110
-5
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7309,10 +7309,20 @@ class AdjointGenerator
73097309
} else if (Mode == DerivativeMode::ForwardMode) {
73107310
IRBuilder<> Builder2(&call);
73117311
getForwardBuilder(Builder2);
7312+
73127313
SmallVector<Value *, 2> args = {orig->getArgOperand(0)};
73137314
CallInst *CI = Builder2.CreateCall(orig->getFunctionType(),
73147315
orig->getCalledFunction(), args);
73157316
CI->setAttributes(orig->getAttributes());
7317+
7318+
auto found = gutils->invertedPointers.find(orig);
7319+
PHINode *placeholder = cast<PHINode>(&*found->second);
7320+
7321+
gutils->invertedPointers.erase(found);
7322+
gutils->replaceAWithB(placeholder, CI);
7323+
gutils->erase(placeholder);
7324+
gutils->invertedPointers.insert(
7325+
std::make_pair(orig, InvertedPointerVH(gutils, CI)));
73167326
return;
73177327
}
73187328
}
@@ -7322,8 +7332,10 @@ class AdjointGenerator
73227332
if (!pair.second)
73237333
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
73247334
bool primalNeededInReverse =
7325-
is_value_needed_in_reverse<ValueType::Primal>(TR, gutils, orig, Mode,
7326-
Seen, oldUnreachable);
7335+
Mode == DerivativeMode::ForwardMode
7336+
? false
7337+
: is_value_needed_in_reverse<ValueType::Primal>(
7338+
TR, gutils, orig, Mode, Seen, oldUnreachable);
73277339
bool hasPDFree = gutils->allocationsWithGuaranteedFree.count(orig);
73287340
if (!primalNeededInReverse && hasPDFree) {
73297341
if (Mode == DerivativeMode::ReverseModeGradient) {

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,9 +1331,13 @@ Function *PreProcessCache::preprocessForClone(Function *F,
13311331
auto PA = LoopSimplifyPass().run(*NewF, FAM);
13321332
FAM.invalidate(*NewF, PA);
13331333

1334-
// For subfunction calls upgrade stack allocations to mallocs
1335-
// to ensure availability in the reverse pass
1336-
UpgradeAllocasToMallocs(NewF, mode);
1334+
if (mode == DerivativeMode::ReverseModePrimal ||
1335+
mode == DerivativeMode::ReverseModeGradient ||
1336+
mode == DerivativeMode::ReverseModeCombined) {
1337+
// For subfunction calls upgrade stack allocations to mallocs
1338+
// to ensure availability in the reverse pass
1339+
UpgradeAllocasToMallocs(NewF, mode);
1340+
}
13371341

13381342
CanonicalizeLoops(NewF, FAM);
13391343

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
; #include <stdio.h>
4+
5+
; double __enzyme_fwddiff(void*, ...);
6+
7+
; __attribute__((noinline))
8+
; void square_(const double* src, double* dest) {
9+
; *dest = *src * *src;
10+
; }
11+
12+
; double square(double x) {
13+
; double y;
14+
; square_(&x, &y);
15+
; return y;
16+
; }
17+
18+
; double dsquare(double x) {
19+
; return __enzyme_fwddiff((void*)square, x, 1.0);
20+
; }
21+
22+
23+
define dso_local void @square_(double* nocapture readonly %src, double* nocapture %dest) local_unnamed_addr #0 {
24+
entry:
25+
%0 = load double, double* %src, align 8
26+
%mul = fmul double %0, %0
27+
store double %mul, double* %dest, align 8
28+
ret void
29+
}
30+
31+
define dso_local double @square(double %x) #1 {
32+
entry:
33+
%x.addr = alloca double, align 8
34+
%y = alloca double, align 8
35+
store double %x, double* %x.addr, align 8
36+
%0 = bitcast double* %y to i8*
37+
call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %0) #4
38+
call void @square_(double* nonnull %x.addr, double* nonnull %y)
39+
%1 = load double, double* %y, align 8
40+
call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %0) #4
41+
ret double %1
42+
}
43+
44+
declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #2
45+
46+
declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #2
47+
48+
define dso_local double @dsquare(double %x) local_unnamed_addr #1 {
49+
entry:
50+
%call = tail call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double)* @square to i8*), double %x, double 1.000000e+00) #4
51+
ret double %call
52+
}
53+
54+
declare dso_local double @__enzyme_fwddiff(i8*, ...) local_unnamed_addr #3
55+
56+
attributes #0 = { norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
57+
attributes #1 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
58+
attributes #2 = { argmemonly nounwind }
59+
attributes #3 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
60+
attributes #4 = { nounwind }
61+
62+
63+
; CHECK: define internal double @fwddiffesquare(double %x, double %"x'")
64+
; CHECK-NEXT: entry:
65+
; CHECK-NEXT: %"x.addr'ipa" = alloca double, align 8
66+
; CHECK-NEXT: store double 0.000000e+00, double* %"x.addr'ipa", align 8
67+
; CHECK-NEXT: %x.addr = alloca double, align 8
68+
; CHECK-NEXT: %"y'ipa" = alloca double, align 8
69+
; CHECK-NEXT: store double 0.000000e+00, double* %"y'ipa", align 8
70+
; CHECK-NEXT: %y = alloca double, align 8
71+
; CHECK-NEXT: store double %x, double* %x.addr, align 8
72+
; CHECK-NEXT: store double %"x'", double* %"x.addr'ipa", align 8
73+
; CHECK-NEXT: call void @fwddiffesquare_(double* %x.addr, double* %"x.addr'ipa", double* %y, double* %"y'ipa")
74+
; CHECK-NEXT: %0 = load double, double* %"y'ipa", align 8
75+
; CHECK-NEXT: ret double %0
76+
; CHECK-NEXT: }
77+
78+
; CHECK: define internal void @fwddiffesquare_(double* nocapture readonly %src, double* nocapture %"src'", double* nocapture %dest, double* nocapture %"dest'")
79+
; CHECK-NEXT: entry:
80+
; CHECK-NEXT: %0 = load double, double* %src, align 8
81+
; CHECK-NEXT: %1 = load double, double* %"src'", align 8
82+
; CHECK-NEXT: %mul = fmul double %0, %0
83+
; CHECK-NEXT: %2 = fmul fast double %1, %0
84+
; CHECK-NEXT: %3 = fmul fast double %1, %0
85+
; CHECK-NEXT: %4 = fadd fast double %2, %3
86+
; CHECK-NEXT: store double %mul, double* %dest, align 8
87+
; CHECK-NEXT: store double %4, double* %"dest'", align 8
88+
; CHECK-NEXT: ret void
89+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)