@@ -308,7 +308,8 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness,
308
308
// / witness.
309
309
// / - If requirement's `@differentiable` attributes are met, or if `result` is
310
310
// / not viable, returns `result`.
311
- // / - Otherwise, returns a `DifferentiableConflict` `RequirementMatch`.
311
+ // / - Otherwise, returns a "missing `@differentiable` attribute"
312
+ // / `RequirementMatch`.
312
313
// Note: the `result` argument is only necessary for using
313
314
// `RequirementMatch::WitnessSubstitutions`.
314
315
static RequirementMatch
@@ -384,15 +385,50 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
384
385
}
385
386
if (!foundExactConfig) {
386
387
bool success = false ;
387
- if (supersetConfig) {
388
- // If the witness has a "superset" derivative configuration, create an
389
- // implicit `@differentiable` attribute with the exact requirement
390
- // `@differentiable` attribute parameter indices.
388
+ // If no exact witness derivative configuration was found, check
389
+ // conditions for creating an implicit witness `@differentiable` attribute
390
+ // with the exact derivative configuration:
391
+ // - If the witness has a "superset" derivative configuration.
392
+ // - If the witness is less than public and is declared in the same file
393
+ // as the conformance.
394
+ // - `@differentiable` attributes are really only significant for public
395
+ // declarations: it improves usability to not require explicit
396
+ // `@differentiable` attributes for less-visible declarations.
397
+ bool createImplicitWitnessAttribute =
398
+ supersetConfig || witness->getFormalAccess () < AccessLevel::Public;
399
+ // If the witness has less-than-public visibility and is declared in a
400
+ // different file than the conformance, produce an error.
401
+ if (!supersetConfig && witness->getFormalAccess () < AccessLevel::Public &&
402
+ dc->getModuleScopeContext () !=
403
+ witness->getDeclContext ()->getModuleScopeContext ()) {
404
+ // FIXME(TF-1014): `@differentiable` attribute diagnostic does not
405
+ // appear if associated type inference is involved.
406
+ if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
407
+ return RequirementMatch (
408
+ getStandinForAccessor (vdWitness, AccessorKind::Get),
409
+ MatchKind::MissingDifferentiableAttr, reqDiffAttr);
410
+ } else {
411
+ return RequirementMatch (witness, MatchKind::MissingDifferentiableAttr,
412
+ reqDiffAttr);
413
+ }
414
+ }
415
+ if (createImplicitWitnessAttribute) {
416
+ auto derivativeGenSig = witnessAFD->getGenericSignature ();
417
+ if (supersetConfig)
418
+ derivativeGenSig = supersetConfig->derivativeGenericSignature ;
419
+ // Use source location of the witness declaration as the source location
420
+ // of the implicit `@differentiable` attribute.
391
421
auto *newAttr = DifferentiableAttr::create (
392
- witnessAFD, /* implicit*/ true , reqDiffAttr->AtLoc ,
393
- reqDiffAttr->getRange (), reqDiffAttr->isLinear (),
394
- reqDiffAttr->getParameterIndices (), /* jvp*/ None,
395
- /* vjp*/ None, supersetConfig->derivativeGenericSignature );
422
+ witnessAFD, /* implicit*/ true , witness->getLoc (), witness->getLoc (),
423
+ reqDiffAttr->isLinear (), reqDiffAttr->getParameterIndices (),
424
+ /* jvp*/ None, /* vjp*/ None, derivativeGenSig);
425
+ // If the implicit attribute is inherited from a protocol requirement's
426
+ // attribute, store the protocol requirement attribute's location for
427
+ // use in diagnostics.
428
+ if (witness->getFormalAccess () < AccessLevel::Public) {
429
+ newAttr->getImplicitlyInheritedDifferentiableAttrLocation (
430
+ reqDiffAttr->getLocation ());
431
+ }
396
432
auto insertion = ctx.DifferentiableAttrs .try_emplace (
397
433
{witnessAFD, newAttr->getParameterIndices ()}, newAttr);
398
434
// Valid `@differentiable` attributes are uniqued by original function
@@ -418,9 +454,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
418
454
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
419
455
return RequirementMatch (
420
456
getStandinForAccessor (vdWitness, AccessorKind::Get),
421
- MatchKind::DifferentiableConflict , reqDiffAttr);
457
+ MatchKind::MissingDifferentiableAttr , reqDiffAttr);
422
458
} else {
423
- return RequirementMatch (witness, MatchKind::DifferentiableConflict ,
459
+ return RequirementMatch (witness, MatchKind::MissingDifferentiableAttr ,
424
460
reqDiffAttr);
425
461
}
426
462
}
@@ -2318,14 +2354,15 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
2318
2354
case MatchKind::NonObjC:
2319
2355
diags.diagnose (match.Witness , diag::protocol_witness_not_objc);
2320
2356
break ;
2321
- case MatchKind::DifferentiableConflict: {
2357
+ case MatchKind::MissingDifferentiableAttr: {
2358
+ auto witness = match.Witness ;
2322
2359
// Emit a note and fix-it showing the missing requirement `@differentiable`
2323
2360
// attribute.
2324
2361
auto *reqAttr = cast<DifferentiableAttr>(match.UnmetAttribute );
2325
2362
assert (reqAttr);
2326
2363
// Omit printing `wrt:` clause if attribute's differentiability
2327
2364
// parameters match inferred differentiability parameters.
2328
- auto *original = cast<AbstractFunctionDecl>(match. Witness );
2365
+ auto *original = cast<AbstractFunctionDecl>(witness );
2329
2366
auto *whereClauseGenEnv =
2330
2367
reqAttr->getDerivativeGenericEnvironment (original);
2331
2368
auto *inferredParameters = TypeChecker::inferDifferentiabilityParameters (
@@ -2336,11 +2373,29 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
2336
2373
llvm::raw_string_ostream os (reqDiffAttrString);
2337
2374
reqAttr->print (os, req, omitWrtClause, /* omitDerivativeFunctions*/ true );
2338
2375
os.flush ();
2339
- diags
2340
- .diagnose (match.Witness ,
2341
- diag::protocol_witness_missing_differentiable_attr,
2342
- reqDiffAttrString)
2343
- .fixItInsert (match.Witness ->getStartLoc (), reqDiffAttrString + ' ' );
2376
+ // If the witness has less-than-public visibility and is declared in a
2377
+ // different file than the conformance, emit a specialized diagnostic.
2378
+ if (witness->getFormalAccess () < AccessLevel::Public &&
2379
+ conformance->getDeclContext ()->getModuleScopeContext () !=
2380
+ witness->getDeclContext ()->getModuleScopeContext ()) {
2381
+ diags
2382
+ .diagnose (
2383
+ witness,
2384
+ diag::
2385
+ protocol_witness_missing_differentiable_attr_nonpublic_other_file,
2386
+ reqDiffAttrString, witness->getDescriptiveKind (),
2387
+ witness->getFullName (), req->getDescriptiveKind (),
2388
+ req->getFullName (), conformance->getType (),
2389
+ conformance->getProtocol ()->getDeclaredInterfaceType ())
2390
+ .fixItInsert (match.Witness ->getStartLoc (), reqDiffAttrString + ' ' );
2391
+ }
2392
+ // Otherwise, emit a general "missing attribute" diagnostic.
2393
+ else {
2394
+ diags
2395
+ .diagnose (witness, diag::protocol_witness_missing_differentiable_attr,
2396
+ reqDiffAttrString)
2397
+ .fixItInsert (witness->getStartLoc (), reqDiffAttrString + ' ' );
2398
+ }
2344
2399
break ;
2345
2400
}
2346
2401
}
0 commit comments