Skip to content

Commit 3880378

Browse files
author
marcrasi
authored
[AutoDiff] differentiability witness wrt params test (#28389)
1 parent 9c79811 commit 3880378

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

test/AutoDiff/sil_differentiability_witness_silgen.swift

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,51 @@ public struct Foo: Differentiable {
110110
// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooVSfycig : $@convention(method) (Foo) -> Float {
111111
// CHECK-NEXT: }
112112
}
113+
114+
// Test function that is differentiable wrt subset of its parameters:
115+
// - wrt x: explicit @differentiable attribute, with no custom derivative specified
116+
// - wrt y: explicit @differentiable attribute, with custom derivative specified
117+
// - wrt x, y: custom deriviative specified, with no explicit @differentiable attribute
118+
// Has a tuple argument to verify that indices are correctly lowered to SIL.
119+
120+
@differentiable(wrt: x)
121+
@differentiable(wrt: y)
122+
public func wrt_subset(_ tup: (Int, Int), _ x: Float, _ y: Float) -> Float {
123+
return 0
124+
}
125+
126+
@differentiating(wrt_subset, wrt: y)
127+
public func wrt_subset_jvp_wrt_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, differential: (Float) -> Float) {
128+
return (0, { $0 })
129+
}
130+
131+
@differentiating(wrt_subset, wrt: y)
132+
public func wrt_subset_vjp_wrt_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, pullback: (Float) -> Float) {
133+
return (0, { $0 })
134+
}
135+
136+
@differentiating(wrt_subset)
137+
public func wrt_subset_jvp_wrt_x_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, differential: (Float, Float) -> Float) {
138+
return (0, { $0 + $1 })
139+
}
140+
141+
@differentiating(wrt_subset)
142+
public func wrt_subset_vjp_wrt_x_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
143+
return (0, { ($0, $0) })
144+
}
145+
146+
// CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
147+
// CHECK-NEXT: sil_differentiability_witness [parameters 2 3] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
148+
// CHECK-NEXT: jvp:
149+
// CHECK-NEXT: vjp:
150+
// CHECK-NEXT: }
151+
152+
// CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
153+
// CHECK-NEXT: sil_differentiability_witness [parameters 3] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
154+
// CHECK-NEXT: jvp:
155+
// CHECK-NEXT: vjp:
156+
// CHECK-NEXT: }
157+
158+
// CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
159+
// CHECK-NEXT: sil_differentiability_witness [parameters 2] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
160+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)