Skip to content

Commit 07596cb

Browse files
authored
[AutoDiff upstream] Relax @differentiable for protocol witnesses. (#30629)
Previously, all witnesses of a `@differentiable` protocol requirement were required to have the same attribute (or one with superset parameter indices). However, this leads to many annotations on witnesses and is not ideal for usability. `@differentiable` attributes are really only significant on public witnesses, so that they are clearly `@differentiable` at a glance (in source code, interface files, and API documentation), without looking through protocol conformance hierarchies. Now, only *public* witnesses of `@differentiable` protocol requirements are required to have the same attribute (or one with superset parameter indices). For less-visible witnesses, an implicit `@differentiable` attribute is created with the same configuration as the requirement's. Resolves TF-1117. Upstreams #29771 from tensorflow branch.
1 parent 025cb9a commit 07596cb

File tree

5 files changed

+195
-71
lines changed

5 files changed

+195
-71
lines changed

include/swift/AST/Attr.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,6 +1709,13 @@ class DifferentiableAttr final
17091709
/// attribute's where clause requirements. This is set only if the attribute
17101710
/// has a where clause.
17111711
GenericSignature DerivativeGenericSignature;
1712+
/// The source location of the implicitly inherited protocol requirement
1713+
/// `@differentiable` attribute. Used for diagnostics, not serialized.
1714+
///
1715+
/// This is set during conformance type-checking, only for implicit
1716+
/// `@differentiable` attributes created for non-public protocol witnesses of
1717+
/// protocol requirements with `@differentiable` attributes.
1718+
SourceLoc ImplicitlyInheritedDifferentiableAttrLocation;
17121719

17131720
explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
17141721
SourceRange baseRange, bool linear,
@@ -1771,6 +1778,14 @@ class DifferentiableAttr final
17711778
DerivativeGenericSignature = derivativeGenSig;
17721779
}
17731780

1781+
SourceLoc getImplicitlyInheritedDifferentiableAttrLocation() const {
1782+
return ImplicitlyInheritedDifferentiableAttrLocation;
1783+
}
1784+
void getImplicitlyInheritedDifferentiableAttrLocation(SourceLoc loc) {
1785+
assert(isImplicit());
1786+
ImplicitlyInheritedDifferentiableAttrLocation = loc;
1787+
}
1788+
17741789
/// Get the derivative generic environment for the given `@differentiable`
17751790
/// attribute and original function.
17761791
GenericEnvironment *

include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2958,6 +2958,12 @@ ERROR(overriding_decl_missing_differentiable_attr,none,
29582958
"overriding declaration is missing attribute '%0'", (StringRef))
29592959
NOTE(protocol_witness_missing_differentiable_attr,none,
29602960
"candidate is missing attribute '%0'", (StringRef))
2961+
NOTE(protocol_witness_missing_differentiable_attr_nonpublic_other_file,none,
2962+
"non-public %1 %2 must have explicit '%0' attribute to satisfy "
2963+
"requirement %3 %4 (in protocol %6) because it is declared in a different "
2964+
"file than the conformance of %5 to %6",
2965+
(StringRef, DescriptiveDeclKind, DeclName, DescriptiveDeclKind, DeclName,
2966+
Type, Type))
29612967

29622968
// @derivative
29632969
ERROR(derivative_attr_expected_result_tuple,none,

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness,
308308
/// witness.
309309
/// - If requirement's `@differentiable` attributes are met, or if `result` is
310310
/// not viable, returns `result`.
311-
/// - Otherwise, returns a `DifferentiableConflict` `RequirementMatch`.
311+
/// - Otherwise, returns a "missing `@differentiable` attribute"
312+
/// `RequirementMatch`.
312313
// Note: the `result` argument is only necessary for using
313314
// `RequirementMatch::WitnessSubstitutions`.
314315
static RequirementMatch
@@ -384,15 +385,50 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
384385
}
385386
if (!foundExactConfig) {
386387
bool success = false;
387-
if (supersetConfig) {
388-
// If the witness has a "superset" derivative configuration, create an
389-
// implicit `@differentiable` attribute with the exact requirement
390-
// `@differentiable` attribute parameter indices.
388+
// If no exact witness derivative configuration was found, check
389+
// conditions for creating an implicit witness `@differentiable` attribute
390+
// with the exact derivative configuration:
391+
// - If the witness has a "superset" derivative configuration.
392+
// - If the witness is less than public and is declared in the same file
393+
// as the conformance.
394+
// - `@differentiable` attributes are really only significant for public
395+
// declarations: it improves usability to not require explicit
396+
// `@differentiable` attributes for less-visible declarations.
397+
bool createImplicitWitnessAttribute =
398+
supersetConfig || witness->getFormalAccess() < AccessLevel::Public;
399+
// If the witness has less-than-public visibility and is declared in a
400+
// different file than the conformance, produce an error.
401+
if (!supersetConfig && witness->getFormalAccess() < AccessLevel::Public &&
402+
dc->getModuleScopeContext() !=
403+
witness->getDeclContext()->getModuleScopeContext()) {
404+
// FIXME(TF-1014): `@differentiable` attribute diagnostic does not
405+
// appear if associated type inference is involved.
406+
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
407+
return RequirementMatch(
408+
getStandinForAccessor(vdWitness, AccessorKind::Get),
409+
MatchKind::MissingDifferentiableAttr, reqDiffAttr);
410+
} else {
411+
return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr,
412+
reqDiffAttr);
413+
}
414+
}
415+
if (createImplicitWitnessAttribute) {
416+
auto derivativeGenSig = witnessAFD->getGenericSignature();
417+
if (supersetConfig)
418+
derivativeGenSig = supersetConfig->derivativeGenericSignature;
419+
// Use source location of the witness declaration as the source location
420+
// of the implicit `@differentiable` attribute.
391421
auto *newAttr = DifferentiableAttr::create(
392-
witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc,
393-
reqDiffAttr->getRange(), reqDiffAttr->isLinear(),
394-
reqDiffAttr->getParameterIndices(),
395-
supersetConfig->derivativeGenericSignature);
422+
witnessAFD, /*implicit*/ true, witness->getLoc(), witness->getLoc(),
423+
reqDiffAttr->isLinear(), reqDiffAttr->getParameterIndices(),
424+
derivativeGenSig);
425+
// If the implicit attribute is inherited from a protocol requirement's
426+
// attribute, store the protocol requirement attribute's location for
427+
// use in diagnostics.
428+
if (witness->getFormalAccess() < AccessLevel::Public) {
429+
newAttr->getImplicitlyInheritedDifferentiableAttrLocation(
430+
reqDiffAttr->getLocation());
431+
}
396432
auto insertion = ctx.DifferentiableAttrs.try_emplace(
397433
{witnessAFD, newAttr->getParameterIndices()}, newAttr);
398434
// Valid `@differentiable` attributes are uniqued by original function
@@ -418,9 +454,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
418454
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
419455
return RequirementMatch(
420456
getStandinForAccessor(vdWitness, AccessorKind::Get),
421-
MatchKind::DifferentiableConflict, reqDiffAttr);
457+
MatchKind::MissingDifferentiableAttr, reqDiffAttr);
422458
} else {
423-
return RequirementMatch(witness, MatchKind::DifferentiableConflict,
459+
return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr,
424460
reqDiffAttr);
425461
}
426462
}
@@ -2318,14 +2354,15 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
23182354
case MatchKind::NonObjC:
23192355
diags.diagnose(match.Witness, diag::protocol_witness_not_objc);
23202356
break;
2321-
case MatchKind::DifferentiableConflict: {
2357+
case MatchKind::MissingDifferentiableAttr: {
2358+
auto *witness = match.Witness;
23222359
// Emit a note and fix-it showing the missing requirement `@differentiable`
23232360
// attribute.
23242361
auto *reqAttr = cast<DifferentiableAttr>(match.UnmetAttribute);
23252362
assert(reqAttr);
23262363
// Omit printing `wrt:` clause if attribute's differentiability
23272364
// parameters match inferred differentiability parameters.
2328-
auto *original = cast<AbstractFunctionDecl>(match.Witness);
2365+
auto *original = cast<AbstractFunctionDecl>(witness);
23292366
auto *whereClauseGenEnv =
23302367
reqAttr->getDerivativeGenericEnvironment(original);
23312368
auto *inferredParameters = TypeChecker::inferDifferentiabilityParameters(
@@ -2336,11 +2373,29 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
23362373
llvm::raw_string_ostream os(reqDiffAttrString);
23372374
reqAttr->print(os, req, omitWrtClause);
23382375
os.flush();
2339-
diags
2340-
.diagnose(match.Witness,
2341-
diag::protocol_witness_missing_differentiable_attr,
2342-
reqDiffAttrString)
2343-
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
2376+
// If the witness has less-than-public visibility and is declared in a
2377+
// different file than the conformance, emit a specialized diagnostic.
2378+
if (witness->getFormalAccess() < AccessLevel::Public &&
2379+
conformance->getDeclContext()->getModuleScopeContext() !=
2380+
witness->getDeclContext()->getModuleScopeContext()) {
2381+
diags
2382+
.diagnose(
2383+
witness,
2384+
diag::
2385+
protocol_witness_missing_differentiable_attr_nonpublic_other_file,
2386+
reqDiffAttrString, witness->getDescriptiveKind(),
2387+
witness->getFullName(), req->getDescriptiveKind(),
2388+
req->getFullName(), conformance->getType(),
2389+
conformance->getProtocol()->getDeclaredInterfaceType())
2390+
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
2391+
}
2392+
// Otherwise, emit a general "missing attribute" diagnostic.
2393+
else {
2394+
diags
2395+
.diagnose(witness, diag::protocol_witness_missing_differentiable_attr,
2396+
reqDiffAttrString)
2397+
.fixItInsert(witness->getStartLoc(), reqDiffAttrString + ' ');
2398+
}
23442399
break;
23452400
}
23462401
}

lib/Sema/TypeCheckProtocol.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,8 @@ enum class MatchKind : uint8_t {
209209
/// The witness is explicitly @nonobjc but the requirement is @objc.
210210
NonObjC,
211211

212-
/// The witness does not have a `@differentiable` attribute satisfying one
213-
/// from the requirement.
214-
DifferentiableConflict,
212+
/// The witness is missing a `@differentiable` attribute from the requirement.
213+
MissingDifferentiableAttr,
215214
};
216215

217216
/// Describes the kind of optional adjustment performed when
@@ -362,7 +361,7 @@ struct RequirementMatch {
362361
: Witness(witness), Kind(kind), WitnessType(), UnmetAttribute(attr),
363362
ReqEnv(None) {
364363
assert(!hasWitnessType() && "Should have witness type");
365-
assert(UnmetAttribute);
364+
assert(hasUnmetAttribute() && "Should have unmet attribute");
366365
}
367366

368367
RequirementMatch(ValueDecl *witness, MatchKind kind,
@@ -437,7 +436,7 @@ struct RequirementMatch {
437436
case MatchKind::RethrowsConflict:
438437
case MatchKind::ThrowsConflict:
439438
case MatchKind::NonObjC:
440-
case MatchKind::DifferentiableConflict:
439+
case MatchKind::MissingDifferentiableAttr:
441440
return false;
442441
}
443442

@@ -467,7 +466,7 @@ struct RequirementMatch {
467466
case MatchKind::RethrowsConflict:
468467
case MatchKind::ThrowsConflict:
469468
case MatchKind::NonObjC:
470-
case MatchKind::DifferentiableConflict:
469+
case MatchKind::MissingDifferentiableAttr:
471470
return false;
472471
}
473472

@@ -478,7 +477,9 @@ struct RequirementMatch {
478477
bool hasRequirement() { return Kind == MatchKind::MissingRequirement; }
479478

480479
/// Determine whether this requirement match has an unmet attribute.
481-
bool hasUnmetAttribute() { return Kind == MatchKind::DifferentiableConflict; }
480+
bool hasUnmetAttribute() {
481+
return Kind == MatchKind::MissingDifferentiableAttr;
482+
}
482483

483484
swift::Witness getWitness(ASTContext &ctx) const;
484485
};

0 commit comments

Comments
 (0)