Skip to content

[AutoDiff upstream] Add @differentiable ASTScope support. #29171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions include/swift/AST/ASTScope.h
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,41 @@ class SpecializeAttributeScope final : public ASTScopeImpl {
DeclConsumer) const override;
};

/// A `@differentiable` attribute scope.
///
/// This exists because `@differentiable` attribute may have a `where` clause
/// referring to generic parameters from some generic context.
class DifferentiableAttributeScope final : public ASTScopeImpl {
public:
DifferentiableAttr *const differentiableAttr;
ValueDecl *const attributedDeclaration;

DifferentiableAttributeScope(DifferentiableAttr *diffAttr, ValueDecl *decl)
: differentiableAttr(diffAttr), attributedDeclaration(decl) {}
virtual ~DifferentiableAttributeScope() {}

std::string getClassName() const override;
SourceRange
getSourceRangeOfThisASTNode(bool omitAssertions = false) const override;
NullablePtr<const void> addressForPrinting() const override {
return differentiableAttr;
}

NullablePtr<AbstractStorageDecl>
getEnclosingAbstractStorageDecl() const override;

NullablePtr<DeclAttribute> getDeclAttributeIfAny() const override {
return differentiableAttr;
}
NullablePtr<const void> getReferrent() const override;

protected:
ASTScopeImpl *expandSpecifically(ScopeCreator &) override;
bool lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
DeclConsumer) const override;
bool doesContextMatchStartingContext(const DeclContext *) const override;
};

class SubscriptDeclScope final : public ASTScopeImpl {
public:
SubscriptDecl *const decl;
Expand Down
1 change: 1 addition & 0 deletions lib/AST/ASTScope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ DEFINE_GET_CLASS_NAME(ClosureParametersScope)
DEFINE_GET_CLASS_NAME(ClosureBodyScope)
DEFINE_GET_CLASS_NAME(TopLevelCodeScope)
DEFINE_GET_CLASS_NAME(SpecializeAttributeScope)
DEFINE_GET_CLASS_NAME(DifferentiableAttributeScope)
DEFINE_GET_CLASS_NAME(SubscriptDeclScope)
DEFINE_GET_CLASS_NAME(VarDeclScope)
DEFINE_GET_CLASS_NAME(EnumElementScope)
Expand Down
50 changes: 50 additions & 0 deletions lib/AST/ASTScopeCreation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ static SourceRange getRangeableSourceRange(const Rangeable *const p) {
static SourceRange getRangeableSourceRange(const SpecializeAttr *a) {
return a->getRange();
}
static SourceRange getRangeableSourceRange(const DifferentiableAttr *a) {
return a->getRange();
}
static SourceRange getRangeableSourceRange(const ASTNode n) {
return n.getSourceRange();
}
Expand Down Expand Up @@ -98,6 +101,17 @@ static void dumpRangeable(SpecializeAttr *r, llvm::raw_ostream &f) {
llvm::errs() << "SpecializeAttr\n";
}

static void dumpRangeable(const DifferentiableAttr *a,
llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED;
static void dumpRangeable(const DifferentiableAttr *a, llvm::raw_ostream &f) {
llvm::errs() << "DifferentiableAttr\n";
}
static void dumpRangeable(DifferentiableAttr *a,
llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED;
static void dumpRangeable(DifferentiableAttr *a, llvm::raw_ostream &f) {
llvm::errs() << "DifferentiableAttr\n";
}

/// For Debugging
template <typename T>
bool doesRangeableRangeMatch(const T *x, const SourceManager &SM,
Expand Down Expand Up @@ -439,6 +453,22 @@ class ScopeCreator final {
fn(specializeAttr);
}

void forEachDifferentiableAttrInSourceOrder(
Decl *decl, function_ref<void(DifferentiableAttr *)> fn) {
std::vector<DifferentiableAttr *> sortedDifferentiableAttrs;
for (auto *attr : decl->getAttrs())
if (auto *diffAttr = dyn_cast<DifferentiableAttr>(attr))
// NOTE(TF-835): Skipping implicit `@differentiable` attributes is
// necessary to avoid verification failure in
// `ASTScopeImpl::verifyThatChildrenAreContainedWithin`.
// Perhaps this check may no longer be necessary after TF-835: robust
// `@derivative` attribute lowering.
if (!diffAttr->isImplicit())
sortedDifferentiableAttrs.push_back(diffAttr);
for (auto *diffAttr : sortBySourceRange(sortedDifferentiableAttrs))
fn(diffAttr);
}

std::vector<ASTNode> expandIfConfigClausesThenCullAndSortElementsOrMembers(
ArrayRef<ASTNode> input) const {
auto cleanedupNodes = sortBySourceRange(cull(expandIfConfigClauses(input)));
Expand Down Expand Up @@ -1039,6 +1069,13 @@ void ScopeCreator::addChildrenForAllLocalizableAccessorsInSourceOrder(
return enclosingAbstractStorageDecl == ad->getStorage();
});

// Create scopes for `@differentiable` attributes.
forEachDifferentiableAttrInSourceOrder(
asd, [&](DifferentiableAttr *diffAttr) {
ifUniqueConstructExpandAndInsert<DifferentiableAttributeScope>(
parent, diffAttr, asd);
});

// Sort in order to include synthesized ones, which are out of order.
for (auto *accessor : sortBySourceRange(accessorsToScope))
addToScopeTree(accessor, parent);
Expand Down Expand Up @@ -1183,6 +1220,7 @@ NO_NEW_INSERTION_POINT(WholeClosureScope)
NO_EXPANSION(GenericParamScope)
NO_EXPANSION(ClosureParametersScope)
NO_EXPANSION(SpecializeAttributeScope)
NO_EXPANSION(DifferentiableAttributeScope)
NO_EXPANSION(ConditionalClausePatternUseScope)
NO_EXPANSION(LookupParentDiversionScope)

Expand Down Expand Up @@ -1353,6 +1391,13 @@ void AbstractFunctionDeclScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
scopeCreator.ifUniqueConstructExpandAndInsert<SpecializeAttributeScope>(
this, specializeAttr, decl);
});
// Create scopes for `@differentiable` attributes.
scopeCreator.forEachDifferentiableAttrInSourceOrder(
decl, [&](DifferentiableAttr *diffAttr) {
scopeCreator
.ifUniqueConstructExpandAndInsert<DifferentiableAttributeScope>(
this, diffAttr, decl);
});
// Create scopes for generic and ordinary parameters.
// For a subscript declaration, the generic and ordinary parameters are in an
// ancestor scope, so don't make them here.
Expand Down Expand Up @@ -1681,6 +1726,10 @@ SpecializeAttributeScope::getEnclosingAbstractStorageDecl() const {
return getParent().get()->getEnclosingAbstractStorageDecl();
}
NullablePtr<AbstractStorageDecl>
DifferentiableAttributeScope::getEnclosingAbstractStorageDecl() const {
return getParent().get()->getEnclosingAbstractStorageDecl();
}
NullablePtr<AbstractStorageDecl>
AbstractFunctionDeclScope::getEnclosingAbstractStorageDecl() const {
return getParent().get()->getEnclosingAbstractStorageDecl();
}
Expand Down Expand Up @@ -1807,6 +1856,7 @@ GET_REFERRENT(AbstractStmtScope, getStmt())
GET_REFERRENT(CaptureListScope, getExpr())
GET_REFERRENT(WholeClosureScope, getExpr())
GET_REFERRENT(SpecializeAttributeScope, specializeAttr)
GET_REFERRENT(DifferentiableAttributeScope, differentiableAttr)
GET_REFERRENT(GenericTypeOrExtensionScope, portion->getReferrentOfScope(this));

const Decl *
Expand Down
32 changes: 32 additions & 0 deletions lib/AST/ASTScopeLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,19 @@ bool GenericParamScope::doesContextMatchStartingContext(
return false;
}

bool DifferentiableAttributeScope::doesContextMatchStartingContext(
const DeclContext *context) const {
// Need special logic to handle case where `attributedDeclaration` is an
// `AbstractStorageDecl` (`SubscriptDecl` or `VarDecl`). The initial starting
// context in `ASTScopeImpl::findStartingScopeForLookup` will be an accessor
// of the `attributedDeclaration`.
if (auto *asd = dyn_cast<AbstractStorageDecl>(attributedDeclaration))
for (auto accessor : asd->getAllAccessors())
if (up_cast<DeclContext>(accessor) == context)
return true;
return false;
}

#pragma mark lookup methods that run once per scope

void ASTScopeImpl::lookup(SmallVectorImpl<const ASTScopeImpl *> &history,
Expand Down Expand Up @@ -438,6 +451,25 @@ bool SpecializeAttributeScope::lookupLocalsOrMembers(
return false;
}

bool DifferentiableAttributeScope::lookupLocalsOrMembers(
ArrayRef<const ASTScopeImpl *>, DeclConsumer consumer) const {
auto visitAbstractFunctionDecl = [&](AbstractFunctionDecl *afd) {
if (auto *params = afd->getGenericParams())
for (auto *param : params->getParams())
if (consumer.consume({param}, DeclVisibilityKind::GenericParameter))
return true;
return false;
};
if (auto *afd = dyn_cast<AbstractFunctionDecl>(attributedDeclaration)) {
return visitAbstractFunctionDecl(afd);
} else if (auto *asd = dyn_cast<AbstractStorageDecl>(attributedDeclaration)) {
for (auto *accessor : asd->getAllAccessors())
if (visitAbstractFunctionDecl(accessor))
return true;
}
return false;
}

bool BraceStmtScope::lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
DeclConsumer consumer) const {
// All types and functions are visible anywhere within a brace statement
Expand Down
5 changes: 5 additions & 0 deletions lib/AST/ASTScopeSourceRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ SourceRange SpecializeAttributeScope::getSourceRangeOfThisASTNode(
return specializeAttr->getRange();
}

SourceRange DifferentiableAttributeScope::getSourceRangeOfThisASTNode(
const bool omitAssertions) const {
return differentiableAttr->getRange();
}

SourceRange AbstractFunctionBodyScope::getSourceRangeOfThisASTNode(
const bool omitAssertions) const {
return decl->getBodySourceRange();
Expand Down