Skip to content

Commit 81e87fc

Browse files
committed
Fix shared memory initialization
1 parent f7addf0 commit 81e87fc

File tree

5 files changed

+217
-6
lines changed

5 files changed

+217
-6
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3018,10 +3018,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
30183018
} else {
30193019
llvm_unreachable("unknown gpu architecture");
30203020
}
3021-
Value *AndVal = ebuilder.CreateAnd(ebuilder.CreateAnd(tx, ty), tz);
3021+
Value *OrVal = ebuilder.CreateOr(ebuilder.CreateOr(tx, ty), tz);
30223022

30233023
ebuilder.CreateCondBr(
3024-
ebuilder.CreateICmpEQ(AndVal, ConstantInt::get(AndVal->getType(), 0)),
3024+
ebuilder.CreateICmpEQ(OrVal, ConstantInt::get(OrVal->getType(), 0)),
30253025
sharedBlock, OldEntryInsts);
30263026

30273027
IRBuilder<> instbuilder(OldEntryInsts, OldEntryInsts->begin());
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -gvn -instsimplify -correlated-propagation -adce -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline norecurse nounwind uwtable
4+
define double @f(double* nocapture %x, i64 %n) #0 {
5+
entry:
6+
br label %loop
7+
8+
loop:
9+
%j = phi i64 [ %nj, %end ], [ 0, %entry ]
10+
%sum = phi double [ %nsum, %end ], [ 0.000000e+00, %entry ]
11+
%nj = add nsw nuw i64 %j, 1
12+
%g0 = getelementptr inbounds double, double* %x, i64 %j
13+
br label %body
14+
15+
body: ; preds = %entry, %for.cond.cleanup6
16+
%i = phi i64 [ %next, %body ], [ 0, %loop ]
17+
%gep = getelementptr inbounds double, double* %g0, i64 %i
18+
%ld = load double, double* %gep, align 8
19+
%cmp = fcmp oeq double %ld, 3.141592e+00
20+
%next = add nuw i64 %i, 1
21+
br i1 %cmp, label %body, label %end
22+
23+
end:
24+
%gep2 = getelementptr inbounds double, double* %x, i64 %i
25+
%ld2 = load double, double* %gep2, align 8
26+
%nsum = fadd double %ld2, %sum
27+
%cmp2 = icmp ne i64 %nj, 10
28+
br i1 %cmp2, label %loop, label %exit
29+
30+
exit:
31+
ret double %nsum
32+
}
33+
34+
; Function Attrs: noinline nounwind uwtable
35+
define dso_local double @dsumsquare(double* %x, double* %xp, i64 %n) local_unnamed_addr #1 {
36+
entry:
37+
%call = tail call fast double @__enzyme_autodiff(i8* bitcast (double (double*, i64)* @f to i8*), double* %x, double* %xp, i64 %n)
38+
ret double %call
39+
}
40+
41+
declare dso_local double @__enzyme_autodiff(i8*, double*, double*, i64) local_unnamed_addr
42+
43+
attributes #0 = { noinline norecurse nounwind uwtable }
44+
attributes #1 = { noinline nounwind uwtable }
45+
46+
; CHECK: define internal void @diffef(double* nocapture %x, double* nocapture %"x'", i64 %n, double %differeturn)
47+
; CHECK-NEXT: entry:
48+
; CHECK-NEXT: %malloccall = tail call noalias nonnull dereferenceable(80) dereferenceable_or_null(80) i8* @malloc(i64 80)
49+
; CHECK-NEXT: %loopLimit_malloccache = bitcast i8* %malloccall to i64*
50+
; CHECK-NEXT: br label %loop
51+
52+
; CHCEK: loop: ; preds = %end, %entry
53+
; CHCEK-NEXT: %iv = phi i64 [ %iv.next, %end ], [ 0, %entry ]
54+
; CHCEK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
55+
; CHCEK-NEXT: %g0 = getelementptr inbounds double, double* %x, i64 %iv
56+
; CHCEK-NEXT: br label %body
57+
58+
; CHECK: body:
59+
; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %body ], [ 0, %loop ]
60+
; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1
61+
; CHECK-NEXT: %gep = getelementptr inbounds double, double* %g0, i64 %iv1
62+
; CHECK-NEXT: %ld = load double, double* %gep, align 8
63+
; CHECK-NEXT: %cmp = fcmp oeq double %ld, 0x400921FAFC8B007A
64+
; CHECK-NEXT: br i1 %cmp, label %body, label %end
65+
66+
; CHECK: end: ; preds = %body
67+
; CHECK-NEXT: %0 = getelementptr inbounds i64, i64* %loopLimit_malloccache, i64 %iv
68+
; CHECK-NEXT: store i64 %iv1, i64* %0, align 8
69+
; CHECK-NEXT: %cmp2 = icmp ne i64 %iv.next, 10
70+
; CHECK-NEXT: br i1 %cmp2, label %loop, label %invertend
71+
72+
; CHECK: invertentry: ; preds = %invertloop
73+
; CHECK-NEXT: tail call void @free(i8* nonnull %malloccall)
74+
; CHECK-NEXT: ret void
75+
76+
; CHECK: invertloop: ; preds = %invertbody
77+
; CHECK-NEXT: %[[icmp0:.+]] = icmp eq i64 %"iv'ac.0", 0
78+
; CHECK-NEXT: br i1 %[[icmp0]], label %invertentry, label %incinvertloop
79+
80+
; CHECK: incinvertloop: ; preds = %invertloop
81+
; CHECK-NEXT: %2 = add nsw i64 %"iv'ac.0", -1
82+
; CHECK-NEXT: br label %invertend
83+
84+
; CHECK: invertbody: ; preds = %invertend, %incinvertbody
85+
; CHECK-NEXT: %"iv1'ac.0" = phi i64 [ 0, %invertend ], [ %4, %incinvertbody ]
86+
; CHECK-NEXT: %3 = icmp eq i64 %"iv1'ac.0", 0
87+
; CHECK-NEXT: br i1 %3, label %invertloop, label %incinvertbody
88+
89+
; CHECK: incinvertbody: ; preds = %invertbody
90+
; CHECK-NEXT: %4 = add nsw i64 %"iv1'ac.0", -1
91+
; CHECK-NEXT: br label %invertbody
92+
93+
; CHECK: invertend: ; preds = %end, %incinvertloop
94+
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %2, %incinvertloop ], [ 9, %end ]
95+
; CHECK-NEXT: %5 = getelementptr inbounds i64, i64* %loopLimit_malloccache, i64 %"iv'ac.0"
96+
; CHECK-NEXT: %6 = load i64, i64* %5, align 8, !invariant.group !0
97+
; CHECK-NEXT: %"gep2'ipg_unwrap" = getelementptr inbounds double, double* %"x'", i64 %6
98+
; CHECK-NEXT: %7 = load double, double* %"gep2'ipg_unwrap", align 8
99+
; CHECK-NEXT: %8 = fadd fast double %7, %differeturn
100+
; CHECK-NEXT: store double %8, double* %"gep2'ipg_unwrap", align 8
101+
; CHECK-NEXT: br label %invertbody
102+
; CHECK-NEXT: }
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -gvn -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline norecurse nounwind uwtable
4+
define double @f(double* nocapture %x, i64 %n) #0 {
5+
entry:
6+
br label %loop
7+
8+
loop:
9+
%j = phi i64 [ %nj, %end ], [ 0, %entry ]
10+
%sum = phi double [ %nsum, %end ], [ 0.000000e+00, %entry ]
11+
%nj = add nsw nuw i64 %j, 1
12+
%g0 = getelementptr inbounds double, double* %x, i64 %j
13+
br label %body
14+
15+
body: ; preds = %entry, %for.cond.cleanup6
16+
%i = phi i64 [ %next, %body ], [ 0, %loop ]
17+
%idx = phi i64 [ %nidx, %body ], [ 0, %loop ]
18+
%gep = getelementptr inbounds double, double* %g0, i64 %i
19+
%ld = load double, double* %gep, align 8
20+
%cmp = fcmp oeq double %ld, 3.141592e+00
21+
%next = add nuw i64 %i, 1
22+
%int = fptoui double %ld to i64
23+
%nidx = add nuw i64 %idx, %int
24+
br i1 %cmp, label %body, label %end
25+
26+
end:
27+
%gep2 = getelementptr inbounds double, double* %x, i64 %idx
28+
%ld2 = load double, double* %gep2, align 8
29+
%nsum = fadd double %ld2, %sum
30+
%cmp2 = icmp ne i64 %nj, 10
31+
br i1 %cmp2, label %loop, label %exit
32+
33+
exit:
34+
ret double %nsum
35+
}
36+
37+
; Function Attrs: noinline nounwind uwtable
38+
define dso_local double @dsumsquare(double* %x, double* %xp, i64 %n) local_unnamed_addr #1 {
39+
entry:
40+
%call = tail call fast double @__enzyme_autodiff(i8* bitcast (double (double*, i64)* @f to i8*), double* %x, double* %xp, i64 %n)
41+
ret double %call
42+
}
43+
44+
declare dso_local double @__enzyme_autodiff(i8*, double*, double*, i64) local_unnamed_addr
45+
46+
attributes #0 = { noinline norecurse nounwind uwtable }
47+
attributes #1 = { noinline nounwind uwtable }
48+
49+
; CHECK: define internal void @diffef(double* nocapture %x, double* nocapture %"x'", i64 %n, double %differeturn)
50+
; CHECK-NEXT: entry:
51+
; CHECK-NEXT: %malloccall = tail call noalias nonnull dereferenceable(80) dereferenceable_or_null(80) i8* @malloc(i64 80)
52+
; CHECK-NEXT: %"idx!manual_lcssa_malloccache" = bitcast i8* %malloccall to i64*
53+
; CHECK-NEXT: br label %loop
54+
55+
; CHECK: loop: ; preds = %end, %entry
56+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %end ], [ 0, %entry ]
57+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
58+
; CHECK-NEXT: %g0 = getelementptr inbounds double, double* %x, i64 %iv
59+
; CHECK-NEXT: br label %body
60+
61+
; CHECK: body: ; preds = %body, %loop
62+
; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %body ], [ 0, %loop ]
63+
; CHECK-NEXT: %idx = phi i64 [ %nidx, %body ], [ 0, %loop ]
64+
; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1
65+
; CHECK-NEXT: %gep = getelementptr inbounds double, double* %g0, i64 %iv1
66+
; CHECK-NEXT: %ld = load double, double* %gep, align 8
67+
; CHECK-NEXT: %cmp = fcmp oeq double %ld, 0x400921FAFC8B007A
68+
; CHECK-NEXT: %int = fptoui double %ld to i64
69+
; CHECK-NEXT: %nidx = add nuw i64 %idx, %int
70+
; CHECK-NEXT: br i1 %cmp, label %body, label %end
71+
72+
; CHECK: end: ; preds = %body
73+
; CHECK-NEXT: %0 = getelementptr inbounds i64, i64* %"idx!manual_lcssa_malloccache", i64 %iv
74+
; CHECK-NEXT: store i64 %idx, i64* %0, align 8, !invariant.group !0
75+
; CHECK-NEXT: %cmp2 = icmp ne i64 %iv.next, 10
76+
; CHECK-NEXT: br i1 %cmp2, label %loop, label %invertend
77+
78+
; CHECK: invertentry: ; preds = %invertloop
79+
; CHECK-NEXT: tail call void @free(i8* nonnull %malloccall)
80+
; CHECK-NEXT: ret void
81+
82+
; CHECK: invertloop: ; preds = %invertbody
83+
; CHECK-NEXT: %1 = icmp eq i64 %"iv'ac.0", 0
84+
; CHECK-NEXT: %2 = select {{(fast )?}}i1 %1, double 0.000000e+00, double %differeturn
85+
; CHECK-NEXT: br i1 %1, label %invertentry, label %incinvertloop
86+
87+
; CHECK: incinvertloop: ; preds = %invertloop
88+
; CHECK-NEXT: %3 = add nsw i64 %"iv'ac.0", -1
89+
; CHECK-NEXT: br label %invertend
90+
91+
; CHECK: invertbody: ; preds = %invertend, %incinvertbody
92+
; CHECK-NEXT: %"iv1'ac.0" = phi i64 [ 0, %invertend ], [ %5, %incinvertbody ]
93+
; CHECK-NEXT: %4 = icmp eq i64 %"iv1'ac.0", 0
94+
; CHECK-NEXT: br i1 %4, label %invertloop, label %incinvertbody
95+
96+
; CHECK: incinvertbody: ; preds = %invertbody
97+
; CHECK-NEXT: %5 = add nsw i64 %"iv1'ac.0", -1
98+
; CHECK-NEXT: br label %invertbody
99+
100+
; CHECK: invertend: ; preds = %end, %incinvertloop
101+
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %3, %incinvertloop ], [ 9, %end ]
102+
; CHECK-NEXT: %6 = getelementptr inbounds i64, i64* %"idx!manual_lcssa_malloccache", i64 %"iv'ac.0"
103+
; CHECK-NEXT: %7 = load i64, i64* %6, align 8, !invariant.group !0
104+
; CHECK-NEXT: %"gep2'ipg_unwrap" = getelementptr inbounds double, double* %"x'", i64 %7
105+
; CHECK-NEXT: %8 = load double, double* %"gep2'ipg_unwrap", align 8
106+
; CHECK-NEXT: %9 = fadd fast double %8, %differeturn
107+
; CHECK-NEXT: store double %9, double* %"gep2'ipg_unwrap", align 8
108+
; CHECK-NEXT: br label %invertbody
109+
; CHECK-NEXT: }

enzyme/test/Enzyme/ReverseMode/sharedcachefwd.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ attributes #6 = { nounwind }
182182
; CHECK-NEXT: %0 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
183183
; CHECK-NEXT: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
184184
; CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
185-
; CHECK-NEXT: %3 = and i32 %0, %1
186-
; CHECK-NEXT: %4 = and i32 %3, %2
185+
; CHECK-NEXT: %3 = or i32 %0, %1
186+
; CHECK-NEXT: %4 = or i32 %3, %2
187187
; CHECK-NEXT: %5 = icmp eq i32 %4, 0
188188
; CHECK-NEXT: br i1 %5, label %shblock, label %[[blk:.+]]
189189

enzyme/test/Enzyme/ReverseMode/sharedmem.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ attributes #4 = { nounwind }
9494
; CHECK-NEXT: %0 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
9595
; CHECK-NEXT: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
9696
; CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
97-
; CHECK-NEXT: %3 = and i32 %0, %1
98-
; CHECK-NEXT: %4 = and i32 %3, %2
97+
; CHECK-NEXT: %3 = or i32 %0, %1
98+
; CHECK-NEXT: %4 = or i32 %3, %2
9999
; CHECK-NEXT: %5 = icmp eq i32 %4, 0
100100
; CHECK-NEXT: br i1 %5, label %shblock, label %invertbb
101101

0 commit comments

Comments
 (0)