Skip to content

Commit d93efc1

Browse files
authored
[AutoDiff] Add differentiability witness SILGen test. (#27717)
1 parent 76729c4 commit d93efc1

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// RUN: %target-swift-frontend -emit-silgen %s | %target-sil-opt | %FileCheck %s
2+
3+
// Test SIL differentiability witness SIL generation.
4+
5+
// Test public non-generic function.
6+
// SIL differentiability witness:
7+
// - Has public linkage (implicit).
8+
// - Has no `where` clause.
9+
10+
public func foo(_ x: Float) -> Float { x }
11+
12+
@differentiating(foo)
13+
public func foo_jvp(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
14+
(x, { $0 })
15+
}
16+
17+
@differentiating(foo)
18+
public func foo_vjp(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
19+
(x, { $0 })
20+
}
21+
22+
// CHECK-LABEL: // differentiability witness for foo(_:)
23+
// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3fooyS2fF : $@convention(thin) (Float) -> Float {
24+
// CHECK-NEXT: jvp: @AD__$s36sil_differentiability_witness_silgen3fooyS2fF__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
25+
// CHECK-NEXT: vjp: @AD__$s36sil_differentiability_witness_silgen3fooyS2fF__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
26+
// CHECK-NEXT: }
27+
28+
// Test internal non-generic function.
29+
// SIL differentiability witness:
30+
// - Has hidden linkage.
31+
// - Has no `where` clause.
32+
// - Has only VJP.
33+
34+
func bar<T>(_ x: Float, _ y: T) -> Float { x }
35+
36+
@differentiating(bar)
37+
public func bar_jvp<T>(_ x: Float, _ y: T) -> (value: Float, differential: (Float) -> Float) {
38+
(x, { $0 })
39+
}
40+
41+
// CHECK-LABEL: // differentiability witness for bar<A>(_:_:)
42+
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3baryS2f_xtlF : $@convention(thin) <T> (Float, @in_guaranteed T) -> Float {
43+
// CHECK-NEXT: jvp: @AD__$s36sil_differentiability_witness_silgen3baryS2f_xtlF__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float)
44+
// CHECK-NEXT: }
45+
46+
// Test internal generic function.
47+
// SIL differentiability witness:
48+
// - Has hidden linkage.
49+
// - Has `where` clause.
50+
51+
@differentiable(where T: Differentiable)
52+
func generic<T>(_ x: T, _ y: Float) -> T { x }
53+
54+
@differentiating(generic)
55+
func generic_jvp<T: Differentiable>(_ x: T, _ y: Float) -> (
56+
value: T, differential: (T.TangentVector, Float) -> T.TangentVector
57+
) {
58+
(x, { dx, dy in dx })
59+
}
60+
61+
@differentiating(generic)
62+
func generic_vjp<T: Differentiable>(_ x: T, _ y: Float) -> (
63+
value: T, pullback: (T.TangentVector) -> (T.TangentVector, Float)
64+
) {
65+
(x, { ($0, .zero) })
66+
}
67+
68+
// CHECK-LABEL: // differentiability witness for generic<A>(_:_:)
69+
// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where T : _Differentiable] @$s36sil_differentiability_witness_silgen7genericyxx_SftlF : $@convention(thin) <T> (@in_guaranteed T, Float) -> @out T {
70+
// CHECK-NEXT: jvp: @AD__$s36sil_differentiability_witness_silgen7genericyxx_SftlF__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
71+
// CHECK-NEXT: vjp: @AD__$s36sil_differentiability_witness_silgen7genericyxx_SftlF__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
72+
// CHECK-NEXT: }
73+
74+
public struct Foo: Differentiable {
75+
public var x: Float
76+
77+
@differentiable
78+
public init(_ x: Float) {
79+
self.x = x
80+
}
81+
82+
// CHECK-LABEL: // differentiability witness for Foo.init(_:)
83+
// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooVyACSfcfC : $@convention(method) (Float, @thin Foo.Type) -> Foo {
84+
// CHECK-NEXT: }
85+
86+
@differentiable
87+
public func method() -> Float {
88+
x
89+
}
90+
91+
// CHECK-LABEL: // differentiability witness for Foo.method()
92+
// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV6methodSfyF : $@convention(method) (Foo) -> Float {
93+
// CHECK-NEXT: }
94+
95+
@differentiable
96+
public var computedProperty: Float {
97+
x
98+
}
99+
100+
// CHECK-LABEL: // differentiability witness for Foo.computedProperty.getter
101+
// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV16computedPropertySfvg : $@convention(method) (Foo) -> Float {
102+
// CHECK-NEXT: }
103+
104+
@differentiable
105+
public subscript() -> Float {
106+
x
107+
}
108+
109+
// CHECK-LABEL: // differentiability witness for Foo.subscript.getter
110+
// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooVSfycig : $@convention(method) (Foo) -> Float {
111+
// CHECK-NEXT: }
112+
}

0 commit comments

Comments
 (0)