1
1
// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck %s
2
2
3
- // Test `differentiable_function_extract` with explicit lowered type.
3
+ // Test `differentiable_function_extract` and
4
+ // `differentiability_witness_function` with explicit lowered type.
4
5
// SIL generated via `%target-sil-opt -loadable-address %s`.
5
6
// Note: SIL serialization/deserialization does not support lowered SIL.
6
7
@@ -27,37 +28,43 @@ struct Large : Differentiable {
27
28
mutating func move(along direction: Large.TangentVector)
28
29
}
29
30
31
+ sil_differentiability_witness [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
32
+
30
33
sil @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
31
34
sil @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
32
35
33
36
// CHECK-LABEL: sil @test
34
37
sil @test : $@convention(thin) () -> () {
35
38
bb0:
36
- %0 = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
37
- %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
38
- %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
39
+ %func = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
40
+ %func_jvpwitness_wrt_012 = differentiability_witness_function [jvp] [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector, Large.TangentVector, Large.TangentVector) -> Large.TangentVector)
41
+ %func_vjpwitness_wrt_012 = differentiability_witness_function [vjp] [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
42
+ %func_diff_wrt_012 = differentiable_function [parameters 0 1 2] %func : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large with_derivative {%func_jvpwitness_wrt_012 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector, Large.TangentVector, Large.TangentVector) -> Large.TangentVector), %func_vjpwitness_wrt_012 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))}
43
+ %func_vjp_wrt_012 = differentiable_function_extract [vjp] %func_diff_wrt_012 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
39
44
40
- // CHECK: %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
41
- // CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
45
+ // CHECK: [[FUNC_REF:%.*]] = function_ref @examplefunc
46
+ // CHECK: [[DIFF_WRT_012:%.*]] = differentiable_function [parameters 0 1 2] [[FUNC_REF]] : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
47
+ // CHECK: [[VJP_WRT_012:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_012]] : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
42
48
43
- %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
44
- %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
49
+ %func_diff_wrt_0 = differentiable_function [parameters 0] %func : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
50
+ %func_vjp_wrt_0 = differentiable_function_extract [vjp] %func_diff_wrt_0 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
45
51
46
- // CHECK: %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
47
- // CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
52
+ // CHECK: [[DIFF_WRT_0:%.*]] = differentiable_function [parameters 0] [[FUNC_REF]] : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
53
+ // CHECK: [[VJP_WRT_0:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0]] : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
48
54
49
- %5 = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
50
- %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
51
- %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
55
+ %method = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
56
+ %method_diff_wrt_0123 = differentiable_function [parameters 0 1 2] %method : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
57
+ %7 = differentiable_function_extract [vjp] %method_diff_wrt_0123 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
52
58
53
- // CHECK: %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
54
- // CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
59
+ // CHECK: [[METHOD_REF:%.*]] = function_ref @examplemethod
60
+ // CHECK: [[DIFF_WRT_0123:%.*]] = differentiable_function [parameters 0 1 2] [[METHOD_REF]] : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
61
+ // CHECK: [[VJP_WRT_0123:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0123]] : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
55
62
56
- %8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
57
- %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
63
+ %method_diff_wrt_0 = differentiable_function [parameters 0] %method : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
64
+ %method_vjp_wrt_0 = differentiable_function_extract [vjp] %method_diff_wrt_0 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
58
65
59
- // CHECK: %8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
60
- // CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
66
+ // CHECK: [[DIFF_WRT_0:%.*]] = differentiable_function [parameters 0] [[METHOD_REF]] : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
67
+ // CHECK: [[VJP_WRT_0:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0]] : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
61
68
62
69
%10 = tuple ()
63
70
return %10 : $()
0 commit comments