Skip to content

Commit ec70616

Browse files
committed
update batch test
1 parent de78744 commit ec70616

File tree

1 file changed

+34
-15
lines changed

1 file changed

+34
-15
lines changed

tests/codegen/autodiffv.rs

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
22
//@ no-prefer-dynamic
33
//@ needs-enzyme
4+
45
#![feature(autodiff)]
56

67
use std::autodiff::autodiff;
78

8-
#[autodiff(d_square, Reverse, 4, Duplicated, Active)]
9+
#[autodiff(d_square3, Forward, Dual, DualOnly)]
10+
#[no_mangle]
11+
fn squaref(x: &f32) -> f32 {
12+
2.0 * x * x
13+
}
14+
15+
16+
#[autodiff(d_square2, Forward, 4, Dual, DualOnly)]
17+
#[autodiff(d_square, Forward, 4, Dual, Dual)]
918
#[no_mangle]
10-
fn square(x: &f64) -> f64 {
19+
fn square(x: &f32) -> f32 {
1120
x * x
1221
}
1322

@@ -33,21 +42,31 @@ fn square(x: &f64) -> f64 {
3342
// CHECK-NEXT:}
3443

3544
fn main() {
36-
let x = 3.0;
45+
let x = std::hint::black_box(3.0);
3746
let output = square(&x);
47+
dbg!(&output);
3848
assert_eq!(9.0, output);
49+
dbg!(squaref(&x));
3950

40-
let mut df_dx1 = 0.0;
41-
let mut df_dx2 = 0.0;
42-
let mut df_dx3 = 0.0;
51+
let mut df_dx1 = 1.0;
52+
let mut df_dx2 = 2.0;
53+
let mut df_dx3 = 3.0;
4354
let mut df_dx4 = 0.0;
44-
let [o1, o2, o3, o4] = d_square(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4, 1.0);
45-
assert_eq!(output, o1);
46-
assert_eq!(output, o2);
47-
assert_eq!(output, o3);
48-
assert_eq!(output, o4);
49-
assert_eq!(6.0, df_dx1);
50-
assert_eq!(6.0, df_dx2);
51-
assert_eq!(6.0, df_dx3);
52-
assert_eq!(6.0, df_dx4);
55+
let [o1,o2,o3,o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
56+
dbg!(o1, o2, o3, o4);
57+
let [output2, o1,o2,o3,o4] = d_square(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
58+
dbg!(o1, o2, o3, o4);
59+
assert_eq!(output, output2);
60+
assert!((6.0 - o1).abs() < 1e-10);
61+
assert!((12.0 - o2).abs() < 1e-10);
62+
assert!((18.0 - o3).abs() < 1e-10);
63+
assert!((0.0 - o4).abs() < 1e-10);
64+
assert_eq!(1.0, df_dx1);
65+
assert_eq!(2.0, df_dx2);
66+
assert_eq!(3.0, df_dx3);
67+
assert_eq!(0.0, df_dx4);
68+
assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1);
69+
assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2);
70+
assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3);
71+
assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4);
5372
}

0 commit comments

Comments
 (0)