@@ -10,6 +10,35 @@ sil_stage raw
10
10
import Swift
11
11
import Builtin
12
12
13
+ sil @examplefunc : $@convention(thin) (Float, Float, Float) -> Float
14
+ sil @examplemethod : $@convention(method) (Float, Float, Float) -> Float
15
+
16
+ // CHECK-LABEL: sil @test
17
+ sil @test : $@convention(thin) () -> () {
18
+ bb0:
19
+ %0 = function_ref @examplefunc : $@convention(thin) (Float, Float, Float) -> Float
20
+ %1 = differentiable_function [wrt 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float
21
+
22
+ // CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float
23
+ %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float
24
+ %3 = differentiable_function [wrt 0] %0 : $@convention(thin) (Float, Float, Float) -> Float
25
+
26
+ // CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float
27
+ %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float
28
+ %5 = function_ref @examplemethod : $@convention(method) (Float, Float, Float) -> Float
29
+ %6 = differentiable_function [wrt 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float
30
+
31
+ // CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float
32
+ %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float
33
+ %8 = differentiable_function [wrt 0] %5 : $@convention(method) (Float, Float, Float) -> Float
34
+
35
+ // CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float
36
+ %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float
37
+
38
+ %ret = tuple ()
39
+ return %ret : $()
40
+ }
41
+
13
42
// The adjoint function emitted by the compiler. Parameter are a vector, as in
14
43
// vector-Jacobian products, and pullback values. The function is partially
15
44
// applied to a pullback struct to form a pullback, which takes a vector and
0 commit comments