Skip to content

Commit 0ab9598

Browse files
committed
[AutoDiff upstream] Add @differentiable ASTScope support.
`@differentiable` attributes may contain `where` clauses referencing generic parameters from some generic context, just like `@_specialize` attributes. Without special ASTScope support for `@differentiable` attributes, ASTScopeLookup.cpp logic tries to resolve the generic parameter `DeclName`s in `where` clauses based on source location alone (`ASTScopeImpl::findChildContaining`) and fails. The fix is to add a special `DifferentiableAttributeScope`, mimicking `SpecializeAttributeScope`. Every `@differentiable` attribute has its own scope, derived from the declaration on which it is declared. Unlike `@_specialize`, `@differentiable` may also be declared on `AbstractStorageDecl` declarations (subscripts and variables). Upstreams #27451. Progress towards TF-828: upstream `@differentiable` attribute type-checking.
1 parent 999836b commit 0ab9598

File tree

5 files changed

+123
-0
lines changed

5 files changed

+123
-0
lines changed

include/swift/AST/ASTScope.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,41 @@ class SpecializeAttributeScope final : public ASTScopeImpl {
15681568
DeclConsumer) const override;
15691569
};
15701570

1571+
/// A `@differentiable` attribute scope.
1572+
///
1573+
/// This exists because `@differentiable` attribute may have a `where` clause
1574+
/// referring to generic parameters from some generic context.
1575+
class DifferentiableAttributeScope final : public ASTScopeImpl {
1576+
public:
1577+
DifferentiableAttr *const differentiableAttr;
1578+
ValueDecl *const attributedDeclaration;
1579+
1580+
DifferentiableAttributeScope(DifferentiableAttr *diffAttr, ValueDecl *decl)
1581+
: differentiableAttr(diffAttr), attributedDeclaration(decl) {}
1582+
virtual ~DifferentiableAttributeScope() {}
1583+
1584+
std::string getClassName() const override;
1585+
SourceRange
1586+
getSourceRangeOfThisASTNode(bool omitAssertions = false) const override;
1587+
NullablePtr<const void> addressForPrinting() const override {
1588+
return differentiableAttr;
1589+
}
1590+
1591+
NullablePtr<AbstractStorageDecl>
1592+
getEnclosingAbstractStorageDecl() const override;
1593+
1594+
NullablePtr<DeclAttribute> getDeclAttributeIfAny() const override {
1595+
return differentiableAttr;
1596+
}
1597+
NullablePtr<const void> getReferrent() const override;
1598+
1599+
protected:
1600+
ASTScopeImpl *expandSpecifically(ScopeCreator &) override;
1601+
bool lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
1602+
DeclConsumer) const override;
1603+
bool doesContextMatchStartingContext(const DeclContext *) const override;
1604+
};
1605+
15711606
class SubscriptDeclScope final : public ASTScopeImpl {
15721607
public:
15731608
SubscriptDecl *const decl;

lib/AST/ASTScope.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ DEFINE_GET_CLASS_NAME(ClosureParametersScope)
228228
DEFINE_GET_CLASS_NAME(ClosureBodyScope)
229229
DEFINE_GET_CLASS_NAME(TopLevelCodeScope)
230230
DEFINE_GET_CLASS_NAME(SpecializeAttributeScope)
231+
DEFINE_GET_CLASS_NAME(DifferentiableAttributeScope)
231232
DEFINE_GET_CLASS_NAME(SubscriptDeclScope)
232233
DEFINE_GET_CLASS_NAME(VarDeclScope)
233234
DEFINE_GET_CLASS_NAME(EnumElementScope)

lib/AST/ASTScopeCreation.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ static SourceRange getRangeableSourceRange(const Rangeable *const p) {
6464
static SourceRange getRangeableSourceRange(const SpecializeAttr *a) {
6565
return a->getRange();
6666
}
67+
static SourceRange getRangeableSourceRange(const DifferentiableAttr *a) {
68+
return a->getRange();
69+
}
6770
static SourceRange getRangeableSourceRange(const ASTNode n) {
6871
return n.getSourceRange();
6972
}
@@ -98,6 +101,17 @@ static void dumpRangeable(SpecializeAttr *r, llvm::raw_ostream &f) {
98101
llvm::errs() << "SpecializeAttr\n";
99102
}
100103

104+
static void dumpRangeable(const DifferentiableAttr *a,
105+
llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED;
106+
static void dumpRangeable(const DifferentiableAttr *a, llvm::raw_ostream &f) {
107+
llvm::errs() << "DifferentiableAttr\n";
108+
}
109+
static void dumpRangeable(DifferentiableAttr *a,
110+
llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED;
111+
static void dumpRangeable(DifferentiableAttr *a, llvm::raw_ostream &f) {
112+
llvm::errs() << "DifferentiableAttr\n";
113+
}
114+
101115
/// For Debugging
102116
template <typename T>
103117
bool doesRangeableRangeMatch(const T *x, const SourceManager &SM,
@@ -439,6 +453,22 @@ class ScopeCreator final {
439453
fn(specializeAttr);
440454
}
441455

456+
void forEachDifferentiableAttrInSourceOrder(
457+
Decl *decl, function_ref<void(DifferentiableAttr *)> fn) {
458+
std::vector<DifferentiableAttr *> sortedDifferentiableAttrs;
459+
for (auto *attr : decl->getAttrs())
460+
if (auto *diffAttr = dyn_cast<DifferentiableAttr>(attr))
461+
// NOTE(TF-835): Skipping implicit `@differentiable` attributes is
462+
// necessary to avoid verification failure in
463+
// `ASTScopeImpl::verifyThatChildrenAreContainedWithin`.
464+
// Perhaps this check may no longer be necessary after TF-835: robust
465+
// `@derivative` attribute lowering.
466+
if (!diffAttr->isImplicit())
467+
sortedDifferentiableAttrs.push_back(diffAttr);
468+
for (auto *diffAttr : sortBySourceRange(sortedDifferentiableAttrs))
469+
fn(diffAttr);
470+
}
471+
442472
std::vector<ASTNode> expandIfConfigClausesThenCullAndSortElementsOrMembers(
443473
ArrayRef<ASTNode> input) const {
444474
auto cleanedupNodes = sortBySourceRange(cull(expandIfConfigClauses(input)));
@@ -1039,6 +1069,13 @@ void ScopeCreator::addChildrenForAllLocalizableAccessorsInSourceOrder(
10391069
return enclosingAbstractStorageDecl == ad->getStorage();
10401070
});
10411071

1072+
// Create scopes for `@differentiable` attributes.
1073+
forEachDifferentiableAttrInSourceOrder(
1074+
asd, [&](DifferentiableAttr *diffAttr) {
1075+
ifUniqueConstructExpandAndInsert<DifferentiableAttributeScope>(
1076+
parent, diffAttr, asd);
1077+
});
1078+
10421079
// Sort in order to include synthesized ones, which are out of order.
10431080
for (auto *accessor : sortBySourceRange(accessorsToScope))
10441081
addToScopeTree(accessor, parent);
@@ -1183,6 +1220,7 @@ NO_NEW_INSERTION_POINT(WholeClosureScope)
11831220
NO_EXPANSION(GenericParamScope)
11841221
NO_EXPANSION(ClosureParametersScope)
11851222
NO_EXPANSION(SpecializeAttributeScope)
1223+
NO_EXPANSION(DifferentiableAttributeScope)
11861224
NO_EXPANSION(ConditionalClausePatternUseScope)
11871225
NO_EXPANSION(LookupParentDiversionScope)
11881226

@@ -1353,6 +1391,13 @@ void AbstractFunctionDeclScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
13531391
scopeCreator.ifUniqueConstructExpandAndInsert<SpecializeAttributeScope>(
13541392
this, specializeAttr, decl);
13551393
});
1394+
// Create scopes for `@differentiable` attributes.
1395+
scopeCreator.forEachDifferentiableAttrInSourceOrder(
1396+
decl, [&](DifferentiableAttr *diffAttr) {
1397+
scopeCreator
1398+
.ifUniqueConstructExpandAndInsert<DifferentiableAttributeScope>(
1399+
this, diffAttr, decl);
1400+
});
13561401
// Create scopes for generic and ordinary parameters.
13571402
// For a subscript declaration, the generic and ordinary parameters are in an
13581403
// ancestor scope, so don't make them here.
@@ -1681,6 +1726,10 @@ SpecializeAttributeScope::getEnclosingAbstractStorageDecl() const {
16811726
return getParent().get()->getEnclosingAbstractStorageDecl();
16821727
}
16831728
NullablePtr<AbstractStorageDecl>
1729+
DifferentiableAttributeScope::getEnclosingAbstractStorageDecl() const {
1730+
return getParent().get()->getEnclosingAbstractStorageDecl();
1731+
}
1732+
NullablePtr<AbstractStorageDecl>
16841733
AbstractFunctionDeclScope::getEnclosingAbstractStorageDecl() const {
16851734
return getParent().get()->getEnclosingAbstractStorageDecl();
16861735
}
@@ -1807,6 +1856,7 @@ GET_REFERRENT(AbstractStmtScope, getStmt())
18071856
GET_REFERRENT(CaptureListScope, getExpr())
18081857
GET_REFERRENT(WholeClosureScope, getExpr())
18091858
GET_REFERRENT(SpecializeAttributeScope, specializeAttr)
1859+
GET_REFERRENT(DifferentiableAttributeScope, differentiableAttr)
18101860
GET_REFERRENT(GenericTypeOrExtensionScope, portion->getReferrentOfScope(this));
18111861

18121862
const Decl *

lib/AST/ASTScopeLookup.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,19 @@ bool GenericParamScope::doesContextMatchStartingContext(
208208
return false;
209209
}
210210

211+
bool DifferentiableAttributeScope::doesContextMatchStartingContext(
212+
const DeclContext *context) const {
213+
// Need special logic to handle case where `attributedDeclaration` is an
214+
// `AbstractStorageDecl` (`SubscriptDecl` or `VarDecl`). The initial starting
215+
// context in `ASTScopeImpl::findStartingScopeForLookup` will be an accessor
216+
// of the `attributedDeclaration`.
217+
if (auto *asd = dyn_cast<AbstractStorageDecl>(attributedDeclaration))
218+
for (auto accessor : asd->getAllAccessors())
219+
if (up_cast<DeclContext>(accessor) == context)
220+
return true;
221+
return false;
222+
}
223+
211224
#pragma mark lookup methods that run once per scope
212225

213226
void ASTScopeImpl::lookup(SmallVectorImpl<const ASTScopeImpl *> &history,
@@ -438,6 +451,25 @@ bool SpecializeAttributeScope::lookupLocalsOrMembers(
438451
return false;
439452
}
440453

454+
bool DifferentiableAttributeScope::lookupLocalsOrMembers(
455+
ArrayRef<const ASTScopeImpl *>, DeclConsumer consumer) const {
456+
auto visitAbstractFunctionDecl = [&](AbstractFunctionDecl *afd) {
457+
if (auto *params = afd->getGenericParams())
458+
for (auto *param : params->getParams())
459+
if (consumer.consume({param}, DeclVisibilityKind::GenericParameter))
460+
return true;
461+
return false;
462+
};
463+
if (auto *afd = dyn_cast<AbstractFunctionDecl>(attributedDeclaration)) {
464+
return visitAbstractFunctionDecl(afd);
465+
} else if (auto *asd = dyn_cast<AbstractStorageDecl>(attributedDeclaration)) {
466+
for (auto *accessor : asd->getAllAccessors())
467+
if (visitAbstractFunctionDecl(accessor))
468+
return true;
469+
}
470+
return false;
471+
}
472+
441473
bool BraceStmtScope::lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
442474
DeclConsumer consumer) const {
443475
// All types and functions are visible anywhere within a brace statement

lib/AST/ASTScopeSourceRange.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ SourceRange SpecializeAttributeScope::getSourceRangeOfThisASTNode(
193193
return specializeAttr->getRange();
194194
}
195195

196+
SourceRange DifferentiableAttributeScope::getSourceRangeOfThisASTNode(
197+
const bool omitAssertions) const {
198+
return differentiableAttr->getRange();
199+
}
200+
196201
SourceRange AbstractFunctionBodyScope::getSourceRangeOfThisASTNode(
197202
const bool omitAssertions) const {
198203
return decl->getBodySourceRange();

0 commit comments

Comments
 (0)