@@ -64,6 +64,9 @@ static SourceRange getRangeableSourceRange(const Rangeable *const p) {
64
64
static SourceRange getRangeableSourceRange (const SpecializeAttr *a) {
65
65
return a->getRange ();
66
66
}
67
+ static SourceRange getRangeableSourceRange (const DifferentiableAttr *a) {
68
+ return a->getRange ();
69
+ }
67
70
static SourceRange getRangeableSourceRange (const ASTNode n) {
68
71
return n.getSourceRange ();
69
72
}
@@ -98,6 +101,17 @@ static void dumpRangeable(SpecializeAttr *r, llvm::raw_ostream &f) {
98
101
llvm::errs () << " SpecializeAttr\n " ;
99
102
}
100
103
104
+ static void dumpRangeable (const DifferentiableAttr *a,
105
+ llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED;
106
+ static void dumpRangeable (const DifferentiableAttr *a, llvm::raw_ostream &f) {
107
+ llvm::errs () << " DifferentiableAttr\n " ;
108
+ }
109
+ static void dumpRangeable (DifferentiableAttr *a,
110
+ llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED;
111
+ static void dumpRangeable (DifferentiableAttr *a, llvm::raw_ostream &f) {
112
+ llvm::errs () << " DifferentiableAttr\n " ;
113
+ }
114
+
101
115
// / For Debugging
102
116
template <typename T>
103
117
bool doesRangeableRangeMatch (const T *x, const SourceManager &SM,
@@ -439,6 +453,22 @@ class ScopeCreator final {
439
453
fn (specializeAttr);
440
454
}
441
455
456
+ void forEachDifferentiableAttrInSourceOrder (
457
+ Decl *decl, function_ref<void (DifferentiableAttr *)> fn) {
458
+ std::vector<DifferentiableAttr *> sortedDifferentiableAttrs;
459
+ for (auto *attr : decl->getAttrs ())
460
+ if (auto *diffAttr = dyn_cast<DifferentiableAttr>(attr))
461
+ // NOTE(TF-835): Skipping implicit `@differentiable` attributes is
462
+ // necessary to avoid verification failure in
463
+ // `ASTScopeImpl::verifyThatChildrenAreContainedWithin`.
464
+ // Perhaps this check may no longer be necessary after TF-835: robust
465
+ // `@derivative` attribute lowering.
466
+ if (!diffAttr->isImplicit ())
467
+ sortedDifferentiableAttrs.push_back (diffAttr);
468
+ for (auto *diffAttr : sortBySourceRange (sortedDifferentiableAttrs))
469
+ fn (diffAttr);
470
+ }
471
+
442
472
std::vector<ASTNode> expandIfConfigClausesThenCullAndSortElementsOrMembers (
443
473
ArrayRef<ASTNode> input) const {
444
474
auto cleanedupNodes = sortBySourceRange (cull (expandIfConfigClauses (input)));
@@ -1039,6 +1069,13 @@ void ScopeCreator::addChildrenForAllLocalizableAccessorsInSourceOrder(
1039
1069
return enclosingAbstractStorageDecl == ad->getStorage ();
1040
1070
});
1041
1071
1072
+ // Create scopes for `@differentiable` attributes.
1073
+ forEachDifferentiableAttrInSourceOrder (
1074
+ asd, [&](DifferentiableAttr *diffAttr) {
1075
+ ifUniqueConstructExpandAndInsert<DifferentiableAttributeScope>(
1076
+ parent, diffAttr, asd);
1077
+ });
1078
+
1042
1079
// Sort in order to include synthesized ones, which are out of order.
1043
1080
for (auto *accessor : sortBySourceRange (accessorsToScope))
1044
1081
addToScopeTree (accessor, parent);
@@ -1183,6 +1220,7 @@ NO_NEW_INSERTION_POINT(WholeClosureScope)
1183
1220
NO_EXPANSION(GenericParamScope)
1184
1221
NO_EXPANSION(ClosureParametersScope)
1185
1222
NO_EXPANSION(SpecializeAttributeScope)
1223
+ NO_EXPANSION(DifferentiableAttributeScope)
1186
1224
NO_EXPANSION(ConditionalClausePatternUseScope)
1187
1225
NO_EXPANSION(LookupParentDiversionScope)
1188
1226
@@ -1353,6 +1391,13 @@ void AbstractFunctionDeclScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
1353
1391
scopeCreator.ifUniqueConstructExpandAndInsert <SpecializeAttributeScope>(
1354
1392
this , specializeAttr, decl);
1355
1393
});
1394
+ // Create scopes for `@differentiable` attributes.
1395
+ scopeCreator.forEachDifferentiableAttrInSourceOrder (
1396
+ decl, [&](DifferentiableAttr *diffAttr) {
1397
+ scopeCreator
1398
+ .ifUniqueConstructExpandAndInsert <DifferentiableAttributeScope>(
1399
+ this , diffAttr, decl);
1400
+ });
1356
1401
// Create scopes for generic and ordinary parameters.
1357
1402
// For a subscript declaration, the generic and ordinary parameters are in an
1358
1403
// ancestor scope, so don't make them here.
@@ -1681,6 +1726,10 @@ SpecializeAttributeScope::getEnclosingAbstractStorageDecl() const {
1681
1726
return getParent ().get ()->getEnclosingAbstractStorageDecl ();
1682
1727
}
1683
1728
NullablePtr<AbstractStorageDecl>
1729
+ DifferentiableAttributeScope::getEnclosingAbstractStorageDecl () const {
1730
+ return getParent ().get ()->getEnclosingAbstractStorageDecl ();
1731
+ }
1732
+ NullablePtr<AbstractStorageDecl>
1684
1733
AbstractFunctionDeclScope::getEnclosingAbstractStorageDecl () const {
1685
1734
return getParent ().get ()->getEnclosingAbstractStorageDecl ();
1686
1735
}
@@ -1807,6 +1856,7 @@ GET_REFERRENT(AbstractStmtScope, getStmt())
1807
1856
GET_REFERRENT(CaptureListScope, getExpr())
1808
1857
GET_REFERRENT(WholeClosureScope, getExpr())
1809
1858
GET_REFERRENT(SpecializeAttributeScope, specializeAttr)
1859
+ GET_REFERRENT(DifferentiableAttributeScope, differentiableAttr)
1810
1860
GET_REFERRENT(GenericTypeOrExtensionScope, portion->getReferrentOfScope (this ));
1811
1861
1812
1862
const Decl *
0 commit comments