Skip to content

Commit 74d971f

Browse files
authored
ARM Bugfixes (rust-lang#380)
* handle insertvalue array * add tests * fix test
1 parent 4c9a761 commit 74d971f

File tree

5 files changed

+127
-12
lines changed

5 files changed

+127
-12
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,10 +1440,15 @@ class AdjointGenerator
14401440
if (Mode == DerivativeMode::ReverseModePrimal)
14411441
return;
14421442

1443-
auto st = cast<StructType>(IVI.getType());
14441443
bool hasNonPointer = false;
1445-
for (unsigned i = 0; i < st->getNumElements(); ++i) {
1446-
if (!st->getElementType(i)->isPointerTy()) {
1444+
if (auto st = dyn_cast<StructType>(IVI.getType())) {
1445+
for (unsigned i = 0; i < st->getNumElements(); ++i) {
1446+
if (!st->getElementType(i)->isPointerTy()) {
1447+
hasNonPointer = true;
1448+
}
1449+
}
1450+
} else if (auto at = dyn_cast<ArrayType>(IVI.getType())) {
1451+
if (!at->getElementType()->isPointerTy()) {
14471452
hasNonPointer = true;
14481453
}
14491454
}

enzyme/Enzyme/GradientUtils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,25 @@ class DiffeGradientUtils : public GradientUtils {
17551755
}
17561756
}
17571757
return addedSelects;
1758+
} else if (auto at = dyn_cast<ArrayType>(old->getType())) {
1759+
assert(!mask);
1760+
if (mask)
1761+
llvm_unreachable("cannot handle recursive addToDiffe with mask");
1762+
if (at->getElementType()->isPointerTy())
1763+
return addedSelects;
1764+
for (unsigned i = 0; i < at->getNumElements(); ++i) {
1765+
// TODO pass in full type tree here and recurse into tree.
1766+
Value *v = ConstantInt::get(Type::getInt32Ty(at->getContext()), i);
1767+
SmallVector<Value *, 2> idx2(idxs.begin(), idxs.end());
1768+
idx2.push_back(v);
1769+
auto selects = addToDiffe(
1770+
val, BuilderM.CreateExtractValue(dif, ArrayRef<unsigned>(i)),
1771+
BuilderM, nullptr, idx2);
1772+
for (auto select : selects) {
1773+
addedSelects.push_back(select);
1774+
}
1775+
}
1776+
return addedSelects;
17581777
} else {
17591778
llvm_unreachable("unknown type to add to diffe");
17601779
exit(1);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline nounwind readnone uwtable
4+
define double @tester(double %x) {
5+
entry:
6+
%agg1 = insertvalue [3 x double] undef, double %x, 0
7+
%mul = fmul double %x, %x
8+
%agg2 = insertvalue [3 x double] %agg1, double %mul, 1
9+
%add = fadd double %mul, 2.0
10+
%agg3 = insertvalue [3 x double] %agg2, double %add, 2
11+
%res = extractvalue [3 x double] %agg2, 1
12+
ret double %res
13+
}
14+
15+
define double @test_derivative(double %x) {
16+
entry:
17+
%0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0)
18+
ret double %0
19+
}
20+
21+
; Function Attrs: nounwind
22+
declare double @__enzyme_fwddiff(double (double)*, ...)
23+
24+
; CHECK: define internal double @fwddiffetester(double %x, double %"x'")
25+
; CHECK-NEXT: entry:
26+
; CHECK-NEXT: %0 = fmul fast double %"x'", %x
27+
; CHECK-NEXT: %1 = fmul fast double %"x'", %x
28+
; CHECK-NEXT: %2 = fadd fast double %0, %1
29+
; CHECK-NEXT: ret double %2
30+
; CHECK-NEXT: }
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline nounwind readnone uwtable
4+
define double @tester(double %x) {
5+
entry:
6+
%agg1 = insertvalue [3 x double] undef, double %x, 0
7+
%mul = fmul double %x, %x
8+
%agg2 = insertvalue [3 x double] %agg1, double %mul, 1
9+
%add = fadd double %mul, 2.0
10+
%agg3 = insertvalue [3 x double] %agg2, double %add, 2
11+
%res = extractvalue [3 x double] %agg2, 1
12+
ret double %res
13+
}
14+
15+
define double @test_derivative(double %x) {
16+
entry:
17+
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
18+
ret double %0
19+
}
20+
21+
; Function Attrs: nounwind
22+
declare double @__enzyme_autodiff(double (double)*, ...)
23+
24+
; CHECK: define internal { double } @diffetester(double %x, double %differeturn)
25+
; CHECK-NEXT: entry:
26+
; CHECK-NEXT: %"agg2'de" = alloca [3 x double], align 8
27+
; CHECK-NEXT: store [3 x double] zeroinitializer, [3 x double]* %"agg2'de"
28+
; CHECK-NEXT: %"agg1'de" = alloca [3 x double], align 8
29+
; CHECK-NEXT: store [3 x double] zeroinitializer, [3 x double]* %"agg1'de"
30+
; CHECK-NEXT: %0 = getelementptr inbounds [3 x double], [3 x double]* %"agg2'de", i32 0, i32 1
31+
; CHECK-NEXT: %1 = load double, double* %0
32+
; CHECK-NEXT: %2 = fadd fast double %1, %differeturn
33+
; CHECK-NEXT: store double %2, double* %0
34+
; CHECK-NEXT: %3 = load [3 x double], [3 x double]* %"agg2'de"
35+
; CHECK-NEXT: %4 = extractvalue [3 x double] %3, 1
36+
; CHECK-NEXT: %5 = load [3 x double], [3 x double]* %"agg2'de"
37+
; CHECK-NEXT: %6 = insertvalue [3 x double] %5, double 0.000000e+00, 1
38+
; CHECK-NEXT: %7 = extractvalue [3 x double] %6, 0
39+
; CHECK-NEXT: %8 = getelementptr inbounds [3 x double], [3 x double]* %"agg1'de", i32 0, i32 0
40+
; CHECK-NEXT: %9 = load double, double* %8
41+
; CHECK-NEXT: %10 = fadd fast double %9, %7
42+
; CHECK-NEXT: store double %10, double* %8
43+
; CHECK-NEXT: %11 = getelementptr inbounds [3 x double], [3 x double]* %"agg1'de", i32 0, i32 1
44+
; CHECK-NEXT: %12 = load double, double* %11
45+
; CHECK-NEXT: store double %12, double* %11
46+
; CHECK-NEXT: %13 = extractvalue [3 x double] %6, 2
47+
; CHECK-NEXT: %14 = getelementptr inbounds [3 x double], [3 x double]* %"agg1'de", i32 0, i32 2
48+
; CHECK-NEXT: %15 = load double, double* %14
49+
; CHECK-NEXT: %16 = fadd fast double %15, %13
50+
; CHECK-NEXT: store double %16, double* %14
51+
; CHECK-NEXT: store [3 x double] zeroinitializer, [3 x double]* %"agg2'de"
52+
; CHECK-NEXT: %m0diffex = fmul fast double %4, %x
53+
; CHECK-NEXT: %m1diffex = fmul fast double %4, %x
54+
; CHECK-NEXT: %17 = fadd fast double %m0diffex, %m1diffex
55+
; CHECK-NEXT: %18 = load [3 x double], [3 x double]* %"agg1'de"
56+
; CHECK-NEXT: %19 = extractvalue [3 x double] %18, 0
57+
; CHECK-NEXT: %20 = fadd fast double %17, %19
58+
; CHECK-NEXT: store [3 x double] zeroinitializer, [3 x double]* %"agg1'de"
59+
; CHECK-NEXT: %21 = insertvalue { double } undef, double %20, 0
60+
; CHECK-NEXT: ret { double } %21
61+
; CHECK-NEXT: }

enzyme/test/Enzyme/ReverseMode/sret1.ll

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ attributes #2 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-
4747
; CHECK: define {{(dso_local )?}}void @_Z5dtestddd(%struct.Diffe* noalias sret %agg.result, double %x, double %y, double %z)
4848
; CHECK-NEXT: entry:
4949
; CHECK-NEXT: %0 = call { double, double, double } @diffe_Z4testddd(double %x, double %y, double %z, double 1.000000e+00)
50-
; CHECK-NEXT: %1 = getelementptr inbounds %struct.Diffe, %struct.Diffe* %agg.result, i32 0, i32 0
51-
; CHECK-NEXT: %2 = extractvalue { double, double, double } %0, 0
52-
; CHECK-NEXT: store double %2, double* %1
53-
; CHECK-NEXT: %3 = getelementptr inbounds %struct.Diffe, %struct.Diffe* %agg.result, i32 0, i32 1
54-
; CHECK-NEXT: %4 = extractvalue { double, double, double } %0, 1
55-
; CHECK-NEXT: store double %4, double* %3
56-
; CHECK-NEXT: %5 = getelementptr inbounds %struct.Diffe, %struct.Diffe* %agg.result, i32 0, i32 2
57-
; CHECK-NEXT: %6 = extractvalue { double, double, double } %0, 2
58-
; CHECK-NEXT: store double %6, double* %5
50+
; CHECK-DAG: %[[a0:.+]] = getelementptr inbounds %struct.Diffe, %struct.Diffe* %agg.result, i32 0, i32 0
51+
; CHECK-DAG: %[[a1:.+]] = extractvalue { double, double, double } %0, 0
52+
; CHECK-NEXT: store double %[[a1]], double* %[[a0]]
53+
; CHECK-DAG: %[[b0:.+]] = getelementptr inbounds %struct.Diffe, %struct.Diffe* %agg.result, i32 0, i32 1
54+
; CHECK-DAG: %[[b1:.+]] = extractvalue { double, double, double } %0, 1
55+
; CHECK-NEXT: store double %[[b1]], double* %[[b0]]
56+
; CHECK-DAG: %[[c0:.+]] = getelementptr inbounds %struct.Diffe, %struct.Diffe* %agg.result, i32 0, i32 2
57+
; CHECK-DAG: %[[c1:.+]] = extractvalue { double, double, double } %0, 2
58+
; CHECK-NEXT: store double %[[c1]], double* %[[c0]]
5959
; CHECK-NEXT: ret void
6060
; CHECK-NEXT: }
6161

0 commit comments

Comments
 (0)