Skip to content

[AutoDiff] Support derivative registration for more declaration kinds. #28468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1679,9 +1679,9 @@ class DifferentiatingAttr final
friend TrailingObjects;

/// The original function name.
DeclNameWithLoc Original;
/// The original function, resolved by the type checker.
FuncDecl *OriginalFunction = nullptr;
DeclNameWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
Expand All @@ -1706,9 +1706,15 @@ class DifferentiatingAttr final
DeclNameWithLoc original,
IndexSubset *indices);

DeclNameWithLoc getOriginal() const { return Original; }
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
DeclNameWithLoc getOriginalFunctionName() const {
return OriginalFunctionName;
}
AbstractFunctionDecl *getOriginalFunction() const {
return OriginalFunction;
}
void setOriginalFunction(AbstractFunctionDecl *decl) {
OriginalFunction = decl;
}

/// The parsed differentiation parameters, i.e. the list of parameters
/// specified in 'wrt:'.
Expand Down Expand Up @@ -1750,9 +1756,9 @@ class TransposingAttr final
/// is an instance/static method).
TypeRepr *BaseType;
/// The original function name.
DeclNameWithLoc Original;
/// The original function, resolved by the type checker.
FuncDecl *OriginalFunction = nullptr;
DeclNameWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
Expand All @@ -1779,10 +1785,15 @@ class TransposingAttr final
IndexSubset *indices);

TypeRepr *getBaseType() const { return BaseType; }
DeclNameWithLoc getOriginal() const { return Original; }

FuncDecl *getOriginalFunction() const { return OriginalFunction; }
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
DeclNameWithLoc getOriginalFunctionName() const {
return OriginalFunctionName;
}
AbstractFunctionDecl *getOriginalFunction() const {
return OriginalFunction;
}
void setOriginalFunction(AbstractFunctionDecl *decl) {
OriginalFunction = decl;
}

/// The parsed transposing parameters, i.e. the list of parameters
/// specified in 'wrt:'.
Expand Down
9 changes: 5 additions & 4 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2954,9 +2954,8 @@ NOTE(differentiable_attr_duplicate_note,none,
"other attribute declared here", ())
ERROR(differentiable_attr_function_not_same_type_context,none,
"%0 is not defined in the current type context", (DeclName))
ERROR(differentiable_attr_specified_not_function,none,
"%0 is not a function to be used as derivative function",
(DeclName))
ERROR(differentiable_attr_derivative_not_function,none,
"registered derivative %0 must be a 'func' declaration", (DeclName))
ERROR(differentiable_attr_class_derivative_not_final,none,
"class member derivative must be final", ())
ERROR(differentiable_attr_ambiguous_function_identifier,none,
Expand Down Expand Up @@ -3020,6 +3019,8 @@ ERROR(differentiating_attr_overload_not_found,none,
"could not find function %0 with expected type %1", (DeclName, Type))
ERROR(differentiating_attr_not_in_same_file_as_original,none,
"derivative not in the same file as the original function", ())
ERROR(differentiating_attr_original_stored_property_unsupported,none,
"cannot register derivative for stored property %0", (DeclName))
ERROR(differentiating_attr_original_already_has_derivative,none,
"a derivative already exists for %0", (DeclName))

Expand All @@ -3033,7 +3034,7 @@ ERROR(transposing_attr_cannot_use_named_wrt_params,none,
"cannot use named 'wrt' parameters in '@transposing' attribute, found %0",
(Identifier))
ERROR(transposing_attr_result_value_not_differentiable,none,
"'@transposing' attribute requires original function result to "
"'@transposing' attribute requires original function result %0 to "
"conform to 'Differentiable'", (Type))

// differentiation `wrt` parameters clause
Expand Down
23 changes: 13 additions & 10 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
Printer << '(';
auto *attr = cast<DifferentiatingAttr>(this);
auto *derivative = cast<AbstractFunctionDecl>(D);
Printer << attr->getOriginal().Name;
Printer << attr->getOriginalFunctionName().Name;
auto diffParamsString = getDifferentiationParametersClauseString(
derivative, attr->getParameterIndices(), attr->getParsedParameters());
if (!diffParamsString.empty())
Expand All @@ -941,7 +941,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
Printer << '(';
auto *attr = cast<TransposingAttr>(this);
auto *transpose = cast<AbstractFunctionDecl>(D);
Printer << attr->getOriginal().Name;
Printer << attr->getOriginalFunctionName().Name;
auto transParamsString = getTransposedParametersClauseString(
transpose, attr->getParameterIndices(), attr->getParsedParameters());
if (!transParamsString.empty())
Expand Down Expand Up @@ -1570,19 +1570,21 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
// SWIFT_ENABLE_TENSORFLOW
DifferentiatingAttr::DifferentiatingAttr(
bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, ArrayRef<ParsedAutoDiffParameter> params)
DeclNameWithLoc originalName, ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
Original(std::move(original)), NumParsedParameters(params.size()) {
OriginalFunctionName(std::move(originalName)),
NumParsedParameters(params.size()) {
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

DifferentiatingAttr::DifferentiatingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange,
DeclNameWithLoc original,
DeclNameWithLoc originalName,
IndexSubset *indices)
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
Original(std::move(original)), ParameterIndices(indices) {}
OriginalFunctionName(std::move(originalName)), ParameterIndices(indices) {
}

DifferentiatingAttr *
DifferentiatingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
Expand All @@ -1607,20 +1609,21 @@ DifferentiatingAttr *DifferentiatingAttr::create(ASTContext &context,

TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc original,
DeclNameWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
BaseType(baseType), Original(std::move(original)),
BaseType(baseType), OriginalFunctionName(std::move(originalName)),
NumParsedParameters(params.size()) {
std::uninitialized_copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc original, IndexSubset *indices)
DeclNameWithLoc originalName,
IndexSubset *indices)
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
BaseType(baseType), Original(std::move(original)),
BaseType(baseType), OriginalFunctionName(std::move(originalName)),
ParameterIndices(indices) {}

TransposingAttr *
Expand Down
30 changes: 13 additions & 17 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ bool Parser::parseDifferentiableAttributeArguments(
SyntaxParsingContext FuncDeclNameContext(
SyntaxContext, SyntaxKind::FunctionDeclName);
Diagnostic funcDiag(diag::attr_differentiable_expected_function_name.ID,
{ label });
{label});
result.Name =
parseUnqualifiedDeclName(/*afterDot=*/false, result.Loc,
funcDiag, /*allowOperators=*/true,
Expand Down Expand Up @@ -1165,11 +1165,14 @@ Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) {
// Parse the name of the function.
SyntaxParsingContext FuncDeclNameContext(
SyntaxContext, SyntaxKind::FunctionDeclName);
// NOTE: Use `afterDot = true` and `allowDeinitAndSubscript = true` to
// enable, e.g. `@differentiating(init)` and
// `@differentiating(subscript)`.
original.Name = parseUnqualifiedDeclName(
/*afterDot*/ false, original.Loc,
/*afterDot*/ true, original.Loc,
diag::attr_differentiating_expected_original_name,
/*allowOperators*/ true, /*allowZeroArgCompoundNames*/ true);

/*allowOperators*/ true, /*allowZeroArgCompoundNames*/ true,
/*allowDeinitAndSubscript*/ true);
if (consumeIfTrailingComma())
return makeParserError();
}
Expand Down Expand Up @@ -1228,19 +1231,13 @@ bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
if (parseBaseTypeForQualifiedDeclName(P, baseType))
return true;

// If base type was parsed and has at least one component, then there was a
// dot before the current token.
bool afterDot = false;
if (baseType) {
if (auto ident = dyn_cast<IdentTypeRepr>(baseType)) {
auto components = ident->getComponentRange();
afterDot = std::distance(components.begin(), components.end()) > 0;
}
}
// NOTE: Use `afterDot = true` and `allowDeinitAndSubscript = true` to enable
// initializer and subscript lookup.
original.Name =
P.parseUnqualifiedDeclName(afterDot, original.Loc, nameParseError,
/*allowOperators*/ true,
/*allowZeroArgCompoundNames*/ true);
P.parseUnqualifiedDeclName(/*afterDot*/ true, original.Loc,
nameParseError, /*allowOperators*/ true,
/*allowZeroArgCompoundNames*/ true,
/*allowDeinitAndSubscript*/ true);

// The base type is optional, but the final unqualified decl name is not.
// If name could not be parsed, return true for error.
Expand Down Expand Up @@ -1285,7 +1282,6 @@ ParserResult<TransposingAttr> Parser::parseTransposingAttribute(SourceLoc atLoc,
diag::attr_transposing_expected_original_name,
baseType, original))
return makeParserError();

if (consumeIfTrailingComma())
return makeParserError();
}
Expand Down
Loading