@@ -185,6 +185,11 @@ protocol StaticMethod: Differentiable {
185
185
static func generic< T: Differentiable > ( _ x: T ) -> T
186
186
}
187
187
188
+ extension StaticMethod {
189
+ static func foo( _ x: Float ) -> Float { x }
190
+ static func generic< T: Differentiable > ( _ x: T ) -> T { x }
191
+ }
192
+
188
193
extension StaticMethod {
189
194
@derivative ( of: foo)
190
195
static func jvpFoo( x: Float ) -> ( value: Float , differential: ( Float ) -> Float )
@@ -215,11 +220,16 @@ extension StaticMethod {
215
220
// Test instance methods.
216
221
217
222
protocol InstanceMethod : Differentiable {
218
- // expected-note @+1 {{'foo' defined here}}
219
223
func foo( _ x: Self ) -> Self
224
+ func generic< T: Differentiable > ( _ x: T ) -> Self
225
+ }
226
+
227
+ extension InstanceMethod {
228
+ // expected-note @+1 {{'foo' defined here}}
229
+ func foo( _ x: Self ) -> Self { x }
220
230
221
231
// expected-note @+1 {{'generic' defined here}}
222
- func generic< T: Differentiable > ( _ x: T ) -> Self
232
+ func generic< T: Differentiable > ( _ x: T ) -> Self { self }
223
233
}
224
234
225
235
extension InstanceMethod {
@@ -536,27 +546,34 @@ extension HasStoredProperty {
536
546
}
537
547
}
538
548
539
- // Test cross-file derivative registration. Currently unsupported.
540
- // TODO(TF-1021 ): Lift this restriction.
549
+ // Test derivative registration for protocol requirements . Currently unsupported.
550
+ // TODO(TF-982 ): Lift this restriction and add proper support .
541
551
542
- extension AdditiveArithmetic where Self: Differentiable {
543
- // expected-error @+1 {{derivative not in the same file as the original function}}
544
- @derivative ( of: + )
545
- static func vjpPlus( x: Self , y: Self ) -> (
546
- value: Self ,
547
- pullback: ( Self . TangentVector ) -> ( Self . TangentVector , Self . TangentVector )
548
- ) {
549
- return ( x + y, { v in ( v, v) } )
552
+ protocol ProtocolRequirementDerivative {
553
+ func requirement( _ x: Float ) -> Float
554
+ }
555
+ extension ProtocolRequirementDerivative {
556
+ // NOTE: the error is misleading because `findAbstractFunctionDecl` in
557
+ // TypeCheckAttr.cpp is not setup to show customized error messages for
558
+ // invalid original function candidates.
559
+ // expected-error @+1 {{could not find function 'requirement' with expected type '<Self where Self : ProtocolRequirementDerivative> (Self) -> (Float) -> Float'}}
560
+ @derivative ( of: requirement)
561
+ func vjpRequirement( _ x: Float ) -> ( value: Float , pullback: ( Float ) -> Float ) {
562
+ fatalError ( )
550
563
}
551
564
}
552
565
553
- extension FloatingPoint where Self: Differentiable , Self == Self . TangentVector {
566
+ // Test cross-file derivative registration. Currently unsupported.
567
+ // TODO(TF-1021): Lift this restriction.
568
+
569
+ extension FloatingPoint where Self: Differentiable {
554
570
// expected-error @+1 {{derivative not in the same file as the original function}}
555
- @derivative ( of: + )
556
- static func vjpPlus( x: Self , y: Self ) -> (
557
- value: Self , pullback: ( Self ) -> ( Self , Self )
571
+ @derivative ( of: rounded)
572
+ func vjpRounded( ) -> (
573
+ value: Self ,
574
+ pullback: ( Self . TangentVector ) -> ( Self . TangentVector )
558
575
) {
559
- return ( x + y , { v in ( v , v ) } )
576
+ fatalError ( )
560
577
}
561
578
}
562
579
0 commit comments