Skip to content

Commit e254b29

Browse files
authored
Add ASTScope support for @differentiable attribute. (#27451)
`@differentiable` attribute where clauses may refer to generic parameters from some generic context. Without special ASTScope support for `@differentiable` attributes, ASTScopeLookup.cpp logic tries to resolve the generic parameter DeclNames in the where clause based on source location alone (`ASTScopeImpl::findChildContaining`) and fails. The fix is to add a special `DifferentiableAttributeScope`, mimicking `SpecializeAttributeScope`. Every `@differentiable` attribute has its own scope, derived from the declaration on which it is declared. Unlike `@_specialize`, `@differentiable` may also be declared on `AbstractStorageDecl` declarations (subscripts and variables). Resolves TF-815. `Decl::getSourceRangeIncludingAttrs` should not consider implicit `@differentiable` attributes generated during `@differentiating` attribute type-checking. TF-835 tracks robust lowering for `@differentiating` attributes that does not involve generating implicit `@differentiable` attributes, circumventing this issue.
1 parent cca9586 commit e254b29

File tree

9 files changed

+254
-6
lines changed

9 files changed

+254
-6
lines changed

include/swift/AST/ASTScope.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,44 @@ class SpecializeAttributeScope final : public ASTScopeImpl {
15311531
DeclConsumer) const override;
15321532
};
15331533

1534+
// SWIFT_ENABLE_TENSORFLOW
1535+
/// A `@differentiable` attribute scope.
1536+
/// This exists because `@differentiable` attribute may have a where clause
1537+
/// referring to generic parameters from some generic context.
1538+
class DifferentiableAttributeScope final : public ASTScopeImpl {
1539+
public:
1540+
DifferentiableAttr *const differentiableAttr;
1541+
ValueDecl *const attributedDeclaration;
1542+
1543+
DifferentiableAttributeScope(DifferentiableAttr *diffAttr,
1544+
ValueDecl *decl)
1545+
: differentiableAttr(diffAttr), attributedDeclaration(decl) {
1546+
}
1547+
virtual ~DifferentiableAttributeScope() {}
1548+
1549+
std::string getClassName() const override;
1550+
SourceRange
1551+
getSourceRangeOfThisASTNode(bool omitAssertions = false) const override;
1552+
NullablePtr<const void> addressForPrinting() const override {
1553+
return differentiableAttr;
1554+
}
1555+
1556+
NullablePtr<AbstractStorageDecl>
1557+
getEnclosingAbstractStorageDecl() const override;
1558+
1559+
NullablePtr<DeclAttribute> getDeclAttributeIfAny() const override {
1560+
return differentiableAttr;
1561+
}
1562+
NullablePtr<const void> getReferrent() const override;
1563+
1564+
protected:
1565+
ASTScopeImpl *expandSpecifically(ScopeCreator &) override;
1566+
bool lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
1567+
DeclConsumer) const override;
1568+
bool doesContextMatchStartingContext(const DeclContext *) const override;
1569+
};
1570+
// SWIFT_ENABLE_TENSORFLOW END
1571+
15341572
class SubscriptDeclScope final : public ASTScopeImpl {
15351573
public:
15361574
SubscriptDecl *const decl;

lib/AST/ASTScope.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ DEFINE_GET_CLASS_NAME(ClosureParametersScope)
231231
DEFINE_GET_CLASS_NAME(ClosureBodyScope)
232232
DEFINE_GET_CLASS_NAME(TopLevelCodeScope)
233233
DEFINE_GET_CLASS_NAME(SpecializeAttributeScope)
234+
// SWIFT_ENABLE_TENSORFLOW
235+
DEFINE_GET_CLASS_NAME(DifferentiableAttributeScope)
236+
// SWIFT_ENABLE_TENSORFLOW END
234237
DEFINE_GET_CLASS_NAME(SubscriptDeclScope)
235238
DEFINE_GET_CLASS_NAME(VarDeclScope)
236239
DEFINE_GET_CLASS_NAME(EnumElementScope)

lib/AST/ASTScopeCreation.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ static SourceRange getRangeableSourceRange(const Rangeable *const p) {
6060
static SourceRange getRangeableSourceRange(const SpecializeAttr *a) {
6161
return a->getRange();
6262
}
63+
// SWIFT_ENABLE_TENSORFLOW
64+
static SourceRange getRangeableSourceRange(const DifferentiableAttr *a) {
65+
return a->getRange();
66+
}
67+
// SWIFT_ENABLE_TENSORFLOW END
6368
static SourceRange getRangeableSourceRange(const ASTNode n) {
6469
return n.getSourceRange();
6570
}
@@ -94,6 +99,19 @@ static void dumpRangeable(SpecializeAttr *r, llvm::raw_ostream &f) {
9499
llvm::errs() << "SpecializeAttr\n";
95100
}
96101

102+
// SWIFT_ENABLE_TENSORFLOW
103+
static void dumpRangeable(const DifferentiableAttr *a,
104+
llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED;
105+
static void dumpRangeable(const DifferentiableAttr *a, llvm::raw_ostream &f) {
106+
llvm::errs() << "DifferentiableAttr\n";
107+
}
108+
static void dumpRangeable(DifferentiableAttr *a,
109+
llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED;
110+
static void dumpRangeable(DifferentiableAttr *a, llvm::raw_ostream &f) {
111+
llvm::errs() << "DifferentiableAttr\n";
112+
}
113+
// SWIFT_ENABLE_TENSORFLOW END
114+
97115
/// For Debugging
98116
template <typename T>
99117
bool doesRangeableRangeMatch(const T *x, const SourceManager &SM,
@@ -435,6 +453,24 @@ class ScopeCreator final {
435453
fn(specializeAttr);
436454
}
437455

456+
// SWIFT_ENABLE_TENSORFLOW
457+
void forEachDifferentiableAttrInSourceOrder(
458+
Decl *decl, function_ref<void(DifferentiableAttr *)> fn) {
459+
std::vector<DifferentiableAttr *> sortedDifferentiableAttrs;
460+
for (auto *attr : decl->getAttrs())
461+
if (auto *diffAttr = dyn_cast<DifferentiableAttr>(attr))
462+
// NOTE(TF-835): Skipping implicit `@differentiable` attributes is
463+
// necessary to avoid verification failure:
464+
// `ASTScopeImpl::verifyThatChildrenAreContainedWithin`.
465+
// Perhaps this check is no longer necessary after TF-835: robust
466+
// `@differentiating` attribute lowering.
467+
if (!diffAttr->isImplicit())
468+
sortedDifferentiableAttrs.push_back(diffAttr);
469+
for (auto *diffAttr : sortBySourceRange(sortedDifferentiableAttrs))
470+
fn(diffAttr);
471+
}
472+
// SWIFT_ENABLE_TENSORFLOW END
473+
438474
std::vector<ASTNode> expandIfConfigClausesThenCullAndSortElementsOrMembers(
439475
ArrayRef<ASTNode> input) const {
440476
auto cleanedupNodes = sortBySourceRange(cull(expandIfConfigClauses(input)));
@@ -1045,6 +1081,15 @@ void ScopeCreator::addChildrenForAllLocalizableAccessorsInSourceOrder(
10451081
return enclosingAbstractStorageDecl == ad->getStorage();
10461082
});
10471083

1084+
// SWIFT_ENABLE_TENSORFLOW
1085+
// Create scopes for `@differentiable` attributes.
1086+
forEachDifferentiableAttrInSourceOrder(
1087+
asd, [&](DifferentiableAttr *diffAttr) {
1088+
ifUniqueConstructExpandAndInsert<DifferentiableAttributeScope>(
1089+
parent, diffAttr, asd);
1090+
});
1091+
// SWIFT_ENABLE_TENSORFLOW END
1092+
10481093
// Sort in order to include synthesized ones, which are out of order.
10491094
// Part of rdar://53921774 rm extra copy
10501095
for (auto *accessor : sortBySourceRange(accessorsToScope))
@@ -1152,6 +1197,9 @@ NO_EXPANSION(GenericParamScope)
11521197
NO_EXPANSION(ASTSourceFileScope)
11531198
NO_EXPANSION(ClosureParametersScope)
11541199
NO_EXPANSION(SpecializeAttributeScope)
1200+
// SWIFT_ENABLE_TENSORFLOW
1201+
NO_EXPANSION(DifferentiableAttributeScope)
1202+
// SWIFT_ENABLE_TENSORFLOW END
11551203
NO_EXPANSION(ConditionalClausePatternUseScope)
11561204
NO_EXPANSION(LookupParentDiversionScope)
11571205

@@ -1309,6 +1357,17 @@ void AbstractFunctionDeclScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
13091357
scopeCreator.ifUniqueConstructExpandAndInsert<SpecializeAttributeScope>(
13101358
this, specializeAttr, decl);
13111359
});
1360+
1361+
// SWIFT_ENABLE_TENSORFLOW
1362+
// Create scopes for `@differentiable` attributes.
1363+
scopeCreator.forEachDifferentiableAttrInSourceOrder(
1364+
decl, [&](DifferentiableAttr *diffAttr) {
1365+
scopeCreator
1366+
.ifUniqueConstructExpandAndInsert<DifferentiableAttributeScope>(
1367+
this, diffAttr, decl);
1368+
});
1369+
// SWIFT_ENABLE_TENSORFLOW END
1370+
13121371
// Create scopes for generic and ordinary parameters.
13131372
// For a subscript declaration, the generic and ordinary parameters are in an
13141373
// ancestor scope, so don't make them here.
@@ -1636,6 +1695,12 @@ NullablePtr<AbstractStorageDecl>
16361695
SpecializeAttributeScope::getEnclosingAbstractStorageDecl() const {
16371696
return getParent().get()->getEnclosingAbstractStorageDecl();
16381697
}
1698+
// SWIFT_ENABLE_TENSORFLOW
1699+
NullablePtr<AbstractStorageDecl>
1700+
DifferentiableAttributeScope::getEnclosingAbstractStorageDecl() const {
1701+
return getParent().get()->getEnclosingAbstractStorageDecl();
1702+
}
1703+
// SWIFT_ENABLE_TENSORFLOW END
16391704
NullablePtr<AbstractStorageDecl>
16401705
AbstractFunctionDeclScope::getEnclosingAbstractStorageDecl() const {
16411706
return getParent().get()->getEnclosingAbstractStorageDecl();
@@ -1784,6 +1849,9 @@ GET_REFERRENT(AbstractStmtScope, getStmt())
17841849
GET_REFERRENT(CaptureListScope, getExpr())
17851850
GET_REFERRENT(WholeClosureScope, getExpr())
17861851
GET_REFERRENT(SpecializeAttributeScope, specializeAttr)
1852+
// SWIFT_ENABLE_TENSORFLOW
1853+
GET_REFERRENT(DifferentiableAttributeScope, differentiableAttr)
1854+
// SWIFT_ENABLE_TENSORFLOW END
17871855
GET_REFERRENT(GenericTypeOrExtensionScope, portion->getReferrentOfScope(this));
17881856

17891857
const Decl *

lib/AST/ASTScopeLookup.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,21 @@ bool GenericParamScope::doesContextMatchStartingContext(
194194
return false;
195195
}
196196

197+
// SWIFT_ENABLE_TENSORFLOW
198+
bool DifferentiableAttributeScope::doesContextMatchStartingContext(
199+
const DeclContext *context) const {
200+
// Need special logic to handle case where `attributedDeclaration` is an
201+
// `AbstractStorageDecl` (`SubscriptDecl` or `VarDecl`). The initial starting
202+
// context in `ASTScopeImpl::findStartingScopeForLookup` will be an accessor
203+
// of the `attributedDeclaration`.
204+
if (auto *asd = dyn_cast<AbstractStorageDecl>(attributedDeclaration))
205+
for (auto accessor : asd->getAllAccessors())
206+
if (up_cast<DeclContext>(accessor) == context)
207+
return true;
208+
return false;
209+
}
210+
// SWIFT_ENABLE_TENSORFLOW END
211+
197212
#pragma mark lookup methods that run once per scope
198213

199214
void ASTScopeImpl::lookup(SmallVectorImpl<const ASTScopeImpl *> &history,
@@ -424,6 +439,27 @@ bool SpecializeAttributeScope::lookupLocalsOrMembers(
424439
return false;
425440
}
426441

442+
// SWIFT_ENABLE_TENSORFLOW
443+
bool DifferentiableAttributeScope::lookupLocalsOrMembers(
444+
ArrayRef<const ASTScopeImpl *>, DeclConsumer consumer) const {
445+
auto visitAbstractFunctionDecl = [&](AbstractFunctionDecl *afd) {
446+
if (auto *params = afd->getGenericParams())
447+
for (auto *param : params->getParams())
448+
if (consumer.consume({param}, DeclVisibilityKind::GenericParameter))
449+
return true;
450+
return false;
451+
};
452+
if (auto *afd = dyn_cast<AbstractFunctionDecl>(attributedDeclaration)) {
453+
return visitAbstractFunctionDecl(afd);
454+
} else if (auto *asd = dyn_cast<AbstractStorageDecl>(attributedDeclaration)) {
455+
for (auto *accessor : asd->getAllAccessors())
456+
if (visitAbstractFunctionDecl(accessor))
457+
return true;
458+
}
459+
return false;
460+
}
461+
// SWIFT_ENABLE_TENSORFLOW END
462+
427463
bool BraceStmtScope::lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
428464
DeclConsumer consumer) const {
429465
// All types and functions are visible anywhere within a brace statement

lib/AST/ASTScopeSourceRange.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,13 @@ SourceRange SpecializeAttributeScope::getSourceRangeOfThisASTNode(
193193
return specializeAttr->getRange();
194194
}
195195

196+
// SWIFT_ENABLE_TENSORFLOW
197+
SourceRange DifferentiableAttributeScope::getSourceRangeOfThisASTNode(
198+
const bool omitAssertions) const {
199+
return differentiableAttr->getRange();
200+
}
201+
// SWIFT_ENABLE_TENSORFLOW END
202+
196203
SourceRange AbstractFunctionBodyScope::getSourceRangeOfThisASTNode(
197204
const bool omitAssertions) const {
198205
return decl->getBodySourceRange();

lib/AST/Decl.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,16 @@ SourceRange Decl::getSourceRangeIncludingAttrs() const {
464464
}
465465

466466
for (auto Attr : getAttrs()) {
467+
// SWIFT_ENABLE_TENSORFLOW
468+
// Skip implicitly `@differentiable` attribute generated during
469+
// `@differentiating` attribute type-checking.
470+
// TODO(TF-835): Instead of generating implicit `@differentiable`
471+
// attributes, lower `@differentiating` attributes to `[differentiable]`
472+
// attributes on the referenced declaration.
473+
if (auto *diffAttr = dyn_cast<DifferentiableAttr>(Attr))
474+
if (diffAttr->isImplicit())
475+
continue;
476+
// SWIFT_ENABLE_TENSORFLOW END
467477
if (Attr->getRange().isValid())
468478
Range.widen(Attr->getRangeWithAt());
469479
}

lib/AST/UnqualifiedLookup.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -485,12 +485,7 @@ void UnqualifiedLookupFactory::performUnqualifiedLookup() {
485485
DC, initialIsCascadingUse};
486486
const bool crosscheckUnqualifiedLookup =
487487
Ctx.LangOpts.CrosscheckUnqualifiedLookup;
488-
// SWIFT_ENABLE_TENSORFLOW
489-
// NOTE(TF-815): using AST scopes for lookup causes standard library
490-
// type-checking for `@differentiable` attributes to fail.
491-
if ((false)) {
492-
// if (useASTScopesForLookup()) {
493-
// SWIFT_ENABLE_TENSORFLOW END
488+
if (useASTScopesForLookup()) {
494489
static bool haveWarned = false;
495490
if (!haveWarned && Ctx.LangOpts.WarnIfASTScopeLookup) {
496491
haveWarned = true;
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// SWIFT_ENABLE_TENSORFLOW
2+
// Check that ASTScope lookup works for `@differentiable` attribute.
3+
4+
// NOTE(TF-815): Without custom scope support, ASTScopeLookup crashes for
5+
// `@differentiable` attribute with where clauses on subscript and `var`
6+
// declarations.
7+
8+
// RUN: %target-swift-frontend -typecheck %s -enable-astscope-lookup
9+
10+
struct Test<Element> {
11+
var element: Element
12+
}
13+
extension Test: Differentiable where Element: Differentiable {}
14+
extension Test {
15+
@differentiable(where Element: Differentiable)
16+
init(_ element: Element) {
17+
self.element = element
18+
}
19+
20+
@differentiable(where Element: Differentiable)
21+
func method() -> Element {
22+
element
23+
}
24+
25+
@differentiable(where T: Differentiable)
26+
func method<T>(_ x: T) -> T {
27+
x
28+
}
29+
30+
// NOTE(TF-815): This crashed without `DifferentiableAttributeScope` support.
31+
@differentiable(where Element: Differentiable)
32+
subscript(implicitGetterOnly_ : Void) -> Element {
33+
element
34+
}
35+
36+
subscript(explicitGetterAndSetter _: Void) -> Element {
37+
@differentiable(where Element: Differentiable)
38+
get { element }
39+
set {}
40+
}
41+
42+
// NOTE(TF-815): This crashed without `DifferentiableAttributeScope` support.
43+
@differentiable(where Element: Differentiable)
44+
var computedProperty: Element {
45+
element
46+
}
47+
48+
var computedPropertyExplicitGetter: Element {
49+
@differentiable(where Element: Differentiable)
50+
get {
51+
element
52+
}
53+
}
54+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// SWIFT_ENABLE_TENSORFLOW
2+
// Check that ASTScope lookup works for `@differentiating` attribute.
3+
4+
// NOTE(TF-835): This test is only necessary because `@differentiating`
5+
// attribute type-checking generates implicit `@differentiable` attributes
6+
// on the referenced declaration. Robust lowering for `@differentiating`
7+
// attributes should make special logic regarding implicit `@differentiable`
8+
// attributes unnecessary.
9+
10+
// RUN: %target-swift-frontend -typecheck %s -enable-astscope-lookup
11+
12+
struct Test<Element> {
13+
var element: Element
14+
}
15+
extension Test: Differentiable where Element: Differentiable {}
16+
extension Test {
17+
static func +(lhs: Self, rhs: Self) -> Self {
18+
lhs
19+
}
20+
static func -(lhs: Self, rhs: Self) -> Self {
21+
lhs
22+
}
23+
}
24+
25+
extension Test where Element : Differentiable {
26+
@differentiating(+)
27+
internal static func _vjpAdd(lhs: Self, rhs: Self)
28+
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
29+
return (lhs + rhs, { v in (v, v) })
30+
}
31+
32+
@differentiating(-)
33+
internal static func _vjpSubtract(lhs: Self, rhs: Self)
34+
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
35+
return (lhs + rhs, { v in (v, v) })
36+
}
37+
}

0 commit comments

Comments
 (0)