Skip to content

Commit e243a3c

Browse files
committed
Add codegen tests
Note(Sa4dUs): As LLVM-IR opt passes are executed after passing LLVM to Enzyme, most of the cases have turned out to not be problematic. Anyways, we still test them to prevent any kind of regression.
1 parent f9bb47c commit e243a3c

File tree

1 file changed

+263
-0
lines changed

1 file changed

+263
-0
lines changed
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
//@ revisions: debug release
2+
3+
//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat
4+
//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
5+
//@ no-prefer-dynamic
6+
//@ needs-enzyme
7+
8+
// This does only test the funtion attribute handling for autodiff.
9+
// Function argument changes are troublesome for Enzyme, so we have to
10+
// ensure that arguments remain the same, or if we change them, be aware
11+
// of the changes to handle it correctly.
12+
13+
#![feature(autodiff)]
14+
15+
use std::autodiff::{autodiff_forward, autodiff_reverse};
16+
17+
#[derive(Copy, Clone)]
18+
struct Input {
19+
x: f32,
20+
y: f32,
21+
}
22+
23+
#[derive(Copy, Clone)]
24+
struct Wrapper {
25+
z: f32,
26+
}
27+
28+
#[derive(Copy, Clone)]
29+
struct NestedInput {
30+
x: f32,
31+
y: Wrapper,
32+
}
33+
34+
fn square(x: f32) -> f32 {
35+
x * x
36+
}
37+
38+
// CHECK: ; abi_handling::f1
39+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
40+
// debug-NEXT: define internal float @_ZN12abi_handling2f117h536ac8081c1e4101E(ptr align 4 %x)
41+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f117h536ac8081c1e4101E(float %x.0.val, float %x.4.val)
42+
#[autodiff_forward(df1, Dual, Dual)]
43+
fn f1(x: &[f32; 2]) -> f32 {
44+
x[0] + x[1]
45+
}
46+
47+
// CHECK: ; abi_handling::f2
48+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
49+
// debug-NEXT: define internal float @_ZN12abi_handling2f217h33732e9f83c91bc9E(ptr %f, float %x)
50+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f217h33732e9f83c91bc9E(float noundef %x)
51+
#[autodiff_reverse(df2, Const, Active, Active)]
52+
fn f2(f: fn(f32) -> f32, x: f32) -> f32 {
53+
f(x)
54+
}
55+
56+
// CHECK: ; abi_handling::f3
57+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
58+
// debug-NEXT: define internal float @_ZN12abi_handling2f317h9cd1fc602b0815a4E(ptr align 4 %x, ptr align 4 %y)
59+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f317h9cd1fc602b0815a4E(float %x.0.val)
60+
#[autodiff_forward(df3, Dual, Dual, Dual)]
61+
fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 {
62+
*x * *y
63+
}
64+
65+
// CHECK: ; abi_handling::f4
66+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
67+
// debug-NEXT: define internal float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float %x.0, float %x.1)
68+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float noundef %x.0, float noundef %x.1)
69+
#[autodiff_forward(df4, Dual, Dual)]
70+
fn f4(x: (f32, f32)) -> f32 {
71+
x.0 * x.1
72+
}
73+
74+
// CHECK: ; abi_handling::f5
75+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
76+
// debug-NEXT: define internal float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float %i.0, float %i.1)
77+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float noundef %i.0, float noundef %i.1)
78+
#[autodiff_forward(df5, Dual, Dual)]
79+
fn f5(i: Input) -> f32 {
80+
i.x + i.y
81+
}
82+
83+
// CHECK: ; abi_handling::f6
84+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
85+
// debug-NEXT: define internal float @_ZN12abi_handling2f617h5784b207bbb2483eE(float %i.0, float %i.1)
86+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f617h5784b207bbb2483eE(float noundef %i.0, float noundef %i.1)
87+
#[autodiff_forward(df6, Dual, Dual)]
88+
fn f6(i: NestedInput) -> f32 {
89+
i.x + i.y.z * i.y.z
90+
}
91+
92+
// df1
93+
// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E(float %x.0.val, float %x.4.val)
94+
// release-NEXT: start:
95+
// release-NEXT: %_0 = fadd float %x.0.val, %x.4.val
96+
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0
97+
// release-NEXT: %1 = insertvalue { float, float } %0, float 1.000000e+00, 1
98+
// release-NEXT: ret { float, float } %1
99+
// release-NEXT: }
100+
101+
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E(ptr align 4 %x, ptr align 4 %"x'")
102+
// debug-NEXT: start:
103+
// debug-NEXT: %"'ipg" = getelementptr inbounds float, ptr %"x'", i64 0
104+
// debug-NEXT: %0 = getelementptr inbounds nuw float, ptr %x, i64 0
105+
// debug-NEXT: %"_2'ipl" = load float, ptr %"'ipg", align 4, !alias.scope !4, !noalias !7
106+
// debug-NEXT: %_2 = load float, ptr %0, align 4, !alias.scope !7, !noalias !4
107+
// debug-NEXT: %"'ipg2" = getelementptr inbounds float, ptr %"x'", i64 1
108+
// debug-NEXT: %1 = getelementptr inbounds nuw float, ptr %x, i64 1
109+
// debug-NEXT: %"_5'ipl" = load float, ptr %"'ipg2", align 4, !alias.scope !4, !noalias !7
110+
// debug-NEXT: %_5 = load float, ptr %1, align 4, !alias.scope !7, !noalias !4
111+
// debug-NEXT: %_0 = fadd float %_2, %_5
112+
// debug-NEXT: %2 = fadd fast float %"_2'ipl", %"_5'ipl"
113+
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
114+
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
115+
// debug-NEXT: ret { float, float } %4
116+
// debug-NEXT: }
117+
118+
// df2
119+
// release: define internal fastcc { float, float } @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E(float noundef %x)
120+
// release-NEXT: invertstart:
121+
// release-NEXT: %_0.i = fmul float %x, %x
122+
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0.i, 0
123+
// release-NEXT: %1 = insertvalue { float, float } %0, float 0.000000e+00, 1
124+
// release-NEXT: ret { float, float } %1
125+
// release-NEXT: }
126+
127+
// debug: define internal { float, float } @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E(ptr %f, float %x, float %differeturn)
128+
// debug-NEXT: start:
129+
// debug-NEXT: %"x'de" = alloca float, align 4
130+
// debug-NEXT: store float 0.000000e+00, ptr %"x'de", align 4
131+
// debug-NEXT: %toreturn = alloca float, align 4
132+
// debug-NEXT: %_0 = call float %f(float %x) #12
133+
// debug-NEXT: store float %_0, ptr %toreturn, align 4
134+
// debug-NEXT: br label %invertstart
135+
// debug-EMPTY:
136+
// debug-NEXT: invertstart: ; preds = %start
137+
// debug-NEXT: %retreload = load float, ptr %toreturn, align 4
138+
// debug-NEXT: %0 = load float, ptr %"x'de", align 4
139+
// debug-NEXT: %1 = insertvalue { float, float } undef, float %retreload, 0
140+
// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1
141+
// debug-NEXT: ret { float, float } %2
142+
// debug-NEXT: }
143+
144+
// df3
145+
// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E(float %x.0.val)
146+
// release-NEXT: start:
147+
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.val, 0
148+
// release-NEXT: %1 = insertvalue { float, float } %0, float 0x40099999A0000000, 1
149+
// release-NEXT: ret { float, float } %1
150+
// release-NEXT: }
151+
152+
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E(ptr align 4 %x, ptr align 4 %"x'", ptr align 4 %y, ptr align 4 %"y'")
153+
// debug-NEXT: start:
154+
// debug-NEXT: %"_3'ipl" = load float, ptr %"x'", align 4, !alias.scope !9, !noalias !12
155+
// debug-NEXT: %_3 = load float, ptr %x, align 4, !alias.scope !12, !noalias !9
156+
// debug-NEXT: %"_4'ipl" = load float, ptr %"y'", align 4, !alias.scope !14, !noalias !17
157+
// debug-NEXT: %_4 = load float, ptr %y, align 4, !alias.scope !17, !noalias !14
158+
// debug-NEXT: %_0 = fmul float %_3, %_4
159+
// debug-NEXT: %0 = fmul fast float %"_3'ipl", %_4
160+
// debug-NEXT: %1 = fmul fast float %"_4'ipl", %_3
161+
// debug-NEXT: %2 = fadd fast float %0, %1
162+
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
163+
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
164+
// debug-NEXT: ret { float, float } %4
165+
// debug-NEXT: }
166+
167+
// df4
168+
// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float noundef %x.0, float %"x.0'")
169+
// release-NEXT: start:
170+
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0, 0
171+
// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'", 1
172+
// release-NEXT: ret { float, float } %1
173+
// release-NEXT: }
174+
175+
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float %x.0, float %"x.0'", float %x.1, float %"x.1'")
176+
// debug-NEXT: start:
177+
// debug-NEXT: %_0 = fmul float %x.0, %x.1
178+
// debug-NEXT: %0 = fmul fast float %"x.0'", %x.1
179+
// debug-NEXT: %1 = fmul fast float %"x.1'", %x.0
180+
// debug-NEXT: %2 = fadd fast float %0, %1
181+
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
182+
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
183+
// debug-NEXT: ret { float, float } %4
184+
// debug-NEXT: }
185+
186+
// df5
187+
// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float noundef %i.0, float %"i.0'")
188+
// release-NEXT: start:
189+
// release-NEXT: %_0 = fadd float %i.0, 1.000000e+00
190+
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0
191+
// release-NEXT: %1 = insertvalue { float, float } %0, float %"i.0'", 1
192+
// release-NEXT: ret { float, float } %1
193+
// release-NEXT: }
194+
195+
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float %i.0, float %"i.0'", float %i.1, float %"i.1'")
196+
// debug-NEXT: start:
197+
// debug-NEXT: %_0 = fadd float %i.0, %i.1
198+
// debug-NEXT: %0 = fadd fast float %"i.0'", %"i.1'"
199+
// debug-NEXT: %1 = insertvalue { float, float } undef, float %_0, 0
200+
// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1
201+
// debug-NEXT: ret { float, float } %2
202+
// debug-NEXT: }
203+
204+
// df6
205+
// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE(float noundef %i.0, float %"i.0'", float noundef %i.1, float %"i.1'")
206+
// release-NEXT: start:
207+
// release-NEXT: %_3 = fmul float %i.1, %i.1
208+
// release-NEXT: %0 = fadd fast float %"i.1'", %"i.1'"
209+
// release-NEXT: %1 = fmul fast float %0, %i.1
210+
// release-NEXT: %_0 = fadd float %i.0, %_3
211+
// release-NEXT: %2 = fadd fast float %"i.0'", %1
212+
// release-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
213+
// release-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
214+
// release-NEXT: ret { float, float } %4
215+
// release-NEXT: }
216+
217+
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE(float %i.0, float %"i.0'", float %i.1, float %"i.1'")
218+
// debug-NEXT: start:
219+
// debug-NEXT: %_3 = fmul float %i.1, %i.1
220+
// debug-NEXT: %0 = fmul fast float %"i.1'", %i.1
221+
// debug-NEXT: %1 = fmul fast float %"i.1'", %i.1
222+
// debug-NEXT: %2 = fadd fast float %0, %1
223+
// debug-NEXT: %_0 = fadd float %i.0, %_3
224+
// debug-NEXT: %3 = fadd fast float %"i.0'", %2
225+
// debug-NEXT: %4 = insertvalue { float, float } undef, float %_0, 0
226+
// debug-NEXT: %5 = insertvalue { float, float } %4, float %3, 1
227+
// debug-NEXT: ret { float, float } %5
228+
// debug-NEXT: }
229+
230+
fn main() {
231+
let x = std::hint::black_box(2.0);
232+
let y = std::hint::black_box(3.0);
233+
let z = std::hint::black_box(4.0);
234+
static Y: f32 = std::hint::black_box(3.2);
235+
236+
let in_f1 = [x, y];
237+
dbg!(f1(&in_f1));
238+
let res_f1 = df1(&in_f1, &[1.0, 0.0]);
239+
dbg!(res_f1);
240+
241+
dbg!(f2(square, x));
242+
let res_f2 = df2(square, x, 1.0);
243+
dbg!(res_f2);
244+
245+
dbg!(f3(&x, &Y));
246+
let res_f3 = df3(&x, &Y, &1.0, &0.0);
247+
dbg!(res_f3);
248+
249+
let in_f4 = (x, y);
250+
dbg!(f4(in_f4));
251+
let res_f4 = df4(in_f4, (1.0, 0.0));
252+
dbg!(res_f4);
253+
254+
let in_f5 = Input { x, y };
255+
dbg!(f5(in_f5));
256+
let res_f5 = df5(in_f5, Input { x: 1.0, y: 0.0 });
257+
dbg!(res_f5);
258+
259+
let in_f6 = NestedInput { x, y: Wrapper { z: y } };
260+
dbg!(f6(in_f6));
261+
let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } });
262+
dbg!(res_f6);
263+
}

0 commit comments

Comments
 (0)