Skip to content

Commit 69209be

Browse files
committed
Add parsing/printing tests, address review feedback.
1 parent 27d7abc commit 69209be

File tree

3 files changed

+43
-123
lines changed

3 files changed

+43
-123
lines changed

lib/ParseSIL/ParseSIL.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6809,6 +6809,7 @@ static void convertRequirements(Parser &P, SILFunction *F,
68096809
/// decl-sil-differentiability-witness ::=
68106810
/// 'sil_differentiability_witness'
68116811
/// ('[' 'serialized' ']')?
6812+
/// sil-linkage?
68126813
/// '[' 'parameters' index-subset ']'
68136814
/// '[' 'results' index-subset ']'
68146815
/// ('[' 'where' derivatve-generic-signature-requirements ']')?
@@ -6828,8 +6829,9 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
68286829
Optional<SILLinkage> linkage;
68296830
if (parseSILLinkage(linkage, P))
68306831
return true;
6832+
// Default to public linkage.
68316833
if (!linkage)
6832-
linkage = SILLinkage::PublicExternal;
6834+
linkage = SILLinkage::Public;
68336835

68346836
// Parse '[serialized]' flag (optional).
68356837
bool isSerialized = false;

lib/SIL/SILPrinter.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3100,7 +3100,6 @@ void SILDifferentiabilityWitness::print(
31003100
interleave(requirements,
31013101
[&](Requirement req) {
31023102
req.print(OS, subPrinter);
3103-
return;
31043103
},
31053104
[&] { OS << ", "; });
31063105
OS << ']';

test/AutoDiff/sil_differentiability_witness_parse.sil

Lines changed: 40 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -2,144 +2,63 @@
22

33
// Round-trip parsing and printing test.
44

5-
// Swift source code (`-emit-silgen` output is below):
6-
//
7-
// @differentiable(jvp: foo_jvp where T: Differentiable)
8-
// @_silgen_name("foo")
9-
// func foo<T>(_ x: T, _ y: Float) -> T { x }
10-
//
11-
// @_silgen_name("foo_jvp")
12-
// func foo_jvp<T: Differentiable>(_ x: T, _ y: Float) -> (T, (T.TangentVector, Float) -> T.TangentVector) {
13-
// (x, { dx, dy in dx })
14-
// }
15-
//
16-
// @_silgen_name("foo_vjp")
17-
// func foo_vjp<T: Differentiable>(_ x: T, _ y: Float) -> (T, (T.TangentVector) -> (T.TangentVector, Float)) {
18-
// (x, { ($0, .zero) })
19-
// }
20-
//
21-
// The `sil_differentiability_witness` at the end was manually written.
22-
235
sil_stage raw
246

257
import Builtin
268
import Swift
279
import SwiftShims
2810

29-
@differentiable(wrt: (x, y), jvp: foo_jvp where T : Differentiable)
30-
@_silgen_name("foo")
31-
func foo<T>(_ x: T, _ y: Float) -> T
32-
33-
@_silgen_name("foo_jvp")
34-
func foo_jvp<T>(_ x: T, _ y: Float) -> (T, (T.TangentVector, Float) -> T.TangentVector) where T : Differentiable
11+
// Test public non-generic function.
12+
// SIL differentiability witness:
13+
// - Has public linkage (implicit).
14+
// - Has no `where` clause.
3515

36-
@_silgen_name("foo_vjp")
37-
func foo_vjp<T>(_ x: T, _ y: Float) -> (T, (T.TangentVector) -> (T.TangentVector, Float)) where T : Differentiable
16+
sil [ossa] @foo : $@convention(thin) (Float) -> Float
3817

39-
// main
40-
sil [ossa] @main : $@convention(c) (Int32, UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>) -> Int32 {
41-
bb0(%0 : $Int32, %1 : $UnsafeMutablePointer<Optional<UnsafeMutablePointer<Int8>>>):
42-
%2 = integer_literal $Builtin.Int32, 0 // user: %3
43-
%3 = struct $Int32 (%2 : $Builtin.Int32) // user: %4
44-
return %3 : $Int32 // id: %4
45-
} // end sil function 'main'
18+
sil @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
4619

47-
// foo
48-
sil hidden [differentiable source 0 wrt 0, 1 jvp @AD__foo__jvp_src_0_wrt_0_1 where T : Differentiable] [ossa] @foo : $@convention(thin) <T> (@in_guaranteed T, Float) -> @out T {
49-
// %0 // user: %5
50-
// %1 // users: %5, %3
51-
// %2 // user: %4
52-
bb0(%0 : $*T, %1 : $*T, %2 : $Float):
53-
debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3
54-
debug_value %2 : $Float, let, name "y", argno 2 // id: %4
55-
copy_addr %1 to [initialization] %0 : $*T // id: %5
56-
%6 = tuple () // user: %7
57-
return %6 : $() // id: %7
58-
} // end sil function 'foo'
20+
sil @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
5921

60-
// foo_jvp
61-
sil hidden [ossa] @foo_jvp : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector) {
62-
// %0 // user: %5
63-
// %1 // users: %5, %3
64-
// %2 // user: %4
65-
bb0(%0 : $*T, %1 : $*T, %2 : $Float):
66-
debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3
67-
debug_value %2 : $Float, let, name "y", argno 2 // id: %4
68-
copy_addr %1 to [initialization] %0 : $*T // id: %5
69-
// function_ref closure #1 in foo_jvp<A>(_:_:)
70-
%6 = function_ref @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %7
71-
%7 = partial_apply [callee_guaranteed] %6<T>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // user: %8
72-
return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector // id: %8
73-
} // end sil function 'foo_jvp'
22+
sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float {
23+
jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
24+
vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
25+
}
7426

75-
// AD__foo__jvp_src_0_wrt_0_1
76-
sil hidden [transparent] [thunk] [ossa] @AD__foo__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) {
77-
// %0 // user: %4
78-
// %1 // user: %4
79-
// %2 // user: %4
80-
bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float):
81-
// function_ref foo_jvp
82-
%3 = function_ref @foo_jvp : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %4
83-
%4 = apply %3<τ_0_0>(%0, %1, %2) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // user: %5
84-
return %4 : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector // id: %5
85-
} // end sil function 'AD__foo__jvp_src_0_wrt_0_1'
27+
// CHECK-LABEL: // differentiability witness for foo
28+
// CHECK: sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float {
29+
// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
30+
// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
31+
// CHECK: }
8632

87-
// closure #1 in foo_jvp<A>(_:_:)
88-
sil private [ossa] @$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_ : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T.TangentVector, Float) -> @out T.TangentVector {
89-
// %0 // user: %5
90-
// %1 // users: %5, %3
91-
// %2 // user: %4
92-
bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector, %2 : $Float):
93-
debug_value_addr %1 : $*T.TangentVector, let, name "dx", argno 1 // id: %3
94-
debug_value %2 : $Float, let, name "dy", argno 2 // id: %4
95-
copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %5
96-
%6 = tuple () // user: %7
97-
return %6 : $() // id: %7
98-
} // end sil function '$s4main7foo_jvpyx_13TangentVectorQzAD_Sftctx_Sfts14DifferentiableRzlFA2D_SftcfU_'
33+
// Test internal generic function.
34+
// SIL differentiability witness:
35+
// - Has hidden linkage.
36+
// - Has `where` clause.
9937

100-
// foo_vjp
101-
sil hidden [ossa] @foo_vjp : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T, Float) -> (@out T, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float)) {
102-
// %0 // user: %5
103-
// %1 // users: %5, %3
104-
// %2 // user: %4
38+
sil hidden [ossa] @generic : $@convention(thin) <T> (@in_guaranteed T, Float) -> @out T {
10539
bb0(%0 : $*T, %1 : $*T, %2 : $Float):
106-
debug_value_addr %1 : $*T, let, name "x", argno 1 // id: %3
107-
debug_value %2 : $Float, let, name "y", argno 2 // id: %4
108-
copy_addr %1 to [initialization] %0 : $*T // id: %5
109-
// function_ref closure #1 in foo_vjp<A>(_:_:)
110-
%6 = function_ref @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %7
111-
%7 = partial_apply [callee_guaranteed] %6<T>() : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // user: %8
112-
return %7 : $@callee_guaranteed (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) // id: %8
113-
} // end sil function 'foo_vjp'
40+
copy_addr %1 to [initialization] %0 : $*T
41+
%void = tuple ()
42+
return %void : $()
43+
}
11444

115-
// closure #1 in foo_vjp<A>(_:_:)
116-
sil private [ossa] @$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_ : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T.TangentVector) -> (@out T.TangentVector, Float) {
117-
// %0 // user: %3
118-
// %1 // users: %3, %2
119-
bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector):
120-
debug_value_addr %1 : $*T.TangentVector, let, name "$0", argno 1 // id: %2
121-
copy_addr %1 to [initialization] %0 : $*T.TangentVector // id: %3
122-
%4 = metatype $@thin Float.Type
123-
%5 = alloc_stack $Float // users: %10, %9, %8
124-
%6 = metatype $@thick Float.Type // user: %8
125-
// function_ref static AdditiveArithmetic<>.zero.getter
126-
%7 = function_ref @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %8
127-
%8 = apply %7<Float>(%5, %6) : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0
128-
%9 = load [trivial] %5 : $*Float // user: %11
129-
dealloc_stack %5 : $*Float // id: %10
130-
return %9 : $Float // id: %11
131-
} // end sil function '$s4main7foo_vjpyx_13TangentVectorQz_SftADctx_Sfts14DifferentiableRzlFAD_SftADcfU_'
45+
sil hidden @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) {
46+
bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float):
47+
return undef : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector
48+
}
13249

133-
// static AdditiveArithmetic<>.zero.getter
134-
sil [serialized] [always_inline] @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0
50+
sil hidden @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) {
51+
bb0(%0 : $*τ_0_0, %1 : $*τ_0_0, %2 : $Float):
52+
return undef : $@callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float) // id: %5
53+
}
13554

136-
sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 {
137-
jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
138-
vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
55+
sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 {
56+
jvp: @AD__generic__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
57+
vjp: @AD__generic__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
13958
}
14059

141-
// CHECK-LABEL: // differentiability witness for foo
142-
// CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 {
143-
// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
144-
// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
60+
// CHECK-LABEL: // differentiability witness for generic
61+
// CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 {
62+
// CHECK: jvp: @AD__generic__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
63+
// CHECK: vjp: @AD__generic__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
14564
// CHECK: }

0 commit comments

Comments
 (0)