Skip to content

Add ASTScope support for @differentiable attribute. #27451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions include/swift/AST/ASTScope.h
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,44 @@ class SpecializeAttributeScope final : public ASTScopeImpl {
DeclConsumer) const override;
};

// SWIFT_ENABLE_TENSORFLOW
/// A `@differentiable` attribute scope.
/// This exists because `@differentiable` attribute may have a where clause
/// referring to generic parameters from some generic context.
class DifferentiableAttributeScope final : public ASTScopeImpl {
public:
DifferentiableAttr *const differentiableAttr;
ValueDecl *const attributedDeclaration;

DifferentiableAttributeScope(DifferentiableAttr *diffAttr,
ValueDecl *decl)
: differentiableAttr(diffAttr), attributedDeclaration(decl) {
}
virtual ~DifferentiableAttributeScope() {}

std::string getClassName() const override;
SourceRange
getSourceRangeOfThisASTNode(bool omitAssertions = false) const override;
NullablePtr<const void> addressForPrinting() const override {
return differentiableAttr;
}

NullablePtr<AbstractStorageDecl>
getEnclosingAbstractStorageDecl() const override;

NullablePtr<DeclAttribute> getDeclAttributeIfAny() const override {
return differentiableAttr;
}
NullablePtr<const void> getReferrent() const override;

protected:
ASTScopeImpl *expandSpecifically(ScopeCreator &) override;
bool lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
DeclConsumer) const override;
bool doesContextMatchStartingContext(const DeclContext *) const override;
};
// SWIFT_ENABLE_TENSORFLOW END

class SubscriptDeclScope final : public ASTScopeImpl {
public:
SubscriptDecl *const decl;
Expand Down
3 changes: 3 additions & 0 deletions lib/AST/ASTScope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ DEFINE_GET_CLASS_NAME(ClosureParametersScope)
DEFINE_GET_CLASS_NAME(ClosureBodyScope)
DEFINE_GET_CLASS_NAME(TopLevelCodeScope)
DEFINE_GET_CLASS_NAME(SpecializeAttributeScope)
// SWIFT_ENABLE_TENSORFLOW
DEFINE_GET_CLASS_NAME(DifferentiableAttributeScope)
// SWIFT_ENABLE_TENSORFLOW END
DEFINE_GET_CLASS_NAME(SubscriptDeclScope)
DEFINE_GET_CLASS_NAME(VarDeclScope)
DEFINE_GET_CLASS_NAME(EnumElementScope)
Expand Down
68 changes: 68 additions & 0 deletions lib/AST/ASTScopeCreation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ static SourceRange getRangeableSourceRange(const Rangeable *const p) {
static SourceRange getRangeableSourceRange(const SpecializeAttr *a) {
return a->getRange();
}
// SWIFT_ENABLE_TENSORFLOW
static SourceRange getRangeableSourceRange(const DifferentiableAttr *a) {
return a->getRange();
}
// SWIFT_ENABLE_TENSORFLOW END
static SourceRange getRangeableSourceRange(const ASTNode n) {
return n.getSourceRange();
}
Expand Down Expand Up @@ -94,6 +99,19 @@ static void dumpRangeable(SpecializeAttr *r, llvm::raw_ostream &f) {
llvm::errs() << "SpecializeAttr\n";
}

// SWIFT_ENABLE_TENSORFLOW
static void dumpRangeable(const DifferentiableAttr *a,
llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED;
static void dumpRangeable(const DifferentiableAttr *a, llvm::raw_ostream &f) {
llvm::errs() << "DifferentiableAttr\n";
}
static void dumpRangeable(DifferentiableAttr *a,
llvm::raw_ostream &f) LLVM_ATTRIBUTE_USED;
static void dumpRangeable(DifferentiableAttr *a, llvm::raw_ostream &f) {
llvm::errs() << "DifferentiableAttr\n";
}
// SWIFT_ENABLE_TENSORFLOW END

/// For Debugging
template <typename T>
bool doesRangeableRangeMatch(const T *x, const SourceManager &SM,
Expand Down Expand Up @@ -435,6 +453,24 @@ class ScopeCreator final {
fn(specializeAttr);
}

// SWIFT_ENABLE_TENSORFLOW
void forEachDifferentiableAttrInSourceOrder(
Decl *decl, function_ref<void(DifferentiableAttr *)> fn) {
std::vector<DifferentiableAttr *> sortedDifferentiableAttrs;
for (auto *attr : decl->getAttrs())
if (auto *diffAttr = dyn_cast<DifferentiableAttr>(attr))
// NOTE(TF-835): Skipping implicit `@differentiable` attributes is
// necessary to avoid verification failure:
// `ASTScopeImpl::verifyThatChildrenAreContainedWithin`.
// Perhaps this check is no longer necessary after TF-835: robust
// `@differentiating` attribute lowering.
if (!diffAttr->isImplicit())
sortedDifferentiableAttrs.push_back(diffAttr);
for (auto *diffAttr : sortBySourceRange(sortedDifferentiableAttrs))
fn(diffAttr);
}
// SWIFT_ENABLE_TENSORFLOW END

std::vector<ASTNode> expandIfConfigClausesThenCullAndSortElementsOrMembers(
ArrayRef<ASTNode> input) const {
auto cleanedupNodes = sortBySourceRange(cull(expandIfConfigClauses(input)));
Expand Down Expand Up @@ -1045,6 +1081,15 @@ void ScopeCreator::addChildrenForAllLocalizableAccessorsInSourceOrder(
return enclosingAbstractStorageDecl == ad->getStorage();
});

// SWIFT_ENABLE_TENSORFLOW
// Create scopes for `@differentiable` attributes.
forEachDifferentiableAttrInSourceOrder(
asd, [&](DifferentiableAttr *diffAttr) {
ifUniqueConstructExpandAndInsert<DifferentiableAttributeScope>(
parent, diffAttr, asd);
});
// SWIFT_ENABLE_TENSORFLOW END

// Sort in order to include synthesized ones, which are out of order.
// Part of rdar://53921774 rm extra copy
for (auto *accessor : sortBySourceRange(accessorsToScope))
Expand Down Expand Up @@ -1152,6 +1197,9 @@ NO_EXPANSION(GenericParamScope)
NO_EXPANSION(ASTSourceFileScope)
NO_EXPANSION(ClosureParametersScope)
NO_EXPANSION(SpecializeAttributeScope)
// SWIFT_ENABLE_TENSORFLOW
NO_EXPANSION(DifferentiableAttributeScope)
// SWIFT_ENABLE_TENSORFLOW END
NO_EXPANSION(ConditionalClausePatternUseScope)
NO_EXPANSION(LookupParentDiversionScope)

Expand Down Expand Up @@ -1309,6 +1357,17 @@ void AbstractFunctionDeclScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
scopeCreator.ifUniqueConstructExpandAndInsert<SpecializeAttributeScope>(
this, specializeAttr, decl);
});

// SWIFT_ENABLE_TENSORFLOW
// Create scopes for `@differentiable` attributes.
scopeCreator.forEachDifferentiableAttrInSourceOrder(
decl, [&](DifferentiableAttr *diffAttr) {
scopeCreator
.ifUniqueConstructExpandAndInsert<DifferentiableAttributeScope>(
this, diffAttr, decl);
});
// SWIFT_ENABLE_TENSORFLOW END

// Create scopes for generic and ordinary parameters.
// For a subscript declaration, the generic and ordinary parameters are in an
// ancestor scope, so don't make them here.
Expand Down Expand Up @@ -1636,6 +1695,12 @@ NullablePtr<AbstractStorageDecl>
SpecializeAttributeScope::getEnclosingAbstractStorageDecl() const {
return getParent().get()->getEnclosingAbstractStorageDecl();
}
// SWIFT_ENABLE_TENSORFLOW
NullablePtr<AbstractStorageDecl>
DifferentiableAttributeScope::getEnclosingAbstractStorageDecl() const {
return getParent().get()->getEnclosingAbstractStorageDecl();
}
// SWIFT_ENABLE_TENSORFLOW END
NullablePtr<AbstractStorageDecl>
AbstractFunctionDeclScope::getEnclosingAbstractStorageDecl() const {
return getParent().get()->getEnclosingAbstractStorageDecl();
Expand Down Expand Up @@ -1784,6 +1849,9 @@ GET_REFERRENT(AbstractStmtScope, getStmt())
GET_REFERRENT(CaptureListScope, getExpr())
GET_REFERRENT(WholeClosureScope, getExpr())
GET_REFERRENT(SpecializeAttributeScope, specializeAttr)
// SWIFT_ENABLE_TENSORFLOW
GET_REFERRENT(DifferentiableAttributeScope, differentiableAttr)
// SWIFT_ENABLE_TENSORFLOW END
GET_REFERRENT(GenericTypeOrExtensionScope, portion->getReferrentOfScope(this));

const Decl *
Expand Down
36 changes: 36 additions & 0 deletions lib/AST/ASTScopeLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,21 @@ bool GenericParamScope::doesContextMatchStartingContext(
return false;
}

// SWIFT_ENABLE_TENSORFLOW
bool DifferentiableAttributeScope::doesContextMatchStartingContext(
const DeclContext *context) const {
// Need special logic to handle case where `attributedDeclaration` is an
// `AbstractStorageDecl` (`SubscriptDecl` or `VarDecl`). The initial starting
// context in `ASTScopeImpl::findStartingScopeForLookup` will be an accessor
// of the `attributedDeclaration`.
if (auto *asd = dyn_cast<AbstractStorageDecl>(attributedDeclaration))
for (auto accessor : asd->getAllAccessors())
if (up_cast<DeclContext>(accessor) == context)
return true;
return false;
}
// SWIFT_ENABLE_TENSORFLOW END

#pragma mark lookup methods that run once per scope

void ASTScopeImpl::lookup(SmallVectorImpl<const ASTScopeImpl *> &history,
Expand Down Expand Up @@ -424,6 +439,27 @@ bool SpecializeAttributeScope::lookupLocalsOrMembers(
return false;
}

// SWIFT_ENABLE_TENSORFLOW
bool DifferentiableAttributeScope::lookupLocalsOrMembers(
ArrayRef<const ASTScopeImpl *>, DeclConsumer consumer) const {
auto visitAbstractFunctionDecl = [&](AbstractFunctionDecl *afd) {
if (auto *params = afd->getGenericParams())
for (auto *param : params->getParams())
if (consumer.consume({param}, DeclVisibilityKind::GenericParameter))
return true;
return false;
};
if (auto *afd = dyn_cast<AbstractFunctionDecl>(attributedDeclaration)) {
return visitAbstractFunctionDecl(afd);
} else if (auto *asd = dyn_cast<AbstractStorageDecl>(attributedDeclaration)) {
for (auto *accessor : asd->getAllAccessors())
if (visitAbstractFunctionDecl(accessor))
return true;
}
return false;
}
// SWIFT_ENABLE_TENSORFLOW END

bool BraceStmtScope::lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
DeclConsumer consumer) const {
// All types and functions are visible anywhere within a brace statement
Expand Down
7 changes: 7 additions & 0 deletions lib/AST/ASTScopeSourceRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ SourceRange SpecializeAttributeScope::getSourceRangeOfThisASTNode(
return specializeAttr->getRange();
}

// SWIFT_ENABLE_TENSORFLOW
SourceRange DifferentiableAttributeScope::getSourceRangeOfThisASTNode(
const bool omitAssertions) const {
return differentiableAttr->getRange();
}
// SWIFT_ENABLE_TENSORFLOW END

SourceRange AbstractFunctionBodyScope::getSourceRangeOfThisASTNode(
const bool omitAssertions) const {
return decl->getBodySourceRange();
Expand Down
10 changes: 10 additions & 0 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,16 @@ SourceRange Decl::getSourceRangeIncludingAttrs() const {
}

for (auto Attr : getAttrs()) {
// SWIFT_ENABLE_TENSORFLOW
// Skip implicitly `@differentiable` attribute generated during
// `@differentiating` attribute type-checking.
// TODO(TF-835): Instead of generating implicit `@differentiable`
// attributes, lower `@differentiating` attributes to `[differentiable]`
// attributes on the referenced declaration.
if (auto *diffAttr = dyn_cast<DifferentiableAttr>(Attr))
if (diffAttr->isImplicit())
continue;
// SWIFT_ENABLE_TENSORFLOW END
if (Attr->getRange().isValid())
Range.widen(Attr->getRangeWithAt());
}
Expand Down
7 changes: 1 addition & 6 deletions lib/AST/UnqualifiedLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,7 @@ void UnqualifiedLookupFactory::performUnqualifiedLookup() {
DC, initialIsCascadingUse};
const bool crosscheckUnqualifiedLookup =
Ctx.LangOpts.CrosscheckUnqualifiedLookup;
// SWIFT_ENABLE_TENSORFLOW
// NOTE(TF-815): using AST scopes for lookup causes standard library
// type-checking for `@differentiable` attributes to fail.
if ((false)) {
// if (useASTScopesForLookup()) {
// SWIFT_ENABLE_TENSORFLOW END
if (useASTScopesForLookup()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this reverts an earlier workaround hack.

static bool haveWarned = false;
if (!haveWarned && Ctx.LangOpts.WarnIfASTScopeLookup) {
haveWarned = true;
Expand Down
54 changes: 54 additions & 0 deletions test/NameBinding/astscope-differentiable-attr.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// SWIFT_ENABLE_TENSORFLOW
// Check that ASTScope lookup works for `@differentiable` attribute.

// NOTE(TF-815): Without custom scope support, ASTScopeLookup crashes for
// `@differentiable` attribute with where clauses on subscript and `var`
// declarations.

// RUN: %target-swift-frontend -typecheck %s -enable-astscope-lookup

struct Test<Element> {
var element: Element
}
extension Test: Differentiable where Element: Differentiable {}
extension Test {
@differentiable(where Element: Differentiable)
init(_ element: Element) {
self.element = element
}

@differentiable(where Element: Differentiable)
func method() -> Element {
element
}

@differentiable(where T: Differentiable)
func method<T>(_ x: T) -> T {
x
}

// NOTE(TF-815): This crashed without `DifferentiableAttributeScope` support.
@differentiable(where Element: Differentiable)
subscript(implicitGetterOnly_ : Void) -> Element {
element
}

subscript(explicitGetterAndSetter _: Void) -> Element {
@differentiable(where Element: Differentiable)
get { element }
set {}
}

// NOTE(TF-815): This crashed without `DifferentiableAttributeScope` support.
@differentiable(where Element: Differentiable)
var computedProperty: Element {
element
}

var computedPropertyExplicitGetter: Element {
@differentiable(where Element: Differentiable)
get {
element
}
}
}
37 changes: 37 additions & 0 deletions test/NameBinding/astscope-differentiating-attr.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// SWIFT_ENABLE_TENSORFLOW
// Check that ASTScope lookup works for `@differentiating` attribute.

// NOTE(TF-835): This test is only necessary because `@differentiating`
// attribute type-checking generates implicit `@differentiable` attributes
// on the referenced declaration. Robust lowering for `@differentiating`
// attributes should make special logic regarding implicit `@differentiable`
// attributes unnecessary.

// RUN: %target-swift-frontend -typecheck %s -enable-astscope-lookup

struct Test<Element> {
var element: Element
}
extension Test: Differentiable where Element: Differentiable {}
extension Test {
static func +(lhs: Self, rhs: Self) -> Self {
lhs
}
static func -(lhs: Self, rhs: Self) -> Self {
lhs
}
}

extension Test where Element : Differentiable {
@differentiating(+)
internal static func _vjpAdd(lhs: Self, rhs: Self)
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs + rhs, { v in (v, v) })
}

@differentiating(-)
internal static func _vjpSubtract(lhs: Self, rhs: Self)
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs + rhs, { v in (v, v) })
}
}