Skip to content

[AutoDiff] Enable @derivative attribute qualified declaration names. #28892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,11 @@ class DerivativeAttr final
private llvm::TrailingObjects<DerivativeAttr, ParsedAutoDiffParameter> {
friend TrailingObjects;

/// The base type repr for the referenced original function. This field is
/// non-null only for parsed attributes that reference a qualified original
/// declaration. This field is not serialized; type-checking uses it to
/// resolve the original declaration, which is serialized.
TypeRepr *BaseTypeRepr;
/// The original function name.
DeclNameRefWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
Expand All @@ -1803,23 +1808,27 @@ class DerivativeAttr final
Optional<AutoDiffDerivativeFunctionKind> Kind = None;

explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameRefWithLoc original,
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameRefWithLoc original, IndexSubset *indices);
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
IndexSubset *parameterIndices);

public:
static DerivativeAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseTypeRepr,
DeclNameRefWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

static DerivativeAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseTypeRepr,
DeclNameRefWithLoc original,
IndexSubset *indices);
IndexSubset *parameterIndices);

TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; }
DeclNameRefWithLoc getOriginalFunctionName() const {
return OriginalFunctionName;
}
Expand Down Expand Up @@ -1876,9 +1885,10 @@ class TransposeAttr final
private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
friend TrailingObjects;

/// The base type of the original function.
/// This is non-null only when the original function is not top-level (i.e. it
/// is an instance/static method).
/// The base type repr for the referenced original function. This field is
/// non-null only for parsed attributes that reference a qualified original
/// declaration. This field is not serialized; type-checking uses it to
/// resolve the original declaration, which is serialized.
TypeRepr *BaseTypeRepr;
/// The original function name.
DeclNameRefWithLoc OriginalFunctionName;
Expand All @@ -1895,7 +1905,7 @@ class TransposeAttr final

explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameRefWithLoc original,
IndexSubset *indices);
IndexSubset *parameterIndices);

public:
static TransposeAttr *create(ASTContext &context, bool implicit,
Expand All @@ -1906,7 +1916,7 @@ class TransposeAttr final
static TransposeAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameRefWithLoc original,
IndexSubset *indices);
IndexSubset *parameterIndices);

TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; }
DeclNameRefWithLoc getOriginalFunctionName() const {
Expand Down
33 changes: 18 additions & 15 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1518,41 +1518,43 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
}

DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange,
SourceRange baseRange, TypeRepr *baseTypeRepr,
DeclNameRefWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
OriginalFunctionName(std::move(originalName)),
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
NumParsedParameters(params.size()) {
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange,
SourceRange baseRange, TypeRepr *baseTypeRepr,
DeclNameRefWithLoc originalName,
IndexSubset *indices)
IndexSubset *parameterIndices)
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
OriginalFunctionName(std::move(originalName)), ParameterIndices(indices) {
}
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
ParameterIndices(parameterIndices) {}

DerivativeAttr *
DerivativeAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
SourceRange baseRange, DeclNameRefWithLoc originalName,
SourceRange baseRange, TypeRepr *baseTypeRepr,
DeclNameRefWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
void *mem = context.Allocate(size, alignof(DerivativeAttr));
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
return new (mem) DerivativeAttr(implicit, atLoc, baseRange, baseTypeRepr,
std::move(originalName), params);
}

DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseTypeRepr,
DeclNameRefWithLoc originalName,
IndexSubset *indices) {
IndexSubset *parameterIndices) {
void *mem = context.Allocate(sizeof(DerivativeAttr), alignof(DerivativeAttr));
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
std::move(originalName), indices);
return new (mem) DerivativeAttr(implicit, atLoc, baseRange, baseTypeRepr,
std::move(originalName), parameterIndices);
}

TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
Expand All @@ -1568,10 +1570,11 @@ TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,

TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseTypeRepr,
DeclNameRefWithLoc originalName, IndexSubset *indices)
DeclNameRefWithLoc originalName,
IndexSubset *parameterIndices)
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
ParameterIndices(indices) {}
ParameterIndices(parameterIndices) {}

TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
Expand All @@ -1588,10 +1591,10 @@ TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType,
DeclNameRefWithLoc originalName,
IndexSubset *indices) {
IndexSubset *parameterIndices) {
void *mem = context.Allocate(sizeof(TransposeAttr), alignof(TransposeAttr));
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
std::move(originalName), indices);
std::move(originalName), parameterIndices);
}

ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
Expand Down
16 changes: 8 additions & 8 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,8 @@ static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
return false;
}

/// parseQualifiedDeclName
/// Parses an optional base type, followed by a declaration name.
/// Returns true on error (if declaration name could not be parsed).
///
/// \verbatim
/// qualified-decl-name:
Expand All @@ -1120,8 +1121,6 @@ static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
/// identifier generic-args? ('.' identifier generic-args?)*
/// \endverbatim
///
/// Parses an optional base type, followed by a declaration name.
/// Returns true on error (if declaration name could not be parsed).
// TODO(TF-1066): Use module qualified name syntax/parsing instead of custom
// qualified name syntax/parsing.
static bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
Expand All @@ -1147,7 +1146,8 @@ static bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
///
/// \verbatim
/// derivative-attribute-arguments:
/// '(' 'of' ':' decl-name (',' differentiation-params-clause)? ')'
/// '(' 'of' ':' qualified-decl-name (',' differentiation-params-clause)?
/// ')'
/// \endverbatim
ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc,
SourceLoc loc) {
Expand Down Expand Up @@ -1206,16 +1206,16 @@ ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc,
/*DeclModifier*/ false);
return makeParserError();
}
return ParserResult<DerivativeAttr>(
DerivativeAttr::create(Context, /*implicit*/ false, atLoc,
SourceRange(loc, rParenLoc), original, params));
return ParserResult<DerivativeAttr>(DerivativeAttr::create(
Context, /*implicit*/ false, atLoc, SourceRange(loc, rParenLoc), baseType,
original, params));
}

/// Parse a `@transpose(of:)` attribute, returning true on error.
///
/// \verbatim
/// transpose-attribute-arguments:
/// '(' 'of' ':' decl-name (',' transposed-params-clause)? ')'
/// '(' 'of' ':' qualified-decl-name (',' transposed-params-clause)? ')'
/// \endverbatim
ParserResult<TransposeAttr> Parser::parseTransposeAttribute(SourceLoc atLoc,
SourceLoc loc) {
Expand Down
18 changes: 14 additions & 4 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3531,16 +3531,26 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
return derivative->getParent() == func->getParent();
};

auto lookupOptions =
defaultMemberLookupOptions | NameLookupFlags::IgnoreAccessControl;
auto resolution = TypeResolution::forContextual(derivative->getDeclContext());
Type baseType;
if (auto *baseTypeRepr = attr->getBaseTypeRepr()) {
TypeResolutionOptions options = None;
options |= TypeResolutionFlags::AllowModule;
baseType = resolution.resolveType(baseTypeRepr, options);
}
if (baseType && baseType->hasError())
return;
auto lookupOptions = attr->getBaseTypeRepr()
? defaultMemberLookupOptions
: defaultUnqualifiedLookupOptions;
auto derivativeTypeCtx = derivative->getInnermostTypeContext();
if (!derivativeTypeCtx)
derivativeTypeCtx = derivative->getParent();
assert(derivativeTypeCtx);

// Look up original function.
auto *originalAFD = findAbstractFunctionDecl(
originalName.Name, originalName.Loc.getBaseNameLoc(), /*baseType*/ Type(),
originalName.Name, originalName.Loc.getBaseNameLoc(), baseType,
derivativeTypeCtx, isValidOriginal, noneValidDiagnostic,
ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions,
hasValidTypeContext, invalidTypeContextDiagnostic);
Expand Down Expand Up @@ -3667,7 +3677,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
}

// Reject different-file derivative registration.
// TODO(TF-1021): Lift this restriction.
// TODO(TF-1021): Lift same-file derivative registration restriction.
if (originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) {
diagnoseAndRemoveAttr(attr,
diag::derivative_attr_not_in_same_file_as_original);
Expand Down
4 changes: 4 additions & 0 deletions lib/Sema/TypeCheckType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,10 @@ Type TypeChecker::resolveIdentifierType(
if (!result) return nullptr;

if (auto moduleTy = result->getAs<ModuleType>()) {
// Allow module types only if flag is specified.
if (options.contains(TypeResolutionFlags::AllowModule))
return moduleTy;
// Otherwise, emit an error.
if (!options.contains(TypeResolutionFlags::SilenceErrors)) {
auto moduleName = moduleTy->getModule()->getName();
diags.diagnose(Components.back()->getNameLoc(),
Expand Down
3 changes: 3 additions & 0 deletions lib/Sema/TypeCheckType.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ enum class TypeResolutionFlags : uint16_t {

/// Whether we should not produce diagnostics if the type is invalid.
SilenceErrors = 1 << 10,

/// Whether to allow module declaration types.
AllowModule = 1 << 11
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brentdax: I added you as a reviewer because this ad-hoc support for module-qualified and type-qualified declaration names for @derivative and @transpose attributes overlaps with your work on module-qualified names.

Some context: the main "qualified name" use case for @derivative and @transpose is module-qualified names. Type-qualified names are no longer necessary after a @transpose type-checking rules revamp (TF-1060).

After syntax exists for your work on module-qualified names (e.g. Module::fooName), we can drop our ad-hoc support and switch to your work!

};

/// Type resolution contexts that require special handling.
Expand Down
11 changes: 6 additions & 5 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4202,11 +4202,12 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
parametersBitVector[i] = parameters[i];
auto *indices = IndexSubset::get(ctx, parametersBitVector);

auto *derivAttr = DerivativeAttr::create(
ctx, isImplicit, SourceLoc(), SourceRange(), origName, indices);
derivAttr->setOriginalFunction(origDecl);
derivAttr->setDerivativeKind(*derivativeKind);
Attr = derivAttr;
auto *derivativeAttr =
DerivativeAttr::create(ctx, isImplicit, SourceLoc(), SourceRange(),
/*baseType*/ nullptr, origName, indices);
derivativeAttr->setOriginalFunction(origDecl);
derivativeAttr->setDerivativeKind(*derivativeKind);
Attr = derivativeAttr;
break;
}

Expand Down
11 changes: 10 additions & 1 deletion test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ extension StaticMethod {
return (x, { $0 })
}

@derivative(of: foo)
// Test qualified declaration name.
@derivative(of: StaticMethod.foo)
static func vjpFoo(x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
Expand Down Expand Up @@ -232,6 +233,14 @@ extension InstanceMethod {
return (x, { $0 + $1 })
}

// Test qualified declaration name.
@derivative(of: InstanceMethod.foo, wrt: x)
func jvpFooWrtX(x: Self) -> (
value: Self, differential: (TangentVector) -> (TangentVector)
) {
return (x, { $0 })
}

@derivative(of: generic)
func vjpGeneric<T: Differentiable>(_ x: T) -> (
value: Self, pullback: (TangentVector) -> (TangentVector, T.TangentVector)
Expand Down
Loading