Skip to content

Commit 843b631

Browse files
committed
Add round-trip parsing/printing test.
1 parent 573dd3e commit 843b631

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// RUN: %target-sil-opt %s -module-name=sil_differentiability_witness_parse | %target-sil-opt -module-name=sil_differentiability_witness_parse | %FileCheck %s
2+
3+
// Round-trip parsing and printing test.
4+
5+
sil_stage raw
6+
7+
import Builtin
8+
import Swift
9+
import SwiftShims
10+
11+
@differentiable(wrt: (x, y), jvp: foo_jvp where T : Differentiable)
12+
@_silgen_name("foo")
13+
func foo<T>(_ x: T, _ y: Float) -> T
14+
15+
@_silgen_name("foo_jvp")
16+
func foo_jvp<T>(_ x: T, _ y: Float) -> (T, (T.TangentVector, Float) -> T.TangentVector) where T : Differentiable
17+
18+
@_silgen_name("foo_vjp")
19+
func foo_vjp<T>(_ x: T, _ y: Float) -> (T, (T.TangentVector) -> (T.TangentVector, Float)) where T : Differentiable
20+
21+
// main
22+
sil [ossa] @main : $@convention(c) (Int32, UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>) -> Int32 {
23+
bb0(%0 : $Int32, %1 : $UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>):
24+
%2 = integer_literal $Builtin.Int32, 0 // user: %3
25+
%3 = struct $Int32 (%2 : $Builtin.Int32) // user: %4
26+
return %3 : $Int32 // id: %4
27+
} // end sil function 'main'
28+
29+
// foo
30+
sil hidden [differentiable source 0 wrt 0, 1 jvp @AD__foo__jvp_src_0_wrt_0_1 where T : Differentiable] [ossa] @foo : $@convention(thin) <T> (@in_guaranteed T, Float) -> @out T {
31+
// %0 // user: %5
32+
// %1 // users: %5, %3
33+
// %2 // user: %4
34+
bb0(%0 : $*T, %1 : $*T, %2 : $Float):
35+
debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3
36+
debug_value %2 : $Float, let, name "y", argno 2 // id: %4
37+
copy_addr %1 to [initialization] %0 : $*T // id: %5
38+
%6 = tuple () // user: %7
39+
return %6 : $() // id: %7
40+
} // end sil function 'foo'
41+
42+
// foo_jvp
43+
sil hidden [ossa] @foo_jvp : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector) {
44+
// %0 // user: %5
45+
// %1 // users: %5, %3
46+
// %2 // user: %4
47+
bb0(%0 : $*T, %1 : $*T, %2 : $Float):
48+
debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3
49+
debug_value %2 : $Float, let, name "y", argno 2 // id: %4
50+
copy_addr %1 to [initialization] %0 : $*T // id: %5
51+
// function_ref closure #1 in foo_jvp<A>(_:_:)
52+
%6 = function_ref @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %7
53+
%7 = partial_apply [callee_guaranteed] %6<T>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %8
54+
return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector // id: %8
55+
} // end sil function 'foo_jvp'
56+
57+
// AD__foo__jvp_src_0_wrt_0_1
58+
sil hidden [transparent] [thunk] [ossa] @AD__foo__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) {
59+
// %0 // user: %4
60+
// %1 // user: %4
61+
// %2 // user: %4
62+
bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float):
63+
// function_ref foo_jvp
64+
%3 = function_ref @foo_jvp : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %4
65+
%4 = apply %3<τ_0_0>(%0, %1, %2) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %5
66+
return %4 : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // id: %5
67+
} // end sil function 'AD__foo__jvp_src_0_wrt_0_1'
68+
69+
// closure #1 in foo_jvp<A>(_:_:)
70+
sil private [ossa] @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector {
71+
// %0 // user: %5
72+
// %1 // users: %5, %3
73+
// %2 // user: %4
74+
bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector, %2 : $Float):
75+
debug_value_addr %1 : $*T.TangentVector, let, name "dx", argno 1 // id: %3
76+
debug_value %2 : $Float, let, name "dy", argno 2 // id: %4
77+
copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %5
78+
%6 = tuple () // user: %7
79+
return %6 : $() // id: %7
80+
} // end sil function '$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_'
81+
82+
// foo_vjp
83+
sil hidden [ossa] @foo_vjp : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float)) {
84+
// %0 // user: %5
85+
// %1 // users: %5, %3
86+
// %2 // user: %4
87+
bb0(%0 : $*T, %1 : $*T, %2 : $Float):
88+
debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3
89+
debug_value %2 : $Float, let, name "y", argno 2 // id: %4
90+
copy_addr %1 to [initialization] %0 : $*T // id: %5
91+
// function_ref closure #1 in foo_vjp<A>(_:_:)
92+
%6 = function_ref @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %7
93+
%7 = partial_apply [callee_guaranteed] %6<T>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %8
94+
return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) // id: %8
95+
} // end sil function 'foo_vjp'
96+
97+
// closure #1 in foo_vjp<A>(_:_:)
98+
sil private [ossa] @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) {
99+
// %0 // user: %3
100+
// %1 // users: %3, %2
101+
bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector):
102+
debug_value_addr %1 : $*T.TangentVector, let, name "$0", argno 1 // id: %2
103+
copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %3
104+
%4 = metatype $@thin Float.Type
105+
%5 = alloc_stack $Float // users: %10, %9, %8
106+
%6 = metatype $@thick Float.Type // user: %8
107+
// function_ref static AdditiveArithmetic<>.zero.getter
108+
%7 = function_ref @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %8
109+
%8 = apply %7<Float>(%5, %6) : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0
110+
%9 = load [trivial] %5 : $*Float // user: %11
111+
dealloc_stack %5 : $*Float // id: %10
112+
return %9 : $Float // id: %11
113+
} // end sil function '$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_'
114+
115+
// static AdditiveArithmetic<>.zero.getter
116+
sil [serialized] [always_inline] @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0
117+
118+
sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable {
119+
jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
120+
vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
121+
}
122+
123+
// CHECK: sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable {
124+
// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
125+
// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
126+
// CHECK: }

0 commit comments

Comments
 (0)