|
| 1 | +// RUN: %target-sil-opt %s -module-name=sil_differentiability_witness_parse | %target-sil-opt -module-name=sil_differentiability_witness_parse | %FileCheck %s |
| 2 | + |
| 3 | +// Round-trip parsing and printing test. |
| 4 | + |
| 5 | +sil_stage raw |
| 6 | + |
| 7 | +import Builtin |
| 8 | +import Swift |
| 9 | +import SwiftShims |
| 10 | + |
| 11 | +@differentiable(wrt: (x, y), jvp: foo_jvp where T : Differentiable) |
| 12 | +@_silgen_name("foo") |
| 13 | +func foo<T>(_ x: T, _ y: Float) -> T |
| 14 | + |
| 15 | +@_silgen_name("foo_jvp") |
| 16 | +func foo_jvp<T>(_ x: T, _ y: Float) -> (T, (T.TangentVector, Float) -> T.TangentVector) where T : Differentiable |
| 17 | + |
| 18 | +@_silgen_name("foo_vjp") |
| 19 | +func foo_vjp<T>(_ x: T, _ y: Float) -> (T, (T.TangentVector) -> (T.TangentVector, Float)) where T : Differentiable |
| 20 | + |
| 21 | +// main |
| 22 | +sil [ossa] @main : $@convention(c) (Int32, UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>) -> Int32 { |
| 23 | +bb0(%0 : $Int32, %1 : $UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>): |
| 24 | + %2 = integer_literal $Builtin.Int32, 0 // user: %3 |
| 25 | + %3 = struct $Int32 (%2 : $Builtin.Int32) // user: %4 |
| 26 | + return %3 : $Int32 // id: %4 |
| 27 | +} // end sil function 'main' |
| 28 | + |
| 29 | +// foo |
| 30 | +sil hidden [differentiable source 0 wrt 0, 1 jvp @AD__foo__jvp_src_0_wrt_0_1 where T : Differentiable] [ossa] @foo : $@convention(thin) <T> (@in_guaranteed T, Float) -> @out T { |
| 31 | +// %0 // user: %5 |
| 32 | +// %1 // users: %5, %3 |
| 33 | +// %2 // user: %4 |
| 34 | +bb0(%0 : $*T, %1 : $*T, %2 : $Float): |
| 35 | + debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 |
| 36 | + debug_value %2 : $Float, let, name "y", argno 2 // id: %4 |
| 37 | + copy_addr %1 to [initialization] %0 : $*T // id: %5 |
| 38 | + %6 = tuple () // user: %7 |
| 39 | + return %6 : $() // id: %7 |
| 40 | +} // end sil function 'foo' |
| 41 | + |
| 42 | +// foo_jvp |
| 43 | +sil hidden [ossa] @foo_jvp : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector) { |
| 44 | +// %0 // user: %5 |
| 45 | +// %1 // users: %5, %3 |
| 46 | +// %2 // user: %4 |
| 47 | +bb0(%0 : $*T, %1 : $*T, %2 : $Float): |
| 48 | + debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 |
| 49 | + debug_value %2 : $Float, let, name "y", argno 2 // id: %4 |
| 50 | + copy_addr %1 to [initialization] %0 : $*T // id: %5 |
| 51 | + // function_ref closure #1 in foo_jvp<A>(_:_:) |
| 52 | + %6 = function_ref @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %7 |
| 53 | + %7 = partial_apply [callee_guaranteed] %6<T>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %8 |
| 54 | + return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector // id: %8 |
| 55 | +} // end sil function 'foo_jvp' |
| 56 | + |
| 57 | +// AD__foo__jvp_src_0_wrt_0_1 |
| 58 | +sil hidden [transparent] [thunk] [ossa] @AD__foo__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) { |
| 59 | +// %0 // user: %4 |
| 60 | +// %1 // user: %4 |
| 61 | +// %2 // user: %4 |
| 62 | +bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float): |
| 63 | + // function_ref foo_jvp |
| 64 | + %3 = function_ref @foo_jvp : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %4 |
| 65 | + %4 = apply %3<τ_0_0>(%0, %1, %2) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %5 |
| 66 | + return %4 : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // id: %5 |
| 67 | +} // end sil function 'AD__foo__jvp_src_0_wrt_0_1' |
| 68 | + |
| 69 | +// closure #1 in foo_jvp<A>(_:_:) |
| 70 | +sil private [ossa] @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector { |
| 71 | +// %0 // user: %5 |
| 72 | +// %1 // users: %5, %3 |
| 73 | +// %2 // user: %4 |
| 74 | +bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector, %2 : $Float): |
| 75 | + debug_value_addr %1 : $*T.TangentVector, let, name "dx", argno 1 // id: %3 |
| 76 | + debug_value %2 : $Float, let, name "dy", argno 2 // id: %4 |
| 77 | + copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %5 |
| 78 | + %6 = tuple () // user: %7 |
| 79 | + return %6 : $() // id: %7 |
| 80 | +} // end sil function '$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_' |
| 81 | + |
| 82 | +// foo_vjp |
| 83 | +sil hidden [ossa] @foo_vjp : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float)) { |
| 84 | +// %0 // user: %5 |
| 85 | +// %1 // users: %5, %3 |
| 86 | +// %2 // user: %4 |
| 87 | +bb0(%0 : $*T, %1 : $*T, %2 : $Float): |
| 88 | + debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3 |
| 89 | + debug_value %2 : $Float, let, name "y", argno 2 // id: %4 |
| 90 | + copy_addr %1 to [initialization] %0 : $*T // id: %5 |
| 91 | + // function_ref closure #1 in foo_vjp<A>(_:_:) |
| 92 | + %6 = function_ref @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %7 |
| 93 | + %7 = partial_apply [callee_guaranteed] %6<T>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %8 |
| 94 | + return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) // id: %8 |
| 95 | +} // end sil function 'foo_vjp' |
| 96 | + |
| 97 | +// closure #1 in foo_vjp<A>(_:_:) |
| 98 | +sil private [ossa] @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) { |
| 99 | +// %0 // user: %3 |
| 100 | +// %1 // users: %3, %2 |
| 101 | +bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector): |
| 102 | + debug_value_addr %1 : $*T.TangentVector, let, name "$0", argno 1 // id: %2 |
| 103 | + copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %3 |
| 104 | + %4 = metatype $@thin Float.Type |
| 105 | + %5 = alloc_stack $Float // users: %10, %9, %8 |
| 106 | + %6 = metatype $@thick Float.Type // user: %8 |
| 107 | + // function_ref static AdditiveArithmetic<>.zero.getter |
| 108 | + %7 = function_ref @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %8 |
| 109 | + %8 = apply %7<Float>(%5, %6) : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 |
| 110 | + %9 = load [trivial] %5 : $*Float // user: %11 |
| 111 | + dealloc_stack %5 : $*Float // id: %10 |
| 112 | + return %9 : $Float // id: %11 |
| 113 | +} // end sil function '$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_' |
| 114 | + |
| 115 | +// static AdditiveArithmetic<>.zero.getter |
| 116 | +sil [serialized] [always_inline] @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 |
| 117 | + |
| 118 | +sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable { |
| 119 | + jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) |
| 120 | + vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) |
| 121 | +} |
| 122 | + |
| 123 | +// CHECK: sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable { |
| 124 | +// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) |
| 125 | +// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) |
| 126 | +// CHECK: } |
0 commit comments