Skip to content

Commit 3500823

Browse files
committed
Fix unwrap cache
1 parent 6d3b2c4 commit 3500823

File tree

2 files changed

+191
-23
lines changed

2 files changed

+191
-23
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,36 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
125125
return available.lookup(val);
126126
}
127127

128+
if (auto inst = dyn_cast<Instruction>(val)) {
129+
// if (inst->getParent() == &newFunc->getEntryBlock()) {
130+
// return inst;
131+
//}
132+
if (isOriginalBlock(*BuilderM.GetInsertBlock())) {
133+
if (BuilderM.GetInsertBlock()->size() &&
134+
BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) {
135+
if (DT.dominates(inst, &*BuilderM.GetInsertPoint())) {
136+
// llvm::errs() << "allowed " << *inst << "from domination\n";
137+
assert(inst->getType() == val->getType());
138+
return inst;
139+
}
140+
} else {
141+
if (DT.dominates(inst, BuilderM.GetInsertBlock())) {
142+
// llvm::errs() << "allowed " << *inst << "from block domination\n";
143+
assert(inst->getType() == val->getType());
144+
return inst;
145+
}
146+
}
147+
}
148+
}
149+
128150
if (this->mode == DerivativeMode::ReverseModeGradient &&
129151
mode != UnwrapMode::LegalFullUnwrap) {
130152
// TODO this isOriginal is a bottleneck, the new mapping of
131153
// knnownRecompute should be precomputed and maintained to lookup instead
132154
Value *orig = isOriginal(val);
133155
if (orig &&
134156
knownRecomputeHeuristic.find(orig) != knownRecomputeHeuristic.end()) {
135-
if (!knownRecomputeHeuristic[orig]) {
157+
if (!knownRecomputeHeuristic[orig] && !legalRecompute(orig, available, &BuilderM)) {
136158
return nullptr;
137159
}
138160
}
@@ -153,28 +175,6 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
153175
return unwrap_cache[BuilderM.GetInsertBlock()][idx];
154176
}
155177

156-
if (auto inst = dyn_cast<Instruction>(val)) {
157-
// if (inst->getParent() == &newFunc->getEntryBlock()) {
158-
// return inst;
159-
//}
160-
if (isOriginalBlock(*BuilderM.GetInsertBlock())) {
161-
if (BuilderM.GetInsertBlock()->size() &&
162-
BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) {
163-
if (DT.dominates(inst, &*BuilderM.GetInsertPoint())) {
164-
// llvm::errs() << "allowed " << *inst << "from domination\n";
165-
assert(inst->getType() == val->getType());
166-
return inst;
167-
}
168-
} else {
169-
if (DT.dominates(inst, BuilderM.GetInsertBlock())) {
170-
// llvm::errs() << "allowed " << *inst << "from block domination\n";
171-
assert(inst->getType() == val->getType());
172-
return inst;
173-
}
174-
}
175-
}
176-
}
177-
178178
#define getOpFullest(Builder, vtmp, frominst, check) \
179179
({ \
180180
Value *v = vtmp; \
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -simplifycfg -early-cse -adce -S | FileCheck %s
2+
; ModuleID = 'inp.ll'
3+
4+
declare dso_local void @_Z17__enzyme_autodiffPvPdS0_i(i8*, double*, double*, i64*) local_unnamed_addr #4
5+
define dso_local void @outer(double* %m, double* %m2, i64* %n) local_unnamed_addr #2 {
6+
entry:
7+
call void @_Z17__enzyme_autodiffPvPdS0_i(i8* bitcast (double (double*, i64*)* @_Z10reduce_maxPdi to i8*), double* nonnull %m, double* nonnull %m2, i64* %n)
8+
ret void
9+
}
10+
; Function Attrs: nounwind uwtable
11+
define dso_local double @_Z10reduce_maxPdi(double* %vec, i64* %v) #0 {
12+
entry:
13+
%res = call double @pb(double* %vec, i64* %v)
14+
store double 0.000000e+00, double* %vec, align 8
15+
store i64 0, i64* %v, align 8
16+
%bc = bitcast i64* %v to i8*
17+
call void @llvm.memset.p0i8.i64(i8* %bc, i8 0, i64 128, i1 false)
18+
ret double %res
19+
}
20+
21+
declare void @llvm.memset.p0i8.i64(i8* nocapture, i8, i64, i1)
22+
23+
define double @pb(double* %x, i64* %sptr) {
24+
entry:
25+
%step = load i64, i64* %sptr, align 8, !tbaa !6
26+
br label %for.body
27+
28+
for.body: ; preds = %for.cond.loopexit, %entry
29+
%i2.0245 = phi i64 [ 0, %entry ], [ %add, %for.cond.loopexit ]
30+
%add = add nsw i64 %i2.0245, 1
31+
br label %for.body59
32+
33+
for.body59: ; preds = %for.body59, %for.body
34+
%k2.0243 = phi i64 [ %add61, %for.body59 ], [ 0, %for.body ]
35+
%add61 = add nsw i64 %k2.0243, %step
36+
call void @inner(double* %x)
37+
%cmp57 = icmp slt i64 %add61, 100
38+
br i1 %cmp57, label %for.body59, label %for.cond.loopexit
39+
40+
for.cond.loopexit: ; preds = %for.body59
41+
%cmp53 = icmp slt i64 %add, 56
42+
br i1 %cmp53, label %for.body, label %_ZN5Eigen8internal28aligned_stack_memory_handlerIdED2Ev.exit
43+
44+
_ZN5Eigen8internal28aligned_stack_memory_handlerIdED2Ev.exit: ; preds = %for.cond.loopexit
45+
ret double 0.000000e+00
46+
}
47+
48+
; Function Attrs: nounwind uwtable
49+
define void @inner(double* %blockA) unnamed_addr #3 align 2 {
50+
entry:
51+
%ld = load double, double* %blockA, align 8
52+
%mul = fmul fast double %ld, %ld
53+
store double %mul, double* %blockA, align 8
54+
ret void
55+
}
56+
57+
!4 = !{!5, i64 1, !"omnipotent char"}
58+
!5 = !{!"Simple C++ TBAA"}
59+
!6 = !{!7, !7, i64 0, i64 8}
60+
!7 = !{!4, i64 8, !"long"}
61+
62+
attributes #0 = { readnone speculatable }
63+
64+
; CHECK: define internal { i64, double* } @augmented_pb(double* %x, double* %"x'", i64* %sptr)
65+
; CHECK-NEXT: entry:
66+
; CHECK-NEXT: %step = load i64, i64* %sptr, align 8, !tbaa !0
67+
; CHECK-NEXT: %[[_unwrap:.+]] = udiv i64 99, %step
68+
; CHECK-NEXT: %[[a0:.+]] = add nuw i64 %[[_unwrap]], 1
69+
; CHECK-NEXT: %[[a1:.+]] = mul nuw nsw i64 %[[a0]], 56
70+
; CHECK-NEXT: %mallocsize = mul nuw nsw i64 %[[a1]], 8
71+
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize)
72+
; CHECK-NEXT: %_augmented_malloccache = bitcast i8* %malloccall to double*
73+
; CHECK-NEXT: br label %for.body
74+
75+
; CHECK: for.body: ; preds = %for.cond.loopexit, %entry
76+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.cond.loopexit ], [ 0, %entry ]
77+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
78+
; CHECK-NEXT: br label %for.body59
79+
80+
; CHECK: for.body59: ; preds = %for.body59, %for.body
81+
; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %for.body59 ], [ 0, %for.body ]
82+
; CHECK-NEXT: %[[a3:.+]] = mul i64 {{(%iv1, %step|%step, %iv1)}}
83+
; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1
84+
; CHECK-NEXT: %add61 = add nsw i64 %[[a3]], %step
85+
; CHECK-NEXT: %_augmented = call fast double @augmented_inner(double* %x, double* %"x'")
86+
; CHECK-NEXT: %[[a5:.+]] = mul nuw nsw i64 %iv, %[[a0]]
87+
; CHECK-NEXT: %[[a6:.+]] = add nuw nsw i64 %iv1, %[[a5]]
88+
; CHECK-NEXT: %[[a7:.+]] = getelementptr inbounds double, double* %_augmented_malloccache, i64 %[[a6]]
89+
; CHECK-NEXT: store double %_augmented, double* %[[a7:.+]], align 8, !invariant.group !
90+
; CHECK-NEXT: %cmp57 = icmp slt i64 %add61, 100
91+
; CHECK-NEXT: br i1 %cmp57, label %for.body59, label %for.cond.loopexit
92+
93+
; CHECK: for.cond.loopexit: ; preds = %for.body59
94+
; CHECK-NEXT: %cmp53 = icmp ne i64 %iv.next, 56
95+
; CHECK-NEXT: br i1 %cmp53, label %for.body, label %_ZN5Eigen8internal28aligned_stack_memory_handlerIdED2Ev.exit
96+
97+
; CHECK: _ZN5Eigen8internal28aligned_stack_memory_handlerIdED2Ev.exit: ; preds = %for.cond.loopexit
98+
; CHECK-NEXT: %.fca.0.insert = insertvalue { i64, double* } undef, i64 %step, 0
99+
; CHECK-NEXT: %.fca.1.insert = insertvalue { i64, double* } %.fca.0.insert, double* %_augmented_malloccache, 1
100+
; CHECK-NEXT: ret { i64, double* } %.fca.1.insert
101+
; CHECK-NEXT: }
102+
103+
; CHECK: define internal void @diffepb(double* %x, double* %"x'", i64* %sptr, double %differeturn, { i64, double* } %tapeArg)
104+
; CHECK-NEXT: entry:
105+
; CHECK-NEXT: %0 = extractvalue { i64, double* } %tapeArg, 1
106+
; CHECK-NEXT: %step = extractvalue { i64, double* } %tapeArg, 0
107+
; CHECK-NEXT: %[[_unwrap:.+]] = udiv i64 99, %step
108+
; CHECK-NEXT: %[[a1:.+]] = add nuw i64 %[[_unwrap]], 1
109+
; CHECK-NEXT: br label %for.body
110+
111+
; CHECK: for.body: ; preds = %for.cond.loopexit, %entry
112+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.cond.loopexit ], [ 0, %entry ]
113+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
114+
; CHECK-NEXT: br label %for.body59
115+
116+
; CHECK: for.body59: ; preds = %for.body59, %for.body
117+
; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %for.body59 ], [ 0, %for.body ]
118+
; CHECK-NEXT: %[[a4:.+]] = mul i64 {{(%iv1, %step|%step, %iv1)}}
119+
; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1
120+
; CHECK-NEXT: %add61 = add nsw i64 %[[a4]], %step
121+
; CHECK-NEXT: %cmp57 = icmp slt i64 %add61, 100
122+
; CHECK-NEXT: br i1 %cmp57, label %for.body59, label %for.cond.loopexit
123+
124+
; CHECK: for.cond.loopexit: ; preds = %for.body59
125+
; CHECK-NEXT: %cmp53 = icmp ne i64 %iv.next, 56
126+
; CHECK-NEXT: br i1 %cmp53, label %for.body, label %invertfor.cond.loopexit
127+
128+
; CHECK: invertentry: ; preds = %invertfor.body
129+
; CHECK-NEXT: %[[tofree:.+]] = bitcast double* %0 to i8*
130+
; CHECK-NEXT: tail call void @free(i8* nonnull %[[tofree]])
131+
; CHECK-NEXT: ret void
132+
133+
; CHECK: invertfor.body: ; preds = %invertfor.body59
134+
; CHECK-NEXT: %[[cmpf:.+]] = icmp eq i64 %"iv'ac.0", 0
135+
; CHECK-NEXT: br i1 %[[cmpf]], label %invertentry, label %incinvertfor.body
136+
137+
; CHECK: incinvertfor.body: ; preds = %invertfor.body
138+
; CHECK-NEXT: %[[a12:.+]] = add nsw i64 %"iv'ac.0", -1
139+
; CHECK-NEXT: br label %invertfor.cond.loopexit
140+
141+
; CHECK: invertfor.body59: ; preds = %invertfor.cond.loopexit, %incinvertfor.body59
142+
; CHECK-NEXT: %"iv1'ac.0" = phi i64 [ %[[_unwrap:.+]], %invertfor.cond.loopexit ], [ %[[a15:.+]], %incinvertfor.body59 ]
143+
; CHECK-NEXT: %[[_unwrap5:.+]] = mul nuw nsw i64 %"iv'ac.0", %[[a1]]
144+
; CHECK-NEXT: %[[_unwrap6:.+]] = add nuw nsw i64 %"iv1'ac.0", %[[_unwrap5]]
145+
; CHECK-NEXT: %[[_unwrap7:.+]] = getelementptr inbounds double, double* %0, i64 %[[_unwrap6]]
146+
; TODO make the invariant group here the same in the augmented forward
147+
; CHECK-NEXT: %tapeArg3_unwrap = load double, double* %[[_unwrap7:.+]], align 8, !invariant.group !
148+
; CHECK-NEXT: call void @diffeinner(double* %x, double* %"x'", double %tapeArg3_unwrap)
149+
; CHECK-NEXT: %[[a13:.+]] = icmp eq i64 %"iv1'ac.0", 0
150+
; CHECK-NEXT: br i1 %[[a13]], label %invertfor.body, label %incinvertfor.body59
151+
152+
; CHECK: incinvertfor.body59: ; preds = %invertfor.body59
153+
; CHECK-NEXT: %[[a15]] = add nsw i64 %"iv1'ac.0", -1
154+
; CHECK-NEXT: br label %invertfor.body59
155+
156+
; CHECK: invertfor.cond.loopexit: ; preds = %for.cond.loopexit, %incinvertfor.body
157+
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %[[a12]], %incinvertfor.body ], [ 55, %for.cond.loopexit ]
158+
; CHECK-NEXT: br label %invertfor.body59
159+
; CHECK-NEXT: }
160+
161+
; CHECK: define internal void @diffeinner(double* %blockA, double* %"blockA'", double %ld) unnamed_addr align 2 {
162+
; CHECK-NEXT: entry:
163+
; CHECK-NEXT: %0 = load double, double* %"blockA'", align 8
164+
; CHECK-NEXT: %m0diffeld = fmul fast double %0, %ld
165+
; CHECK-NEXT: %1 = fadd fast double %m0diffeld, %m0diffeld
166+
; CHECK-NEXT: store double %1, double* %"blockA'", align 8
167+
; CHECK-NEXT: ret void
168+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)