Skip to content

Commit 23756a8

Browse files
authored
---
yaml --- r: 340987 b: refs/heads/rxwei-patch-1 c: 714810c h: refs/heads/master i: 340985: a50cf9b 340983: e2365fb
1 parent cce25ad commit 23756a8

File tree

4 files changed

+76
-31
lines changed

4 files changed

+76
-31
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-08-18-a: b10b1fce14385faa6d44f6b933e95
10151015
refs/heads/rdar-43033749-fix-batch-mode-no-diags-swift-5.0-branch: a14e64eaad30de89f0f5f0b2a782eed7ecdcb255
10161016
refs/heads/revert-19006-error-bridging-integer-type: 8a9065a3696535305ea53fe9b71f91cbe6702019
10171017
refs/heads/revert-19050-revert-19006-error-bridging-integer-type: ecf752d54b05dd0a20f510f0bfa54a3fec3bcaca
1018-
refs/heads/rxwei-patch-1: fe1b8482eb0448684e3e7f1a2f915cc0ca5386b9
1018+
refs/heads/rxwei-patch-1: 714810c438d9fa02a5389d4863cbe3ccbc0fe3de
10191019
refs/heads/shahmishal-patch-1: e58ec0f7488258d42bef51bc3e6d7b3dc74d7b2a
10201020
refs/heads/typelist-existential: 4046359efd541fb5c72d69a92eefc0a784df8f5e
10211021
refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-08-20-a: 4319ba09e4fb8650ee86061075c74a016b6baab9

branches/rxwei-patch-1/include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2784,8 +2784,12 @@ ERROR(differentiable_attr_invalid_access,none,
27842784
ERROR(differentiable_attr_result_not_differentiable,none,
27852785
"can only differentiate functions with results that conform to "
27862786
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
2787-
ERROR(differentiable_attr_protocol_where_clause,none,
2788-
"'where' clauses cannot be used in a '@differentiable' attribute on a protocol requirement", ())
2787+
ERROR(differentiable_attr_protocol_req_where_clause,none,
2788+
"'@differentiable' attribute on protocol requirement cannot specify "
2789+
"'where' clause", ())
2790+
ERROR(differentiable_attr_protocol_req_assoc_func,none,
2791+
"'@differentiable' attribute on protocol requirement cannot specify "
2792+
"'jvp:' or 'vjp:'", ())
27892793
ERROR(differentiable_attr_empty_where_clause,none,
27902794
"empty 'where' clause in '@differentiable' attribute", ())
27912795
ERROR(differentiable_attr_nongeneric_trailing_where,none,

branches/rxwei-patch-1/lib/Sema/TypeCheckAttr.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3267,6 +3267,12 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
32673267
// Start type-checking the arguments of the @differentiable attribute. This
32683268
// covers 'wrt:', 'jvp:', 'vjp:', and 'where', all of which are optional.
32693269

3270+
// `@differentiable` attributes on protocol requirements do not support
3271+
// JVP/VJP or 'where' clauses.
3272+
bool isOriginalProtocolRequirement =
3273+
isa<ProtocolDecl>(original->getDeclContext()) &&
3274+
original->isProtocolRequirement();
3275+
32703276
// Handle 'where' clause, if it exists.
32713277
// - Resolve attribute where clause requirements and store in the attribute
32723278
// for serialization.
@@ -3276,12 +3282,11 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
32763282
GenericSignature *whereClauseGenSig = nullptr;
32773283
GenericEnvironment *whereClauseGenEnv = nullptr;
32783284
if (auto whereClause = attr->getWhereClause()) {
3279-
// 'where' clauses in '@differentiable' attributes of protocol
3280-
// requirements are not supported.
3281-
if (isa<ProtocolDecl>(original->getDeclContext()) &&
3282-
original->isProtocolRequirement()) {
3285+
// `@differentiable` attributes on protocol requirements do not support
3286+
// 'where' clauses.
3287+
if (isOriginalProtocolRequirement) {
32833288
TC.diagnose(attr->getLocation(),
3284-
diag::differentiable_attr_protocol_where_clause);
3289+
diag::differentiable_attr_protocol_req_where_clause);
32853290
attr->setInvalid();
32863291
return;
32873292
}
@@ -3398,6 +3403,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
33983403
return;
33993404
}
34003405

3406+
// `@differentiable` attributes on protocol requirements do not support
3407+
// JVP/VJP.
3408+
if (isOriginalProtocolRequirement && (attr->getJVP() || attr->getVJP())) {
3409+
TC.diagnose(attr->getLocation(),
3410+
diag::differentiable_attr_protocol_req_assoc_func);
3411+
attr->setInvalid();
3412+
return;
3413+
}
3414+
34013415
// Checks that the `candidate` function type equals the `required` function
34023416
// type, disregarding parameter labels and tuple result labels.
34033417
std::function<bool(CanAnyFunctionType, CanType)> checkFunctionSignature;

branches/rxwei-patch-1/test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ func invalidRequirementLayout<Scalar>(x: Scalar) -> Scalar {
612612
return x
613613
}
614614

615-
protocol DifferentiableAttrRequirements : Differentiable {
615+
protocol ProtocolRequirements : Differentiable {
616616
// expected-note @+2 {{protocol requires initializer 'init(x:y:)' with type '(x: Float, y: Float)'}}
617617
@differentiable
618618
init(x: Float, y: Float)
@@ -638,8 +638,13 @@ protocol DifferentiableAttrRequirements : Differentiable {
638638
func f2(_ x: Float, _ y: Float) -> Float
639639
}
640640

641-
// expected-error @+1 {{does not conform to protocol 'DifferentiableAttrRequirements'}}
642-
struct DiffAttrConformanceErrors : DifferentiableAttrRequirements {
641+
protocol ProtocolRequirementsRefined : ProtocolRequirements {
642+
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
643+
func f1(_ x: Float) -> Float
644+
}
645+
646+
// expected-error @+1 {{does not conform to protocol 'ProtocolRequirements'}}
647+
struct DiffAttrConformanceErrors : ProtocolRequirements {
643648
var x: Float
644649
var y: Float
645650

@@ -683,6 +688,31 @@ struct DiffAttrConformanceErrors : DifferentiableAttrRequirements {
683688
}
684689
}
685690

691+
protocol ProtocolRequirementsWithDefault_NoConformingTypes {
692+
@differentiable
693+
func f1(_ x: Float) -> Float
694+
}
695+
extension ProtocolRequirementsWithDefault_NoConformingTypes {
696+
// TODO(TF-650): It would be nice to diagnose protocol default implementation
697+
// with missing `@differentiable` attribute.
698+
func f1(_ x: Float) -> Float { x }
699+
}
700+
701+
protocol ProtocolRequirementsWithDefault {
702+
// expected-note @+2 {{protocol requires function 'f1'}}
703+
@differentiable
704+
func f1(_ x: Float) -> Float
705+
}
706+
extension ProtocolRequirementsWithDefault {
707+
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
708+
func f1(_ x: Float) -> Float { x }
709+
}
710+
// expected-error @+1 {{type 'DiffAttrConformanceErrors2' does not conform to protocol 'ProtocolRequirementsWithDefault'}}
711+
struct DiffAttrConformanceErrors2 : ProtocolRequirementsWithDefault {
712+
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
713+
func f1(_ x: Float) -> Float { x }
714+
}
715+
686716
protocol NotRefiningDiffable {
687717
@differentiable(wrt: x)
688718
// expected-note @+1 {{protocol requires function 'a' with type '(Float) -> Float'; do you want to add a stub?}}
@@ -846,17 +876,7 @@ func inout2(x: Float, y: inout Float) -> Float {
846876
let _ = x + y
847877
}
848878

849-
850-
// Missing `@differentiable` attribute, without printing the 'wrt' arguments.
851-
852-
protocol DifferentiableWhereClause: Differentiable {
853-
associatedtype Scalar
854-
855-
@differentiable(where Scalar: Differentiable) // expected-error {{'where' clauses cannot be used in a '@differentiable' attribute on a protocol requirement}}
856-
func test(value: Scalar) -> Float
857-
}
858-
859-
// Missing a `@differentiable` attribute.
879+
// Test refining protocol requirements with `@differentiable` attribute.
860880

861881
public protocol Distribution {
862882
associatedtype Value
@@ -870,18 +890,25 @@ public protocol DifferentiableDistribution: Differentiable, Distribution {
870890

871891
public protocol MissingDifferentiableDistribution: DifferentiableDistribution
872892
where Value: Differentiable {
873-
func logProbability(of value: Value) -> Float // expected-note {{candidate is missing attribute '@differentiable(wrt: self)'}}
893+
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: self)'}}
894+
func logProbability(of value: Value) -> Float
874895
}
875896

876-
// Missing `@differentiable` attribute, without printing the 'wrt' arguments.
897+
// Test protocol requirement `@differentiable` attribute unsupported features.
877898

878-
protocol Example: Differentiable {
879-
associatedtype Scalar: Differentiable
899+
protocol ProtocolRequirementUnsupported : Differentiable {
900+
associatedtype Scalar
880901

881-
@differentiable
882-
func test(value: Scalar) -> Float
883-
}
902+
// expected-error @+1 {{'@differentiable' attribute on protocol requirement cannot specify 'where' clause}}
903+
@differentiable(where Scalar: Differentiable)
904+
func unsupportedWhereClause(value: Scalar) -> Float
884905

885-
protocol MissingDifferentiableTest: Example {
886-
func test(value: Scalar) -> Float // expected-note {{candidate is missing attribute '@differentiable'}}
906+
// expected-error @+1 {{'@differentiable' attribute on protocol requirement cannot specify 'jvp:' or 'vjp:'}}
907+
@differentiable(wrt: x, jvp: dfoo, vjp: dfoo)
908+
func unsupportedDerivatives(_ x: Float) -> Float
909+
}
910+
extension ProtocolRequirementUnsupported {
911+
func dfoo(_ x: Float) -> (Float, (Float) -> Float) {
912+
(x, { $0 })
913+
}
887914
}

0 commit comments

Comments
 (0)