Skip to content

Commit 513d666

Browse files
committed
Merge remote-tracking branch 'github/tensorflow' into HEAD
2 parents d76188c + e531493 commit 513d666

17 files changed

+428
-107
lines changed

include/swift/AST/Attr.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,13 @@ class DifferentiableAttr final
17201720
/// attribute's where clause requirements. This is set only if the attribute
17211721
/// has a where clause.
17221722
GenericSignature DerivativeGenericSignature;
1723+
/// The source location of the implicitly inherited protocol requirement
1724+
/// `@differentiable` attribute. Used for diagnostics, not serialized.
1725+
///
1726+
/// This is set during conformance type-checking, only for implicit
1727+
/// `@differentiable` attributes created for non-public protocol witnesses of
1728+
/// protocol requirements with `@differentiable` attributes.
1729+
SourceLoc ImplicitlyInheritedDifferentiableAttrLocation;
17231730

17241731
explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
17251732
SourceRange baseRange, bool linear,
@@ -1805,6 +1812,14 @@ class DifferentiableAttr final
18051812
FuncDecl *getVJPFunction() const { return VJPFunction; }
18061813
void setVJPFunction(FuncDecl *decl);
18071814

1815+
SourceLoc getImplicitlyInheritedDifferentiableAttrLocation() const {
1816+
return ImplicitlyInheritedDifferentiableAttrLocation;
1817+
}
1818+
void getImplicitlyInheritedDifferentiableAttrLocation(SourceLoc loc) {
1819+
assert(isImplicit());
1820+
ImplicitlyInheritedDifferentiableAttrLocation = loc;
1821+
}
1822+
18081823
/// Get the derivative generic environment for the given `@differentiable`
18091824
/// attribute and original function.
18101825
GenericEnvironment *

include/swift/AST/DiagnosticsSIL.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,9 @@ NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
540540
"cannot differentiate through multiple results", ())
541541
NOTE(autodiff_class_member_not_supported,none,
542542
"differentiating class members is not yet supported", ())
543+
NOTE(autodiff_implicitly_inherited_differentiable_attr_here,none,
544+
"differentiability required by the corresponding protocol requirement "
545+
"here", ())
543546
// TODO(TF-642): Remove when `partial_apply` works with `@differentiable`
544547
// functions.
545548
NOTE(autodiff_cannot_param_subset_thunk_partially_applied_orig_fn,none,

include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2995,6 +2995,12 @@ ERROR(overriding_decl_missing_differentiable_attr,none,
29952995
"overriding declaration is missing attribute '%0'", (StringRef))
29962996
NOTE(protocol_witness_missing_differentiable_attr,none,
29972997
"candidate is missing attribute '%0'", (StringRef))
2998+
NOTE(protocol_witness_missing_differentiable_attr_nonpublic_other_file,none,
2999+
"non-public %1 %2 must have explicit '%0' attribute to satisfy "
3000+
"requirement %3 %4 (in protocol %6) because it is declared in a different "
3001+
"file than the conformance of %5 to %6",
3002+
(StringRef, DescriptiveDeclKind, DeclName, DescriptiveDeclKind, DeclName,
3003+
Type, Type))
29983004

29993005
// @derivative
30003006
ERROR(derivative_attr_expected_result_tuple,none,

include/swift/SILOptimizer/Utils/Differentiation/ADContext.h

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -311,20 +311,41 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc,
311311
return diagnose(loc, diag, std::forward<U>(args)...);
312312
}
313313

314-
// For `SILDifferentiabilityWitness`es, try to find an AST function
315-
// declaration and `@differentiable` attribute. If they are found, emit an
316-
// error on the `@differentiable` attribute; otherwise, emit an error on the
317-
// SIL function. Emit a note at the non-differentiable operation.
314+
// For differentiability witnesses: try to find a `@differentiable` or
315+
// `@derivative` attribute. If an attribute is found, emit an error on it;
316+
// otherwise, emit an error on the original function.
318317
case DifferentiationInvoker::Kind::SILDifferentiabilityWitnessInvoker: {
319318
auto *witness = invoker.getSILDifferentiabilityWitnessInvoker();
320319
auto *original = witness->getOriginalFunction();
321-
if (auto *diffAttr = witness->getAttribute()) {
322-
diagnose(diffAttr->getLocation(),
320+
// If the witness has an associated attribute, emit an error at its
321+
// location.
322+
if (auto *attr = witness->getAttribute()) {
323+
diagnose(attr->getLocation(),
323324
diag::autodiff_function_not_differentiable_error)
324-
.highlight(diffAttr->getRangeWithAt());
325-
diagnose(original->getLocation().getSourceLoc(),
326-
diag::autodiff_when_differentiating_function_definition);
327-
} else {
325+
.highlight(attr->getRangeWithAt());
326+
// Emit informative note.
327+
bool emittedNote = false;
328+
// If the witness comes from an implicit `@differentiable` attribute
329+
// inherited from a protocol requirement's `@differentiable` attribute,
330+
// emit a note on the inherited attribute.
331+
if (auto *diffAttr = dyn_cast<DifferentiableAttr>(attr)) {
332+
auto inheritedAttrLoc =
333+
diffAttr->getImplicitlyInheritedDifferentiableAttrLocation();
334+
if (inheritedAttrLoc.isValid()) {
335+
diagnose(inheritedAttrLoc,
336+
diag::autodiff_implicitly_inherited_differentiable_attr_here)
337+
.highlight(inheritedAttrLoc);
338+
emittedNote = true;
339+
}
340+
}
341+
// Otherwise, emit a note on the original function.
342+
if (!emittedNote) {
343+
diagnose(original->getLocation().getSourceLoc(),
344+
diag::autodiff_when_differentiating_function_definition);
345+
}
346+
}
347+
// Otherwise, emit an error on the original function.
348+
else {
328349
diagnose(original->getLocation().getSourceLoc(),
329350
diag::autodiff_function_not_differentiable_error);
330351
}

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness,
308308
/// witness.
309309
/// - If requirement's `@differentiable` attributes are met, or if `result` is
310310
/// not viable, returns `result`.
311-
/// - Otherwise, returns a `DifferentiableConflict` `RequirementMatch`.
311+
/// - Otherwise, returns a "missing `@differentiable` attribute"
312+
/// `RequirementMatch`.
312313
// Note: the `result` argument is only necessary for using
313314
// `RequirementMatch::WitnessSubstitutions`.
314315
static RequirementMatch
@@ -384,15 +385,50 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
384385
}
385386
if (!foundExactConfig) {
386387
bool success = false;
387-
if (supersetConfig) {
388-
// If the witness has a "superset" derivative configuration, create an
389-
// implicit `@differentiable` attribute with the exact requirement
390-
// `@differentiable` attribute parameter indices.
388+
// If no exact witness derivative configuration was found, check
389+
// conditions for creating an implicit witness `@differentiable` attribute
390+
// with the exact derivative configuration:
391+
// - If the witness has a "superset" derivative configuration.
392+
// - If the witness is less than public and is declared in the same file
393+
// as the conformance.
394+
// - `@differentiable` attributes are really only significant for public
395+
// declarations: it improves usability to not require explicit
396+
// `@differentiable` attributes for less-visible declarations.
397+
bool createImplicitWitnessAttribute =
398+
supersetConfig || witness->getFormalAccess() < AccessLevel::Public;
399+
// If the witness has less-than-public visibility and is declared in a
400+
// different file than the conformance, produce an error.
401+
if (!supersetConfig && witness->getFormalAccess() < AccessLevel::Public &&
402+
dc->getModuleScopeContext() !=
403+
witness->getDeclContext()->getModuleScopeContext()) {
404+
// FIXME(TF-1014): `@differentiable` attribute diagnostic does not
405+
// appear if associated type inference is involved.
406+
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
407+
return RequirementMatch(
408+
getStandinForAccessor(vdWitness, AccessorKind::Get),
409+
MatchKind::MissingDifferentiableAttr, reqDiffAttr);
410+
} else {
411+
return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr,
412+
reqDiffAttr);
413+
}
414+
}
415+
if (createImplicitWitnessAttribute) {
416+
auto derivativeGenSig = witnessAFD->getGenericSignature();
417+
if (supersetConfig)
418+
derivativeGenSig = supersetConfig->derivativeGenericSignature;
419+
// Use source location of the witness declaration as the source location
420+
// of the implicit `@differentiable` attribute.
391421
auto *newAttr = DifferentiableAttr::create(
392-
witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc,
393-
reqDiffAttr->getRange(), reqDiffAttr->isLinear(),
394-
reqDiffAttr->getParameterIndices(), /*jvp*/ None,
395-
/*vjp*/ None, supersetConfig->derivativeGenericSignature);
422+
witnessAFD, /*implicit*/ true, witness->getLoc(), witness->getLoc(),
423+
reqDiffAttr->isLinear(), reqDiffAttr->getParameterIndices(),
424+
/*jvp*/ None, /*vjp*/ None, derivativeGenSig);
425+
// If the implicit attribute is inherited from a protocol requirement's
426+
// attribute, store the protocol requirement attribute's location for
427+
// use in diagnostics.
428+
if (witness->getFormalAccess() < AccessLevel::Public) {
429+
newAttr->getImplicitlyInheritedDifferentiableAttrLocation(
430+
reqDiffAttr->getLocation());
431+
}
396432
auto insertion = ctx.DifferentiableAttrs.try_emplace(
397433
{witnessAFD, newAttr->getParameterIndices()}, newAttr);
398434
// Valid `@differentiable` attributes are uniqued by original function
@@ -418,9 +454,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
418454
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
419455
return RequirementMatch(
420456
getStandinForAccessor(vdWitness, AccessorKind::Get),
421-
MatchKind::DifferentiableConflict, reqDiffAttr);
457+
MatchKind::MissingDifferentiableAttr, reqDiffAttr);
422458
} else {
423-
return RequirementMatch(witness, MatchKind::DifferentiableConflict,
459+
return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr,
424460
reqDiffAttr);
425461
}
426462
}
@@ -2318,14 +2354,15 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
23182354
case MatchKind::NonObjC:
23192355
diags.diagnose(match.Witness, diag::protocol_witness_not_objc);
23202356
break;
2321-
case MatchKind::DifferentiableConflict: {
2357+
case MatchKind::MissingDifferentiableAttr: {
2358+
auto witness = match.Witness;
23222359
// Emit a note and fix-it showing the missing requirement `@differentiable`
23232360
// attribute.
23242361
auto *reqAttr = cast<DifferentiableAttr>(match.UnmetAttribute);
23252362
assert(reqAttr);
23262363
// Omit printing `wrt:` clause if attribute's differentiability
23272364
// parameters match inferred differentiability parameters.
2328-
auto *original = cast<AbstractFunctionDecl>(match.Witness);
2365+
auto *original = cast<AbstractFunctionDecl>(witness);
23292366
auto *whereClauseGenEnv =
23302367
reqAttr->getDerivativeGenericEnvironment(original);
23312368
auto *inferredParameters = TypeChecker::inferDifferentiabilityParameters(
@@ -2336,11 +2373,29 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
23362373
llvm::raw_string_ostream os(reqDiffAttrString);
23372374
reqAttr->print(os, req, omitWrtClause, /*omitDerivativeFunctions*/ true);
23382375
os.flush();
2339-
diags
2340-
.diagnose(match.Witness,
2341-
diag::protocol_witness_missing_differentiable_attr,
2342-
reqDiffAttrString)
2343-
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
2376+
// If the witness has less-than-public visibility and is declared in a
2377+
// different file than the conformance, emit a specialized diagnostic.
2378+
if (witness->getFormalAccess() < AccessLevel::Public &&
2379+
conformance->getDeclContext()->getModuleScopeContext() !=
2380+
witness->getDeclContext()->getModuleScopeContext()) {
2381+
diags
2382+
.diagnose(
2383+
witness,
2384+
diag::
2385+
protocol_witness_missing_differentiable_attr_nonpublic_other_file,
2386+
reqDiffAttrString, witness->getDescriptiveKind(),
2387+
witness->getFullName(), req->getDescriptiveKind(),
2388+
req->getFullName(), conformance->getType(),
2389+
conformance->getProtocol()->getDeclaredInterfaceType())
2390+
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
2391+
}
2392+
// Otherwise, emit a general "missing attribute" diagnostic.
2393+
else {
2394+
diags
2395+
.diagnose(witness, diag::protocol_witness_missing_differentiable_attr,
2396+
reqDiffAttrString)
2397+
.fixItInsert(witness->getStartLoc(), reqDiffAttrString + ' ');
2398+
}
23442399
break;
23452400
}
23462401
}

lib/Sema/TypeCheckProtocol.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,8 @@ enum class MatchKind : uint8_t {
210210
/// The witness is explicitly @nonobjc but the requirement is @objc.
211211
NonObjC,
212212

213-
/// The witness does not have a `@differentiable` attribute satisfying one
214-
/// from the requirement.
215-
DifferentiableConflict,
213+
/// The witness is missing a `@differentiable` attribute from the requirement.
214+
MissingDifferentiableAttr,
216215
};
217216

218217
/// Describes the kind of optional adjustment performed when
@@ -363,7 +362,7 @@ struct RequirementMatch {
363362
: Witness(witness), Kind(kind), WitnessType(), UnmetAttribute(attr),
364363
ReqEnv(None) {
365364
assert(!hasWitnessType() && "Should have witness type");
366-
assert(UnmetAttribute);
365+
assert(hasUnmetAttribute() && "Should have unmet attribute");
367366
}
368367

369368
RequirementMatch(ValueDecl *witness, MatchKind kind,
@@ -438,7 +437,7 @@ struct RequirementMatch {
438437
case MatchKind::RethrowsConflict:
439438
case MatchKind::ThrowsConflict:
440439
case MatchKind::NonObjC:
441-
case MatchKind::DifferentiableConflict:
440+
case MatchKind::MissingDifferentiableAttr:
442441
return false;
443442
}
444443

@@ -468,7 +467,7 @@ struct RequirementMatch {
468467
case MatchKind::RethrowsConflict:
469468
case MatchKind::ThrowsConflict:
470469
case MatchKind::NonObjC:
471-
case MatchKind::DifferentiableConflict:
470+
case MatchKind::MissingDifferentiableAttr:
472471
return false;
473472
}
474473

@@ -479,7 +478,9 @@ struct RequirementMatch {
479478
bool hasRequirement() { return Kind == MatchKind::MissingRequirement; }
480479

481480
/// Determine whether this requirement match has an unmet attribute.
482-
bool hasUnmetAttribute() { return Kind == MatchKind::DifferentiableConflict; }
481+
bool hasUnmetAttribute() {
482+
return Kind == MatchKind::MissingDifferentiableAttr;
483+
}
483484

484485
swift::Witness getWitness(ASTContext &ctx) const;
485486
};

0 commit comments

Comments
 (0)