|
2 | 2 |
|
3 | 3 | // Round-trip parsing and printing test.
|
4 | 4 |
|
5 |
| -// Swift source code (`-emit-silgen` output is below): |
6 |
| -// |
7 |
| -// @differentiable(jvp: foo_jvp where T: Differentiable) |
8 |
| -// @_silgen_name("foo") |
9 |
| -// func foo<T>(_ x: T, _ y: Float) -> T { x } |
10 |
| -// |
11 |
| -// @_silgen_name("foo_jvp") |
12 |
| -// func foo_jvp<T: Differentiable>(_ x: T, _ y: Float) -> (T, (T.TangentVector, Float) -> T.TangentVector) { |
13 |
| -// (x, { dx, dy in dx }) |
14 |
| -// } |
15 |
| -// |
16 |
| -// @_silgen_name("foo_vjp") |
17 |
| -// func foo_vjp<T: Differentiable>(_ x: T, _ y: Float) -> (T, (T.TangentVector) -> (T.TangentVector, Float)) { |
18 |
| -// (x, { ($0, .zero) }) |
19 |
| -// } |
20 |
| -// |
21 |
| -// The `sil_differentiability_witness` at the end was manually written. |
22 |
| - |
23 | 5 | sil_stage raw
|
24 | 6 |
|
25 | 7 | import Builtin
|
26 | 8 | import Swift
|
27 | 9 | import SwiftShims
|
28 | 10 |
|
29 |
| -@differentiable(wrt: (x, y), jvp: foo_jvp where T : Differentiable) |
30 |
| -@_silgen_name("foo") |
31 |
| -func foo<T>(_ x: T, _ y: Float) -> T |
32 |
| - |
33 |
| -@_silgen_name("foo_jvp") |
34 |
| -func foo_jvp<T>(_ x: T, _ y: Float) -> (T, (T.TangentVector, Float) -> T.TangentVector) where T : Differentiable |
| 11 | +// Test public non-generic function. |
| 12 | +// SIL differentiability witness: |
| 13 | +// - Has public linkage (implicit). |
| 14 | +// - Has no `where` clause. |
35 | 15 |
|
36 |
| -@_silgen_name("foo_vjp") |
37 |
| -func foo_vjp<T>(_ x: T, _ y: Float) -> (T, (T.TangentVector) -> (T.TangentVector, Float)) where T : Differentiable |
| 16 | +sil [ossa] @foo : $@convention(thin) (Float) -> Float |
38 | 17 |
|
39 |
| -// main |
40 |
| -sil [ossa] @main : $@convention(c) (Int32, UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>) -> Int32 { |
41 |
| -bb0(%0 : $Int32, %1 : $UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>): |
42 |
| - %2 = integer_literal $Builtin.Int32, 0 // user: %3 |
43 |
| - %3 = struct $Int32 (%2 : $Builtin.Int32) // user: %4 |
44 |
| - return %3 : $Int32 // id: %4 |
45 |
| -} // end sil function 'main' |
| 18 | +sil @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
46 | 19 |
|
47 |
| -// foo |
48 |
| -sil hidden [differentiable source 0 wrt 0, 1 jvp @AD__foo__jvp_src_0_wrt_0_1 where T : Differentiable] [ossa] @foo : $@convention(thin) <T> (@in_guaranteed T, Float) -> @out T { |
49 |
| -// %0 // user: %5 |
50 |
| -// %1 // users: %5, %3 |
51 |
| -// %2 // user: %4 |
52 |
| -bb0(%0 : $*T, %1 : $*T, %2 : $Float): |
53 |
| - debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 |
54 |
| - debug_value %2 : $Float, let, name "y", argno 2 // id: %4 |
55 |
| - copy_addr %1 to [initialization] %0 : $*T // id: %5 |
56 |
| - %6 = tuple () // user: %7 |
57 |
| - return %6 : $() // id: %7 |
58 |
| -} // end sil function 'foo' |
| 20 | +sil @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
59 | 21 |
|
60 |
| -// foo_jvp |
61 |
| -sil hidden [ossa] @foo_jvp : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector) { |
62 |
| -// %0 // user: %5 |
63 |
| -// %1 // users: %5, %3 |
64 |
| -// %2 // user: %4 |
65 |
| -bb0(%0 : $*T, %1 : $*T, %2 : $Float): |
66 |
| - debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 |
67 |
| - debug_value %2 : $Float, let, name "y", argno 2 // id: %4 |
68 |
| - copy_addr %1 to [initialization] %0 : $*T // id: %5 |
69 |
| - // function_ref closure #1 in foo_jvp<A>(_:_:) |
70 |
| - %6 = function_ref @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %7 |
71 |
| - %7 = partial_apply [callee_guaranteed] %6<T>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %8 |
72 |
| - return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector // id: %8 |
73 |
| -} // end sil function 'foo_jvp' |
| 22 | +sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float { |
| 23 | + jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| 24 | + vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| 25 | +} |
74 | 26 |
|
75 |
| -// AD__foo__jvp_src_0_wrt_0_1 |
76 |
| -sil hidden [transparent] [thunk] [ossa] @AD__foo__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) { |
77 |
| -// %0 // user: %4 |
78 |
| -// %1 // user: %4 |
79 |
| -// %2 // user: %4 |
80 |
| -bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): |
81 |
| - // function_ref foo_jvp |
82 |
| - %3 = function_ref @foo_jvp : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %4 |
83 |
| - %4 = apply %3<τ_0_0>(%0, %1, %2) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %5 |
84 |
| - return %4 : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // id: %5 |
85 |
| -} // end sil function 'AD__foo__jvp_src_0_wrt_0_1' |
| 27 | +// CHECK-LABEL: // differentiability witness for foo |
| 28 | +// CHECK: sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float { |
| 29 | +// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| 30 | +// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| 31 | +// CHECK: } |
86 | 32 |
|
87 |
| -// closure #1 in foo_jvp<A>(_:_:) |
88 |
| -sil private [ossa] @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector { |
89 |
| -// %0 // user: %5 |
90 |
| -// %1 // users: %5, %3 |
91 |
| -// %2 // user: %4 |
92 |
| -bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector, %2 : $Float): |
93 |
| - debug_value_addr %1 : $*T.TangentVector, let, name "dx", argno 1 // id: %3 |
94 |
| - debug_value %2 : $Float, let, name "dy", argno 2 // id: %4 |
95 |
| - copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %5 |
96 |
| - %6 = tuple () // user: %7 |
97 |
| - return %6 : $() // id: %7 |
98 |
| -} // end sil function '$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_' |
| 33 | +// Test internal generic function. |
| 34 | +// SIL differentiability witness: |
| 35 | +// - Has hidden linkage. |
| 36 | +// - Has `where` clause. |
99 | 37 |
|
100 |
| -// foo_vjp |
101 |
| -sil hidden [ossa] @foo_vjp : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float)) { |
102 |
| -// %0 // user: %5 |
103 |
| -// %1 // users: %5, %3 |
104 |
| -// %2 // user: %4 |
| 38 | +sil hidden [ossa] @generic : $@convention(thin) <T> (@in_guaranteed T, Float) -> @out T { |
105 | 39 | bb0(%0 : $*T, %1 : $*T, %2 : $Float):
|
106 |
| - debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 |
107 |
| - debug_value %2 : $Float, let, name "y", argno 2 // id: %4 |
108 |
| - copy_addr %1 to [initialization] %0 : $*T // id: %5 |
109 |
| - // function_ref closure #1 in foo_vjp<A>(_:_:) |
110 |
| - %6 = function_ref @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %7 |
111 |
| - %7 = partial_apply [callee_guaranteed] %6<T>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %8 |
112 |
| - return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) // id: %8 |
113 |
| -} // end sil function 'foo_vjp' |
| 40 | + copy_addr %1 to [initialization] %0 : $*T |
| 41 | + %void = tuple () |
| 42 | + return %void : $() |
| 43 | +} |
114 | 44 |
|
115 |
| -// closure #1 in foo_vjp<A>(_:_:) |
116 |
| -sil private [ossa] @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) { |
117 |
| -// %0 // user: %3 |
118 |
| -// %1 // users: %3, %2 |
119 |
| -bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector): |
120 |
| - debug_value_addr %1 : $*T.TangentVector, let, name "$0", argno 1 // id: %2 |
121 |
| - copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %3 |
122 |
| - %4 = metatype $@thin Float.Type |
123 |
| - %5 = alloc_stack $Float // users: %10, %9, %8 |
124 |
| - %6 = metatype $@thick Float.Type // user: %8 |
125 |
| - // function_ref static AdditiveArithmetic<>.zero.getter |
126 |
| - %7 = function_ref @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %8 |
127 |
| - %8 = apply %7<Float>(%5, %6) : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 |
128 |
| - %9 = load [trivial] %5 : $*Float // user: %11 |
129 |
| - dealloc_stack %5 : $*Float // id: %10 |
130 |
| - return %9 : $Float // id: %11 |
131 |
| -} // end sil function '$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_' |
| 45 | +sil hidden @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) { |
| 46 | +bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): |
| 47 | + return undef : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector |
| 48 | +} |
132 | 49 |
|
133 |
| -// static AdditiveArithmetic<>.zero.getter |
134 |
| -sil [serialized] [always_inline] @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 |
| 50 | +sil hidden @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) { |
| 51 | +bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): |
| 52 | + return undef : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // id: %5 |
| 53 | +} |
135 | 54 |
|
136 |
| -sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { |
137 |
| - jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) |
138 |
| - vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) |
| 55 | +sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { |
| 56 | + jvp: @AD__generic__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) |
| 57 | + vjp: @AD__generic__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) |
139 | 58 | }
|
140 | 59 |
|
141 |
| -// CHECK-LABEL: // differentiability witness for foo |
142 |
| -// CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { |
143 |
| -// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) |
144 |
| -// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) |
| 60 | +// CHECK-LABEL: // differentiability witness for generic |
| 61 | +// CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { |
| 62 | +// CHECK: jvp: @AD__generic__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) |
| 63 | +// CHECK: vjp: @AD__generic__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) |
145 | 64 | // CHECK: }
|
0 commit comments