@@ -110,3 +110,51 @@ public struct Foo: Differentiable {
110
110
// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooVSfycig : $@convention(method) (Foo) -> Float {
111
111
// CHECK-NEXT: }
112
112
}
113
+
114
+ // Test function that is differentiable wrt subset of its parameters:
115
+ // - wrt x: explicit @differentiable attribute, with no custom derivative specified
116
+ // - wrt y: explicit @differentiable attribute, with custom derivative specified
117
+ // - wrt x, y: custom deriviative specified, with no explicit @differentiable attribute
118
+ // Has a tuple argument to verify that indices are correctly lowered to SIL.
119
+
120
+ @differentiable ( wrt: x)
121
+ @differentiable ( wrt: y)
122
+ public func wrt_subset( _ tup: ( Int , Int ) , _ x: Float , _ y: Float ) -> Float {
123
+ return 0
124
+ }
125
+
126
+ @differentiating ( wrt_subset, wrt: y)
127
+ public func wrt_subset_jvp_wrt_y( _ tup: ( Int , Int ) , _ x: Float , _ y: Float ) -> ( value: Float , differential: ( Float ) -> Float ) {
128
+ return ( 0 , { $0 } )
129
+ }
130
+
131
+ @differentiating ( wrt_subset, wrt: y)
132
+ public func wrt_subset_vjp_wrt_y( _ tup: ( Int , Int ) , _ x: Float , _ y: Float ) -> ( value: Float , pullback: ( Float ) -> Float ) {
133
+ return ( 0 , { $0 } )
134
+ }
135
+
136
+ @differentiating ( wrt_subset)
137
+ public func wrt_subset_jvp_wrt_x_y( _ tup: ( Int , Int ) , _ x: Float , _ y: Float ) -> ( value: Float , differential: ( Float , Float ) -> Float ) {
138
+ return ( 0 , { $0 + $1 } )
139
+ }
140
+
141
+ @differentiating ( wrt_subset)
142
+ public func wrt_subset_vjp_wrt_x_y( _ tup: ( Int , Int ) , _ x: Float , _ y: Float ) -> ( value: Float , pullback: ( Float ) -> ( Float , Float ) ) {
143
+ return ( 0 , { ( $0, $0) } )
144
+ }
145
+
146
+ // CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
147
+ // CHECK-NEXT: sil_differentiability_witness [parameters 2 3] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
148
+ // CHECK-NEXT: jvp:
149
+ // CHECK-NEXT: vjp:
150
+ // CHECK-NEXT: }
151
+
152
+ // CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
153
+ // CHECK-NEXT: sil_differentiability_witness [parameters 3] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
154
+ // CHECK-NEXT: jvp:
155
+ // CHECK-NEXT: vjp:
156
+ // CHECK-NEXT: }
157
+
158
+ // CHECK-LABEL: // differentiability witness for wrt_subset(_:_:_:)
159
+ // CHECK-NEXT: sil_differentiability_witness [parameters 2] [results 0] @$s36sil_differentiability_witness_silgen10wrt_subsetySfSi_Sit_S2ftF : $@convention(thin) (Int, Int, Float, Float) -> Float {
160
+ // CHECK-NEXT: }
0 commit comments