@@ -185,6 +185,11 @@ protocol StaticMethod: Differentiable {
185
185
static func generic< T: Differentiable > ( _ x: T ) -> T
186
186
}
187
187
188
+ extension StaticMethod {
189
+ static func foo( _ x: Float ) -> Float { x }
190
+ static func generic< T: Differentiable > ( _ x: T ) -> T { x }
191
+ }
192
+
188
193
extension StaticMethod {
189
194
@derivative ( of: foo)
190
195
static func jvpFoo( x: Float ) -> ( value: Float , differential: ( Float ) -> Float )
@@ -215,11 +220,16 @@ extension StaticMethod {
215
220
// Test instance methods.
216
221
217
222
protocol InstanceMethod : Differentiable {
218
- // expected-note @+1 {{'foo' defined here}}
219
223
func foo( _ x: Self ) -> Self
224
+ func generic< T: Differentiable > ( _ x: T ) -> Self
225
+ }
226
+
227
+ extension InstanceMethod {
228
+ // expected-note @+1 {{'foo' defined here}}
229
+ func foo( _ x: Self ) -> Self { x }
220
230
221
231
// expected-note @+1 {{'generic' defined here}}
222
- func generic< T: Differentiable > ( _ x: T ) -> Self
232
+ func generic< T: Differentiable > ( _ x: T ) -> Self { self }
223
233
}
224
234
225
235
extension InstanceMethod {
@@ -539,24 +549,14 @@ extension HasStoredProperty {
539
549
// Test cross-file derivative registration. Currently unsupported.
540
550
// TODO(TF-1021): Lift this restriction.
541
551
542
- extension AdditiveArithmetic where Self: Differentiable {
552
+ extension FloatingPoint where Self: Differentiable {
543
553
// expected-error @+1 {{derivative not in the same file as the original function}}
544
- @derivative ( of: + )
545
- static func vjpPlus ( x : Self , y : Self ) -> (
554
+ @derivative ( of: rounded )
555
+ func vjpRounded ( ) -> (
546
556
value: Self ,
547
- pullback: ( Self . TangentVector ) -> ( Self . TangentVector , Self . TangentVector )
548
- ) {
549
- return ( x + y, { v in ( v, v) } )
550
- }
551
- }
552
-
553
- extension FloatingPoint where Self: Differentiable , Self == Self . TangentVector {
554
- // expected-error @+1 {{derivative not in the same file as the original function}}
555
- @derivative ( of: + )
556
- static func vjpPlus( x: Self , y: Self ) -> (
557
- value: Self , pullback: ( Self ) -> ( Self , Self )
557
+ pullback: ( Self . TangentVector ) -> ( Self . TangentVector )
558
558
) {
559
- return ( x + y , { v in ( v , v ) } )
559
+ fatalError ( )
560
560
}
561
561
}
562
562
0 commit comments