Skip to content

Commit e2365fb

Browse files
eaplataniosrxwei
authored andcommitted
---
yaml --- r: 340983 b: refs/heads/rxwei-patch-1 c: 43e6ad6 h: refs/heads/master i: 340981: 22a6926 340979: b881a59 340975: 43e8f65
1 parent 6364099 commit e2365fb

File tree

9 files changed

+253
-40
lines changed

9 files changed

+253
-40
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-08-18-a: b10b1fce14385faa6d44f6b933e95
10151015
refs/heads/rdar-43033749-fix-batch-mode-no-diags-swift-5.0-branch: a14e64eaad30de89f0f5f0b2a782eed7ecdcb255
10161016
refs/heads/revert-19006-error-bridging-integer-type: 8a9065a3696535305ea53fe9b71f91cbe6702019
10171017
refs/heads/revert-19050-revert-19006-error-bridging-integer-type: ecf752d54b05dd0a20f510f0bfa54a3fec3bcaca
1018-
refs/heads/rxwei-patch-1: 6b3d8742e82ca0b76100ad6a802ffcb8e2bf11ee
1018+
refs/heads/rxwei-patch-1: 43e6ad6b4f3681a2f8357d4bab01864716bd4513
10191019
refs/heads/shahmishal-patch-1: e58ec0f7488258d42bef51bc3e6d7b3dc74d7b2a
10201020
refs/heads/typelist-existential: 4046359efd541fb5c72d69a92eefc0a784df8f5e
10211021
refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-08-20-a: 4319ba09e4fb8650ee86061075c74a016b6baab9

branches/rxwei-patch-1/include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2784,6 +2784,8 @@ ERROR(differentiable_attr_invalid_access,none,
27842784
ERROR(differentiable_attr_result_not_differentiable,none,
27852785
"can only differentiate functions with results that conform to "
27862786
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
2787+
ERROR(differentiable_attr_protocol_where_clause,none,
2788+
"'where' clauses cannot be used in a '@differentiable' attribute on a protocol requirement", ())
27872789
ERROR(differentiable_attr_empty_where_clause,none,
27882790
"empty 'where' clause in '@differentiable' attribute", ())
27892791
ERROR(differentiable_attr_nongeneric_trailing_where,none,

branches/rxwei-patch-1/lib/AST/Attr.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,16 @@ static void printDifferentiableAttrArguments(
499499
stream << "vjp: " << vjp->Name;
500500
}
501501
// Print 'where' clause, if any.
502-
if (!attr->getRequirements().empty()) {
502+
// First, filter out requirements satisfied by the original function's
503+
// generic signature. They should not be printed.
504+
auto requirementsToPrint =
505+
makeFilterRange(attr->getRequirements(), [&](Requirement req) {
506+
if (auto *originalGenSig = original->getGenericSignature())
507+
if (originalGenSig->isRequirementSatisfied(req))
508+
return false;
509+
return true;
510+
});
511+
if (!requirementsToPrint.empty()) {
503512
if (!isLeadingClause)
504513
stream << ' ';
505514
stream << "where ";
@@ -515,15 +524,6 @@ static void printDifferentiableAttrArguments(
515524
return genericEnv->getSugaredType(Ty);
516525
};
517526
}
518-
// Filter out requirements satisfied by original function's generic
519-
// signature. They should not be printed.
520-
auto requirementsToPrint =
521-
makeFilterRange(attr->getRequirements(), [&](Requirement req) {
522-
if (auto *originalGenSig = original->getGenericSignature())
523-
if (originalGenSig->isRequirementSatisfied(req))
524-
return false;
525-
return true;
526-
});
527527
interleave(requirementsToPrint, [&](Requirement req) {
528528
if (auto *originalGenSig = original->getGenericSignature())
529529
if (originalGenSig->isRequirementSatisfied(req))

branches/rxwei-patch-1/lib/Sema/TypeCheckAttr.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3276,6 +3276,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
32763276
GenericSignature *whereClauseGenSig = nullptr;
32773277
GenericEnvironment *whereClauseGenEnv = nullptr;
32783278
if (auto whereClause = attr->getWhereClause()) {
3279+
// 'where' clauses in '@differentiable' attributes of protocol
3280+
// requirements are not supported.
3281+
if (isa<ProtocolDecl>(original->getDeclContext()) &&
3282+
original->isProtocolRequirement()) {
3283+
TC.diagnose(attr->getLocation(),
3284+
diag::differentiable_attr_protocol_where_clause);
3285+
attr->setInvalid();
3286+
return;
3287+
}
32793288
if (whereClause->getRequirements().empty()) {
32803289
// Where clause must not be empty.
32813290
TC.diagnose(attr->getLocation(),

branches/rxwei-patch-1/lib/Sema/TypeCheckDeclOverride.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,92 @@ static bool parameterTypesMatch(const ValueDecl *derivedDecl,
591591
return true;
592592
}
593593

594+
// SWIFT_ENABLE_TENSORFLOW
595+
static bool overridesDifferentiableAttribute(ValueDecl *derivedDecl,
596+
ValueDecl *baseDecl) {
597+
ASTContext &ctx = derivedDecl->getASTContext();
598+
auto &diags = ctx.Diags;
599+
600+
auto *derivedAFD = dyn_cast<AbstractFunctionDecl>(derivedDecl);
601+
auto *baseAFD = dyn_cast<AbstractFunctionDecl>(baseDecl);
602+
603+
if (!derivedAFD || !baseAFD)
604+
return false;
605+
606+
auto derivedDAs = derivedAFD->getAttrs().getAttributes<DifferentiableAttr>();
607+
auto baseDAs = baseAFD->getAttrs().getAttributes<DifferentiableAttr>();
608+
609+
// Make sure all the differentiable attributes in `baseDecl` are
610+
// also declared in `derivedDecl`.
611+
for (auto baseDA : baseDAs) {
612+
auto baseParameters = baseDA->getParameterIndices();
613+
auto defined = false;
614+
for (auto derivedDA : derivedDAs) {
615+
auto derivedParameters = derivedDA->getParameterIndices();
616+
if (derivedParameters &&
617+
baseParameters &&
618+
AutoDiffIndexSubset::get(
619+
ctx, baseParameters->parameters)
620+
->isSubsetOf(AutoDiffIndexSubset::get(
621+
ctx, derivedParameters->parameters))) {
622+
defined = true;
623+
break;
624+
}
625+
}
626+
if (!defined) {
627+
// Omit printing wrt clause if attribute differentiation parameters match
628+
// inferred differentiation parameters.
629+
auto *inferredParameters = TypeChecker::inferDifferentiableParameters(
630+
derivedAFD, nullptr);
631+
bool omitWrtClause = !baseParameters ||
632+
baseParameters->parameters.count() ==
633+
inferredParameters->parameters.count();
634+
// Get `@differentiable` attribute description.
635+
std::string baseDAString;
636+
llvm::raw_string_ostream stream(baseDAString);
637+
baseDA->print(stream, derivedDecl, omitWrtClause);
638+
diags.diagnose(
639+
derivedDecl,
640+
diag::protocol_witness_missing_differentiable_attr,
641+
StringRef(stream.str()).trim());
642+
return false;
643+
}
644+
}
645+
646+
// If there is no differentiable attribute in `derivedDecl`, then
647+
// overriding is not allowed.
648+
if (derivedDAs.empty())
649+
return false;
650+
651+
// Finally, go through all differentiable attributes in
652+
// `derivedDecl` and check if they subsume any of the
653+
// differentiable attributes in `baseDecl`.
654+
for (auto derivedDA : derivedDAs) {
655+
auto derivedParameters = derivedDA->getParameterIndices();
656+
auto overrides = true;
657+
for (auto baseDA : baseDAs) {
658+
auto baseParameters = baseDA->getParameterIndices();
659+
// If the differentiable indices of `derivedDA` are a
660+
// subset of those of `baseDA`, then `baseDA` subsumes
661+
// `derivedDA` and the function is marked as overridden.
662+
if (derivedParameters &&
663+
baseParameters &&
664+
AutoDiffIndexSubset::get(
665+
ctx, derivedParameters->parameters)
666+
->isSubsetOf(AutoDiffIndexSubset::get(
667+
ctx, baseParameters->parameters))) {
668+
overrides = false;
669+
break;
670+
}
671+
}
672+
if (overrides)
673+
return true;
674+
}
675+
676+
return false;
677+
}
678+
// SWIFT_ENABLE_TENSORFLOW END
679+
594680
/// Returns true if the given declaration is for the `NSObject.hashValue`
595681
/// property.
596682
static bool isNSObjectHashValue(ValueDecl *baseDecl) {
@@ -746,6 +832,12 @@ SmallVector<OverrideMatch, 2> OverrideMatcher::match(
746832
if (!areOverrideCompatibleSimple(decl, parentDecl))
747833
continue;
748834

835+
// SWIFT_ENABLE_TENSORFLOW
836+
// Check whether the differentiable attribute allows overriding.
837+
if (overridesDifferentiableAttribute(decl, parentDecl))
838+
continue;
839+
// SWIFT_ENABLE_TENSORFLOW END
840+
749841
auto parentMethod = dyn_cast<AbstractFunctionDecl>(parentDecl);
750842
auto parentStorage = dyn_cast<AbstractStorageDecl>(parentDecl);
751843
assert(parentMethod || parentStorage);

branches/rxwei-patch-1/lib/Sema/TypeCheckProtocol.cpp

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,11 @@ swift::matchWitness(
569569
// SWIFT_ENABLE_TENSORFLOW
570570
auto result = finalize(anyRenaming, optionalAdjustments);
571571
if (result.isViable()) {
572-
// '@differentiable' attributes must match completely.
572+
// '@differentiable' attributes must match completely. If there exists a
573+
// '@differentiable' attribute with a superset of the "wrt" parameters of
574+
// a requirement, then an '@differentiable' attribute is added
575+
// automatically.
576+
ASTContext &ctx = witness->getASTContext();
573577
for (auto *reqDiffAttr : reqAttrs.getAttributes<DifferentiableAttr>()) {
574578
auto witnessDiffAttrs = witnessAttrs
575579
.getAttributes<DifferentiableAttr, /*AllowInvalid*/ true>();
@@ -579,14 +583,44 @@ swift::matchWitness(
579583
reqDiffAttr->getParameterIndices() &&
580584
witnessDiffAttr->parametersMatch(*reqDiffAttr);
581585
});
586+
bool reqDiffAttrSupersetMatch = llvm::any_of(
587+
witnessDiffAttrs, [&](const DifferentiableAttr *witnessDiffAttr) {
588+
return witnessDiffAttr->getParameterIndices() &&
589+
reqDiffAttr->getParameterIndices() &&
590+
AutoDiffIndexSubset::get(
591+
ctx, witnessDiffAttr->getParameterIndices()->parameters)
592+
->isSupersetOf(AutoDiffIndexSubset::get(
593+
ctx, reqDiffAttr->getParameterIndices()->parameters));
594+
});
582595
if (!reqDiffAttrMatch) {
583-
if (auto *vdWitness = dyn_cast<VarDecl>(witness))
584-
return RequirementMatch(
585-
getStandinForAccessor(vdWitness, AccessorKind::Get),
586-
MatchKind::DifferentiableConflict, reqDiffAttr);
587-
else
588-
return RequirementMatch(witness, MatchKind::DifferentiableConflict,
589-
reqDiffAttr);
596+
auto implicitDiffAttr = false;
597+
if (reqDiffAttrSupersetMatch) {
598+
auto *newAttr = DifferentiableAttr::create(
599+
ctx, /*implicit*/ true, reqDiffAttr->AtLoc,
600+
reqDiffAttr->getRange(), reqDiffAttr->isLinear(),
601+
reqDiffAttr->getParameterIndices(), /*jvp*/ None,
602+
/*vjp*/ None, reqDiffAttr->getRequirements());
603+
auto insertion = ctx.DifferentiableAttrs.try_emplace(
604+
{witness, newAttr->getParameterIndices()}, newAttr);
605+
// Valid `@differentiable` attributes are uniqued by their parameter
606+
// indices. Reject duplicate attributes for the same decl and parameter
607+
// indices pair.
608+
if (!insertion.second) {
609+
newAttr->setInvalid();
610+
} else {
611+
witness->getAttrs().add(newAttr);
612+
implicitDiffAttr = true;
613+
}
614+
}
615+
if (!implicitDiffAttr) {
616+
if (auto *vdWitness = dyn_cast<VarDecl>(witness))
617+
return RequirementMatch(
618+
getStandinForAccessor(vdWitness, AccessorKind::Get),
619+
MatchKind::DifferentiableConflict, reqDiffAttr);
620+
else
621+
return RequirementMatch(witness, MatchKind::DifferentiableConflict,
622+
reqDiffAttr);
623+
}
590624
}
591625
}
592626
}

branches/rxwei-patch-1/stdlib/public/core/AutoDiff.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ public extension VectorProtocol {
100100
}
101101
}
102102

103-
/* Note: These default-implemented opreators will slow down type-checking
103+
/* Note: These default-implemented operators will slow down type-checking
104104
performance and break existing code.
105105

106106
public extension VectorProtocol {

branches/rxwei-patch-1/test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -588,17 +588,6 @@ extension FloatingPoint {
588588
}
589589
}
590590

591-
protocol MethodDiffReq {
592-
@differentiable(wrt: self, vjp: vjpFoo where Self : Differentiable)
593-
func foo() -> Self
594-
}
595-
596-
extension MethodDiffReq where Self : Differentiable {
597-
func vjpFoo(x: Self) -> (Self, (Self.TangentVector) -> Self.TangentVector) {
598-
return (self, { $0 })
599-
}
600-
}
601-
602591
// expected-error @+1 {{'vjpNonvariadic' does not have expected type '(Float, Int32...) -> (Float, (Float.TangentVector) -> Float.TangentVector)' (aka '(Float, Int32...) -> (Float, (Float) -> Float)')}}
603592
@differentiable(wrt: x, vjp: vjpNonvariadic)
604593
func variadic(_ x: Float, indices: Int32...) -> Float {
@@ -647,10 +636,6 @@ protocol DifferentiableAttrRequirements : Differentiable {
647636
// expected-note @+2 {{protocol requires function 'f2'}}
648637
@differentiable(wrt: (self, x, y))
649638
func f2(_ x: Float, _ y: Float) -> Float
650-
651-
// expected-note @+2 {{protocol requires function 'generic'}}
652-
@differentiable(where T : Differentiable)
653-
func generic<T>(_ x: T) -> T
654639
}
655640

656641
// expected-error @+1 {{does not conform to protocol 'DifferentiableAttrRequirements'}}
@@ -696,11 +681,6 @@ struct DiffAttrConformanceErrors : DifferentiableAttrRequirements {
696681
func f2(_ x: Float, _ y: Float) -> Float {
697682
return x + y
698683
}
699-
700-
// expected-note @+1 {{candidate is missing attribute '@differentiable(where T : Differentiable)'}}
701-
func generic<T>(_ x: T) -> T {
702-
return x
703-
}
704684
}
705685

706686
protocol NotRefiningDiffable {
@@ -865,3 +845,43 @@ func inout1(x: Float, y: inout Float) -> Void {
865845
func inout2(x: Float, y: inout Float) -> Float {
866846
let _ = x + y
867847
}
848+
849+
850+
// Missing `@differentiable` attribute, without printing the 'wrt' arguments.
851+
852+
protocol DifferentiableWhereClause: Differentiable {
853+
associatedtype Scalar
854+
855+
@differentiable(where Scalar: Differentiable) // expected-error {{'where' clauses cannot be used in a '@differentiable' attribute on a protocol requirement}}
856+
func test(value: Scalar) -> Float
857+
}
858+
859+
// Missing a `@differentiable` attribute.
860+
861+
public protocol Distribution {
862+
associatedtype Value
863+
func logProbability(of value: Value) -> Float
864+
}
865+
866+
public protocol DifferentiableDistribution: Differentiable, Distribution {
867+
@differentiable(wrt: self)
868+
func logProbability(of value: Value) -> Float
869+
}
870+
871+
public protocol MissingDifferentiableDistribution: DifferentiableDistribution
872+
where Value: Differentiable {
873+
func logProbability(of value: Value) -> Float // expected-note {{candidate is missing attribute '@differentiable(wrt: self)'}}
874+
}
875+
876+
// Missing `@differentiable` attribute, without printing the 'wrt' arguments.
877+
878+
protocol Example: Differentiable {
879+
associatedtype Scalar: Differentiable
880+
881+
@differentiable
882+
func test(value: Scalar) -> Float
883+
}
884+
885+
protocol MissingDifferentiableTest: Example {
886+
func test(value: Scalar) -> Float // expected-note {{candidate is missing attribute '@differentiable'}}
887+
}

branches/rxwei-patch-1/test/AutoDiff/protocol_requirement_autodiff.swift

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,60 @@ struct S : P {
137137
}
138138
}
139139

140+
// MARK: - Overridden protocol method adding differentiable attribute.
141+
142+
public protocol Distribution {
143+
associatedtype Value
144+
func logProbability(of value: Value) -> Float
145+
}
146+
147+
public protocol DifferentiableDistribution: Differentiable, Distribution {
148+
@differentiable(wrt: self)
149+
func logProbability(of value: Value) -> Float
150+
}
151+
152+
struct Foo: DifferentiableDistribution {
153+
@differentiable(wrt: self)
154+
func logProbability(of value: Float) -> Float {
155+
.zero
156+
}
157+
}
158+
159+
@differentiable
160+
func blah<T: DifferentiableDistribution>(_ x: T) -> Float where T.Value: AdditiveArithmetic {
161+
x.logProbability(of: .zero)
162+
}
163+
164+
// Adding a more general `@differentiable` attribute.
165+
public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
166+
where Value: Differentiable {
167+
@differentiable(wrt: self)
168+
@differentiable(wrt: (self, value))
169+
func logProbability(of value: Value) -> Float
170+
}
171+
172+
@differentiable
173+
func blah2<T: DoubleDifferentiableDistribution>(_ x: T, _ value: T.Value) -> Float
174+
where T.Value: AdditiveArithmetic {
175+
x.logProbability(of: value)
176+
}
177+
178+
protocol DifferentiableFoo {
179+
associatedtype T: Differentiable
180+
@differentiable(wrt: x)
181+
func foo(_ x: T) -> Float
182+
}
183+
184+
protocol MoreDifferentiableFoo: Differentiable, DifferentiableFoo {
185+
@differentiable(wrt: (self, x))
186+
func foo(_ x: T) -> Float
187+
}
188+
189+
struct MoreDifferentiableFooStruct: MoreDifferentiableFoo {
190+
@differentiable(wrt: (self, x))
191+
func foo(_ x: Float) -> Float {
192+
x
193+
}
194+
}
195+
140196
runAllTests()

0 commit comments

Comments
 (0)