@@ -612,7 +612,7 @@ func invalidRequirementLayout<Scalar>(x: Scalar) -> Scalar {
612
612
return x
613
613
}
614
614
615
- protocol DifferentiableAttrRequirements : Differentiable {
615
+ protocol ProtocolRequirements : Differentiable {
616
616
// expected-note @+2 {{protocol requires initializer 'init(x:y:)' with type '(x: Float, y: Float)'}}
617
617
@differentiable
618
618
init ( x: Float , y: Float )
@@ -638,8 +638,13 @@ protocol DifferentiableAttrRequirements : Differentiable {
638
638
func f2( _ x: Float , _ y: Float ) -> Float
639
639
}
640
640
641
- // expected-error @+1 {{does not conform to protocol 'DifferentiableAttrRequirements'}}
642
- struct DiffAttrConformanceErrors : DifferentiableAttrRequirements {
641
+ protocol ProtocolRequirementsRefined : ProtocolRequirements {
642
+ // expected-note @+1 {{candidate is missing attribute '@differentiable'}}
643
+ func f1( _ x: Float ) -> Float
644
+ }
645
+
646
+ // expected-error @+1 {{does not conform to protocol 'ProtocolRequirements'}}
647
+ struct DiffAttrConformanceErrors : ProtocolRequirements {
643
648
var x : Float
644
649
var y : Float
645
650
@@ -683,6 +688,31 @@ struct DiffAttrConformanceErrors : DifferentiableAttrRequirements {
683
688
}
684
689
}
685
690
691
+ protocol ProtocolRequirementsWithDefault_NoConformingTypes {
692
+ @differentiable
693
+ func f1( _ x: Float ) -> Float
694
+ }
695
+ extension ProtocolRequirementsWithDefault_NoConformingTypes {
696
+ // TODO(TF-650): It would be nice to diagnose protocol default implementation
697
+ // with missing `@differentiable` attribute.
698
+ func f1( _ x: Float ) -> Float { x }
699
+ }
700
+
701
+ protocol ProtocolRequirementsWithDefault {
702
+ // expected-note @+2 {{protocol requires function 'f1'}}
703
+ @differentiable
704
+ func f1( _ x: Float ) -> Float
705
+ }
706
+ extension ProtocolRequirementsWithDefault {
707
+ // expected-note @+1 {{candidate is missing attribute '@differentiable'}}
708
+ func f1( _ x: Float ) -> Float { x }
709
+ }
710
+ // expected-error @+1 {{type 'DiffAttrConformanceErrors2' does not conform to protocol 'ProtocolRequirementsWithDefault'}}
711
+ struct DiffAttrConformanceErrors2 : ProtocolRequirementsWithDefault {
712
+ // expected-note @+1 {{candidate is missing attribute '@differentiable'}}
713
+ func f1( _ x: Float ) -> Float { x }
714
+ }
715
+
686
716
protocol NotRefiningDiffable {
687
717
@differentiable ( wrt: x)
688
718
// expected-note @+1 {{protocol requires function 'a' with type '(Float) -> Float'; do you want to add a stub?}}
@@ -846,17 +876,7 @@ func inout2(x: Float, y: inout Float) -> Float {
846
876
let _ = x + y
847
877
}
848
878
849
-
850
- // Missing `@differentiable` attribute, without printing the 'wrt' arguments.
851
-
852
- protocol DifferentiableWhereClause : Differentiable {
853
- associatedtype Scalar
854
-
855
- @differentiable ( where Scalar: Differentiable) // expected-error {{'where' clauses cannot be used in a '@differentiable' attribute on a protocol requirement}}
856
- func test( value: Scalar ) -> Float
857
- }
858
-
859
- // Missing a `@differentiable` attribute.
879
+ // Test refining protocol requirements with `@differentiable` attribute.
860
880
861
881
public protocol Distribution {
862
882
associatedtype Value
@@ -870,18 +890,25 @@ public protocol DifferentiableDistribution: Differentiable, Distribution {
870
890
871
891
public protocol MissingDifferentiableDistribution : DifferentiableDistribution
872
892
where Value: Differentiable {
873
- func logProbability( of value: Value ) -> Float // expected-note {{candidate is missing attribute '@differentiable(wrt: self)'}}
893
+ // expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: self)'}}
894
+ func logProbability( of value: Value ) -> Float
874
895
}
875
896
876
- // Missing `@differentiable` attribute, without printing the 'wrt' arguments .
897
+ // Test protocol requirement `@differentiable` attribute unsupported features .
877
898
878
- protocol Example : Differentiable {
879
- associatedtype Scalar : Differentiable
899
+ protocol ProtocolRequirementUnsupported : Differentiable {
900
+ associatedtype Scalar
880
901
881
- @ differentiable
882
- func test ( value : Scalar ) -> Float
883
- }
902
+ // expected-error @+1 {{'@ differentiable' attribute on protocol requirement cannot specify 'where' clause}}
903
+ @ differentiable ( where Scalar: Differentiable )
904
+ func unsupportedWhereClause ( value : Scalar ) -> Float
884
905
885
- protocol MissingDifferentiableTest : Example {
886
- func test( value: Scalar ) -> Float // expected-note {{candidate is missing attribute '@differentiable'}}
906
+ // expected-error @+1 {{'@differentiable' attribute on protocol requirement cannot specify 'jvp:' or 'vjp:'}}
907
+ @differentiable ( wrt: x, jvp: dfoo, vjp: dfoo)
908
+ func unsupportedDerivatives( _ x: Float ) -> Float
909
+ }
910
+ extension ProtocolRequirementUnsupported {
911
+ func dfoo( _ x: Float ) -> ( Float , ( Float ) -> Float ) {
912
+ ( x, { $0 } )
913
+ }
887
914
}
0 commit comments