@@ -5214,49 +5214,6 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
5214
5214
return originalType;
5215
5215
}
5216
5216
5217
- // / Given a `@differentiable` attribute, attempts to resolve the original
5218
- // / `AbstractFunctionDecl` for which it is registered, using the declaration
5219
- // / on which it is actually declared. On error, emits diagnostic and returns
5220
- // / `nullptr`.
5221
- AbstractFunctionDecl *
5222
- resolveDifferentiableAttrOriginalFunction (DifferentiableAttr *attr) {
5223
- auto *D = attr->getOriginalDeclaration ();
5224
- assert (D &&
5225
- " Original declaration should be resolved by parsing/deserialization" );
5226
- auto *original = dyn_cast<AbstractFunctionDecl>(D);
5227
- if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) {
5228
- // If `@differentiable` attribute is declared directly on a
5229
- // `AbstractStorageDecl` (a stored/computed property or subscript),
5230
- // forward the attribute to the storage's getter.
5231
- // TODO(TF-129): Forward `@differentiable` attributes to setters after
5232
- // differentiation supports inout parameters.
5233
- // TODO(TF-1080): Forward `@differentiable` attributes to `read` and
5234
- // `modify` accessors after differentiation supports `inout` parameters.
5235
- if (!asd->getDeclContext ()->isModuleScopeContext ()) {
5236
- original = asd->getSynthesizedAccessor (AccessorKind::Get);
5237
- } else {
5238
- original = nullptr ;
5239
- }
5240
- }
5241
- // Non-`get` accessors are not yet supported: `set`, `read`, and `modify`.
5242
- // TODO(TF-1080): Enable `read` and `modify` when differentiation supports
5243
- // coroutines.
5244
- if (auto *accessor = dyn_cast_or_null<AccessorDecl>(original))
5245
- if (!accessor->isGetter () && !accessor->isSetter ())
5246
- original = nullptr ;
5247
- // Diagnose if original `AbstractFunctionDecl` could not be resolved.
5248
- if (!original) {
5249
- diagnoseAndRemoveAttr (D, attr, diag::invalid_decl_attribute, attr);
5250
- attr->setInvalid ();
5251
- return nullptr ;
5252
- }
5253
- // If the original function has an error interface type, return.
5254
- // A diagnostic should have already been emitted.
5255
- if (original->getInterfaceType ()->hasError ())
5256
- return nullptr ;
5257
- return original;
5258
- }
5259
-
5260
5217
// / Given a `@differentiable` attribute, attempts to resolve the derivative
5261
5218
// / generic signature. The derivative generic signature is returned as
5262
5219
// / `derivativeGenSig`. On error, emits diagnostic, assigns `nullptr` to
@@ -5435,11 +5392,11 @@ bool resolveDifferentiableAttrDifferentiabilityParameters(
5435
5392
5436
5393
// / Checks whether differentiable programming is enabled for the given
5437
5394
// / differentiation-related attribute. Returns true on error.
5438
- bool checkIfDifferentiableProgrammingEnabled (ASTContext &ctx ,
5439
- DeclAttribute *attr,
5440
- DeclContext *DC) {
5395
+ static bool checkIfDifferentiableProgrammingEnabled (DeclAttribute *attr ,
5396
+ Decl *D) {
5397
+ auto &ctx = D-> getASTContext ();
5441
5398
auto &diags = ctx.Diags ;
5442
- auto *SF = DC ->getParentSourceFile ();
5399
+ auto *SF = D-> getDeclContext () ->getParentSourceFile ();
5443
5400
assert (SF && " Source file not found" );
5444
5401
// The `Differentiable` protocol must be available.
5445
5402
// If unavailable, the `_Differentiation` module should be imported.
@@ -5452,31 +5409,36 @@ bool checkIfDifferentiableProgrammingEnabled(ASTContext &ctx,
5452
5409
return true ;
5453
5410
}
5454
5411
5455
- IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate (
5456
- Evaluator &evaluator, DifferentiableAttr *attr) const {
5457
- // Skip type-checking for implicit `@differentiable` attributes. We currently
5458
- // assume that all implicit `@differentiable` attributes are valid.
5459
- //
5460
- // Motivation: some implicit attributes do not have a `where` clause, and this
5461
- // function assumes that the `where` clauses exist. Propagating `where`
5462
- // clauses and requirements consistently is a larger problem, to be revisited.
5463
- if (attr->isImplicit ())
5464
- return nullptr ;
5412
+ static IndexSubset *
5413
+ resolveDiffParamIndices (AbstractFunctionDecl *original,
5414
+ DifferentiableAttr *attr,
5415
+ GenericSignature derivativeGenSig) {
5416
+ auto *derivativeGenEnv = derivativeGenSig.getGenericEnvironment ();
5465
5417
5466
- auto *D = attr->getOriginalDeclaration ();
5467
- auto &ctx = D->getASTContext ();
5468
- auto &diags = ctx.Diags ;
5469
- // `@differentiable` attribute requires experimental differentiable
5470
- // programming to be enabled.
5471
- if (checkIfDifferentiableProgrammingEnabled (ctx, attr, D->getDeclContext ()))
5472
- return nullptr ;
5418
+ // Compute the derivative function type.
5419
+ auto originalFnRemappedTy = original->getInterfaceType ()->castTo <AnyFunctionType>();
5420
+ if (derivativeGenEnv)
5421
+ originalFnRemappedTy =
5422
+ derivativeGenEnv->mapTypeIntoContext (originalFnRemappedTy)
5423
+ ->castTo <AnyFunctionType>();
5473
5424
5474
- // Resolve the original `AbstractFunctionDecl`.
5475
- auto *original = resolveDifferentiableAttrOriginalFunction (attr);
5476
- if (!original)
5425
+ // Resolve and validate the differentiability parameters.
5426
+ IndexSubset *resolvedDiffParamIndices = nullptr ;
5427
+ if (resolveDifferentiableAttrDifferentiabilityParameters (
5428
+ attr, original, originalFnRemappedTy, derivativeGenEnv,
5429
+ resolvedDiffParamIndices))
5477
5430
return nullptr ;
5478
5431
5479
- auto *originalFnTy = original->getInterfaceType ()->castTo <AnyFunctionType>();
5432
+ return resolvedDiffParamIndices;
5433
+ }
5434
+
5435
+
5436
+ static IndexSubset *
5437
+ typecheckDifferentiableAttrforDecl (AbstractFunctionDecl *original,
5438
+ DifferentiableAttr *attr,
5439
+ IndexSubset *resolvedDiffParamIndices = nullptr ) {
5440
+ auto &ctx = original->getASTContext ();
5441
+ auto &diags = ctx.Diags ;
5480
5442
5481
5443
// Diagnose if original function has opaque result types.
5482
5444
if (auto *opaqueResultTypeDecl = original->getOpaqueResultTypeDecl ()) {
@@ -5523,69 +5485,161 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
5523
5485
}
5524
5486
5525
5487
// Resolve the derivative generic signature.
5526
- GenericSignature derivativeGenSig = nullptr ;
5527
- if (resolveDifferentiableAttrDerivativeGenericSignature (attr, original,
5488
+ GenericSignature derivativeGenSig = attr->getDerivativeGenericSignature ();
5489
+ if (!derivativeGenSig &&
5490
+ resolveDifferentiableAttrDerivativeGenericSignature (attr, original,
5528
5491
derivativeGenSig))
5529
5492
return nullptr ;
5530
- auto *derivativeGenEnv = derivativeGenSig.getGenericEnvironment ();
5531
-
5532
- // Compute the derivative function type.
5533
- auto originalFnRemappedTy = originalFnTy;
5534
- if (derivativeGenEnv)
5535
- originalFnRemappedTy =
5536
- derivativeGenEnv->mapTypeIntoContext (originalFnRemappedTy)
5537
- ->castTo <AnyFunctionType>();
5538
5493
5539
5494
// Resolve and validate the differentiability parameters.
5540
- IndexSubset * resolvedDiffParamIndices = nullptr ;
5541
- if ( resolveDifferentiableAttrDifferentiabilityParameters (
5542
- attr, original, originalFnRemappedTy, derivativeGenEnv,
5543
- resolvedDiffParamIndices) )
5495
+ if (! resolvedDiffParamIndices)
5496
+ resolvedDiffParamIndices = resolveDiffParamIndices (original, attr,
5497
+ derivativeGenSig);
5498
+ if (! resolvedDiffParamIndices)
5544
5499
return nullptr ;
5545
5500
5546
- if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) {
5547
- // Remove `@differentiable` attribute from storage declaration to prevent
5548
- // duplicate attribute registration during SILGen.
5549
- D->getAttrs ().removeAttribute (attr);
5550
- // Transfer `@differentiable` attribute from storage declaration to
5551
- // getter accessor.
5552
- auto *getterDecl = asd->getOpaqueAccessor (AccessorKind::Get);
5553
- auto *newAttr = DifferentiableAttr::create (
5554
- getterDecl, /* implicit*/ true , attr->AtLoc , attr->getRange (),
5555
- attr->getDifferentiabilityKind (), resolvedDiffParamIndices,
5556
- attr->getDerivativeGenericSignature ());
5557
- auto insertion = ctx.DifferentiableAttrs .try_emplace (
5558
- {getterDecl, resolvedDiffParamIndices}, newAttr);
5559
- // Reject duplicate `@differentiable` attributes.
5560
- if (!insertion.second ) {
5561
- diagnoseAndRemoveAttr (D, attr, diag::differentiable_attr_duplicate);
5562
- diags.diagnose (insertion.first ->getSecond ()->getLocation (),
5563
- diag::differentiable_attr_duplicate_note);
5564
- return nullptr ;
5565
- }
5566
- getterDecl->getAttrs ().add (newAttr);
5567
- // Register derivative function configuration.
5568
- auto *resultIndices = IndexSubset::get (ctx, 1 , {0 });
5569
- getterDecl->addDerivativeFunctionConfiguration (
5570
- {resolvedDiffParamIndices, resultIndices, derivativeGenSig});
5571
- return resolvedDiffParamIndices;
5572
- }
5573
5501
// Reject duplicate `@differentiable` attributes.
5574
5502
auto insertion =
5575
- ctx.DifferentiableAttrs .try_emplace ({D , resolvedDiffParamIndices}, attr);
5503
+ ctx.DifferentiableAttrs .try_emplace ({original , resolvedDiffParamIndices}, attr);
5576
5504
if (!insertion.second && insertion.first ->getSecond () != attr) {
5577
- diagnoseAndRemoveAttr (D , attr, diag::differentiable_attr_duplicate);
5505
+ diagnoseAndRemoveAttr (original , attr, diag::differentiable_attr_duplicate);
5578
5506
diags.diagnose (insertion.first ->getSecond ()->getLocation (),
5579
5507
diag::differentiable_attr_duplicate_note);
5580
5508
return nullptr ;
5581
5509
}
5510
+
5582
5511
// Register derivative function configuration.
5583
5512
auto *resultIndices = IndexSubset::get (ctx, 1 , {0 });
5584
5513
original->addDerivativeFunctionConfiguration (
5585
5514
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
5586
5515
return resolvedDiffParamIndices;
5587
5516
}
5588
5517
5518
+ // / Given a `@differentiable` attribute, attempts to resolve the original
5519
+ // / `AbstractFunctionDecl` for which it is registered, using the declaration
5520
+ // / on which it is actually declared. On error, emits diagnostic and returns
5521
+ // / `nullptr`.
5522
+ static AbstractFunctionDecl *
5523
+ resolveDifferentiableAttrOriginalFunction (DifferentiableAttr *attr) {
5524
+ auto *D = attr->getOriginalDeclaration ();
5525
+ auto *original = dyn_cast<AbstractFunctionDecl>(D);
5526
+
5527
+ // Non-`get`/`set` accessors are not yet supported: `read`, and `modify`.
5528
+ // TODO(TF-1080): Enable `read` and `modify` when differentiation supports
5529
+ // coroutines.
5530
+ if (auto *accessor = dyn_cast_or_null<AccessorDecl>(original))
5531
+ if (!accessor->isGetter () && !accessor->isSetter ())
5532
+ original = nullptr ;
5533
+
5534
+ // Diagnose if original `AbstractFunctionDecl` could not be resolved.
5535
+ if (!original) {
5536
+ diagnoseAndRemoveAttr (D, attr, diag::invalid_decl_attribute, attr);
5537
+ attr->setInvalid ();
5538
+ return nullptr ;
5539
+ }
5540
+
5541
+ // If the original function has an error interface type, return.
5542
+ // A diagnostic should have already been emitted.
5543
+ if (original->getInterfaceType ()->hasError ())
5544
+ return nullptr ;
5545
+
5546
+ return original;
5547
+ }
5548
+
5549
+ static IndexSubset *
5550
+ resolveDifferentiableAccessors (DifferentiableAttr *attr,
5551
+ AbstractStorageDecl *asd) {
5552
+ auto typecheckAccessor = [&](AccessorDecl *ad) -> IndexSubset* {
5553
+ GenericSignature derivativeGenSig = nullptr ;
5554
+ if (resolveDifferentiableAttrDerivativeGenericSignature (attr, ad,
5555
+ derivativeGenSig))
5556
+ return nullptr ;
5557
+
5558
+ IndexSubset *resolvedDiffParamIndices = resolveDiffParamIndices (ad, attr,
5559
+ derivativeGenSig);
5560
+ if (!resolvedDiffParamIndices)
5561
+ return nullptr ;
5562
+
5563
+ auto *newAttr = DifferentiableAttr::create (
5564
+ ad, /* implicit*/ true , attr->AtLoc , attr->getRange (),
5565
+ attr->getDifferentiabilityKind (), resolvedDiffParamIndices,
5566
+ attr->getDerivativeGenericSignature ());
5567
+ ad->getAttrs ().add (newAttr);
5568
+
5569
+ if (!typecheckDifferentiableAttrforDecl (ad, attr,
5570
+ resolvedDiffParamIndices))
5571
+ return nullptr ;
5572
+
5573
+ return resolvedDiffParamIndices;
5574
+ };
5575
+
5576
+ // No getters / setters for global variables
5577
+ if (asd->getDeclContext ()->isModuleScopeContext ()) {
5578
+ diagnoseAndRemoveAttr (asd, attr, diag::invalid_decl_attribute, attr);
5579
+ attr->setInvalid ();
5580
+ return nullptr ;
5581
+ }
5582
+
5583
+ if (!typecheckAccessor (asd->getSynthesizedAccessor (AccessorKind::Get)))
5584
+ return nullptr ;
5585
+
5586
+ if (asd->supportsMutation ()) {
5587
+ // FIXME: Class-typed values have reference semantics and can be freely
5588
+ // mutated. Thus, they should be treated like inout parameters for the
5589
+ // purposes of @differentiable and @derivative type-checking. Until
5590
+ // https://github.com/apple/swift/issues/55542 is fixed, check if setter has
5591
+ // computed semantic results and do not typecheck if they are none
5592
+ // (class-typed `self' parameter is not treated as a "semantic result"
5593
+ // currently)
5594
+ if (!asd->getDeclContext ()->getSelfClassDecl ())
5595
+ if (!typecheckAccessor (asd->getSynthesizedAccessor (AccessorKind::Set)))
5596
+ return nullptr ;
5597
+ }
5598
+
5599
+ // Remove `@differentiable` attribute from storage declaration to prevent
5600
+ // duplicate attribute registration during SILGen.
5601
+ asd->getAttrs ().removeAttribute (attr);
5602
+
5603
+ // Here we are effectively removing attribute from original decl, therefore no
5604
+ // index subset for us
5605
+ return nullptr ;
5606
+ }
5607
+
5608
+
5609
+ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate (
5610
+ Evaluator &evaluator, DifferentiableAttr *attr) const {
5611
+ // Skip type-checking for implicit `@differentiable` attributes. We currently
5612
+ // assume that all implicit `@differentiable` attributes are valid.
5613
+ //
5614
+ // Motivation: some implicit attributes do not have a `where` clause, and this
5615
+ // function assumes that the `where` clauses exist. Propagating `where`
5616
+ // clauses and requirements consistently is a larger problem, to be revisited.
5617
+ if (attr->isImplicit ())
5618
+ return nullptr ;
5619
+
5620
+ auto *D = attr->getOriginalDeclaration ();
5621
+ assert (D &&
5622
+ " Original declaration should be resolved by parsing/deserialization" );
5623
+
5624
+ // `@differentiable` attribute requires experimental differentiable
5625
+ // programming to be enabled.
5626
+ if (checkIfDifferentiableProgrammingEnabled (attr, D))
5627
+ return nullptr ;
5628
+
5629
+ // If `@differentiable` attribute is declared directly on a
5630
+ // `AbstractStorageDecl` (a stored/computed property or subscript),
5631
+ // forward the attribute to the storage's getter / setter
5632
+ if (auto *asd = dyn_cast<AbstractStorageDecl>(D))
5633
+ return resolveDifferentiableAccessors (attr, asd);
5634
+
5635
+ // Resolve the original `AbstractFunctionDecl`.
5636
+ auto *original = resolveDifferentiableAttrOriginalFunction (attr);
5637
+ if (!original)
5638
+ return nullptr ;
5639
+
5640
+ return typecheckDifferentiableAttrforDecl (original, attr);
5641
+ }
5642
+
5589
5643
void AttributeChecker::visitDifferentiableAttr (DifferentiableAttr *attr) {
5590
5644
// Call `getParameterIndices` to trigger
5591
5645
// `DifferentiableAttributeTypeCheckRequest`.
@@ -5608,7 +5662,7 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
5608
5662
auto &diags = Ctx.Diags ;
5609
5663
// `@derivative` attribute requires experimental differentiable programming
5610
5664
// to be enabled.
5611
- if (checkIfDifferentiableProgrammingEnabled (Ctx, attr, D-> getDeclContext () ))
5665
+ if (checkIfDifferentiableProgrammingEnabled (attr, D))
5612
5666
return true ;
5613
5667
auto *derivative = cast<FuncDecl>(D);
5614
5668
auto originalName = attr->getOriginalFunctionName ();
0 commit comments