Skip to content

Commit 658b7f7

Browse files
authored
[AutoDiff] Support derivative registration for more declaration kinds. (swiftlang#28468)
Make `@differentiating` and `@transposing` attributes support more original declaration kinds: computed properties, subscripts, and initializers. - Change `DifferentiatingAttr` and `TransposingAttr` to store an `AbstractFunctionDecl` representing the original declaration, instead of a `FuncDecl`. - Change attribute parsing to support initializer/subscript `DeclName`s. - Make `TypeChecker::lookupFuncDecl` a static function in TypeCheckAttr.cpp. - Assorted parsing and type-checking gardening. `@differentiating` now has feature parity with `@differentiable(jvp: ..., vjp: ...)` for derivative registration. This is a necessary step towards making `@differentiating` and `@transposing` the canonical mechanism for registering derivative/transpose functions. Registering non-`func` declaration derivatives with `@differentiable` attribute `jvp:`/`vjp:` labels is now explicitly rejected. Resolves TF-281. Todos: - TF-997: support `@transposing` attribute with initializer original declarations. - TF-988: do not reuse `@differentiable` attribute type-checking diagnostics for `@differentiating`/`@transposing` attribute type-checking.
1 parent 64361eb commit 658b7f7

11 files changed

+545
-286
lines changed

include/swift/AST/Attr.h

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,9 +1679,9 @@ class DifferentiatingAttr final
16791679
friend TrailingObjects;
16801680

16811681
/// The original function name.
1682-
DeclNameWithLoc Original;
1683-
/// The original function, resolved by the type checker.
1684-
FuncDecl *OriginalFunction = nullptr;
1682+
DeclNameWithLoc OriginalFunctionName;
1683+
/// The original function declaration, resolved by the type checker.
1684+
AbstractFunctionDecl *OriginalFunction = nullptr;
16851685
/// The number of parsed parameters specified in 'wrt:'.
16861686
unsigned NumParsedParameters = 0;
16871687
/// The differentiation parameters' indices, resolved by the type checker.
@@ -1706,9 +1706,15 @@ class DifferentiatingAttr final
17061706
DeclNameWithLoc original,
17071707
IndexSubset *indices);
17081708

1709-
DeclNameWithLoc getOriginal() const { return Original; }
1710-
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
1711-
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
1709+
DeclNameWithLoc getOriginalFunctionName() const {
1710+
return OriginalFunctionName;
1711+
}
1712+
AbstractFunctionDecl *getOriginalFunction() const {
1713+
return OriginalFunction;
1714+
}
1715+
void setOriginalFunction(AbstractFunctionDecl *decl) {
1716+
OriginalFunction = decl;
1717+
}
17121718

17131719
/// The parsed differentiation parameters, i.e. the list of parameters
17141720
/// specified in 'wrt:'.
@@ -1750,9 +1756,9 @@ class TransposingAttr final
17501756
/// is an instance/static method).
17511757
TypeRepr *BaseType;
17521758
/// The original function name.
1753-
DeclNameWithLoc Original;
1754-
/// The original function, resolved by the type checker.
1755-
FuncDecl *OriginalFunction = nullptr;
1759+
DeclNameWithLoc OriginalFunctionName;
1760+
/// The original function declaration, resolved by the type checker.
1761+
AbstractFunctionDecl *OriginalFunction = nullptr;
17561762
/// The number of parsed parameters specified in 'wrt:'.
17571763
unsigned NumParsedParameters = 0;
17581764
/// The differentiation parameters' indices, resolved by the type checker.
@@ -1779,10 +1785,15 @@ class TransposingAttr final
17791785
IndexSubset *indices);
17801786

17811787
TypeRepr *getBaseType() const { return BaseType; }
1782-
DeclNameWithLoc getOriginal() const { return Original; }
1783-
1784-
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
1785-
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
1788+
DeclNameWithLoc getOriginalFunctionName() const {
1789+
return OriginalFunctionName;
1790+
}
1791+
AbstractFunctionDecl *getOriginalFunction() const {
1792+
return OriginalFunction;
1793+
}
1794+
void setOriginalFunction(AbstractFunctionDecl *decl) {
1795+
OriginalFunction = decl;
1796+
}
17861797

17871798
/// The parsed transposing parameters, i.e. the list of parameters
17881799
/// specified in 'wrt:'.

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2954,9 +2954,8 @@ NOTE(differentiable_attr_duplicate_note,none,
29542954
"other attribute declared here", ())
29552955
ERROR(differentiable_attr_function_not_same_type_context,none,
29562956
"%0 is not defined in the current type context", (DeclName))
2957-
ERROR(differentiable_attr_specified_not_function,none,
2958-
"%0 is not a function to be used as derivative function",
2959-
(DeclName))
2957+
ERROR(differentiable_attr_derivative_not_function,none,
2958+
"registered derivative %0 must be a 'func' declaration", (DeclName))
29602959
ERROR(differentiable_attr_class_derivative_not_final,none,
29612960
"class member derivative must be final", ())
29622961
ERROR(differentiable_attr_ambiguous_function_identifier,none,
@@ -3020,6 +3019,8 @@ ERROR(differentiating_attr_overload_not_found,none,
30203019
"could not find function %0 with expected type %1", (DeclName, Type))
30213020
ERROR(differentiating_attr_not_in_same_file_as_original,none,
30223021
"derivative not in the same file as the original function", ())
3022+
ERROR(differentiating_attr_original_stored_property_unsupported,none,
3023+
"cannot register derivative for stored property %0", (DeclName))
30233024
ERROR(differentiating_attr_original_already_has_derivative,none,
30243025
"a derivative already exists for %0", (DeclName))
30253026

@@ -3033,7 +3034,7 @@ ERROR(transposing_attr_cannot_use_named_wrt_params,none,
30333034
"cannot use named 'wrt' parameters in '@transposing' attribute, found %0",
30343035
(Identifier))
30353036
ERROR(transposing_attr_result_value_not_differentiable,none,
3036-
"'@transposing' attribute requires original function result to "
3037+
"'@transposing' attribute requires original function result %0 to "
30373038
"conform to 'Differentiable'", (Type))
30383039

30393040
// differentiation `wrt` parameters clause

lib/AST/Attr.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
926926
Printer << '(';
927927
auto *attr = cast<DifferentiatingAttr>(this);
928928
auto *derivative = cast<AbstractFunctionDecl>(D);
929-
Printer << attr->getOriginal().Name;
929+
Printer << attr->getOriginalFunctionName().Name;
930930
auto diffParamsString = getDifferentiationParametersClauseString(
931931
derivative, attr->getParameterIndices(), attr->getParsedParameters());
932932
if (!diffParamsString.empty())
@@ -941,7 +941,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
941941
Printer << '(';
942942
auto *attr = cast<TransposingAttr>(this);
943943
auto *transpose = cast<AbstractFunctionDecl>(D);
944-
Printer << attr->getOriginal().Name;
944+
Printer << attr->getOriginalFunctionName().Name;
945945
auto transParamsString = getTransposedParametersClauseString(
946946
transpose, attr->getParameterIndices(), attr->getParsedParameters());
947947
if (!transParamsString.empty())
@@ -1570,19 +1570,21 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
15701570
// SWIFT_ENABLE_TENSORFLOW
15711571
DifferentiatingAttr::DifferentiatingAttr(
15721572
bool implicit, SourceLoc atLoc, SourceRange baseRange,
1573-
DeclNameWithLoc original, ArrayRef<ParsedAutoDiffParameter> params)
1573+
DeclNameWithLoc originalName, ArrayRef<ParsedAutoDiffParameter> params)
15741574
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1575-
Original(std::move(original)), NumParsedParameters(params.size()) {
1575+
OriginalFunctionName(std::move(originalName)),
1576+
NumParsedParameters(params.size()) {
15761577
std::copy(params.begin(), params.end(),
15771578
getTrailingObjects<ParsedAutoDiffParameter>());
15781579
}
15791580

15801581
DifferentiatingAttr::DifferentiatingAttr(bool implicit, SourceLoc atLoc,
15811582
SourceRange baseRange,
1582-
DeclNameWithLoc original,
1583+
DeclNameWithLoc originalName,
15831584
IndexSubset *indices)
15841585
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1585-
Original(std::move(original)), ParameterIndices(indices) {}
1586+
OriginalFunctionName(std::move(originalName)), ParameterIndices(indices) {
1587+
}
15861588

15871589
DifferentiatingAttr *
15881590
DifferentiatingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
@@ -1607,20 +1609,21 @@ DifferentiatingAttr *DifferentiatingAttr::create(ASTContext &context,
16071609

16081610
TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
16091611
SourceRange baseRange, TypeRepr *baseType,
1610-
DeclNameWithLoc original,
1612+
DeclNameWithLoc originalName,
16111613
ArrayRef<ParsedAutoDiffParameter> params)
16121614
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
1613-
BaseType(baseType), Original(std::move(original)),
1615+
BaseType(baseType), OriginalFunctionName(std::move(originalName)),
16141616
NumParsedParameters(params.size()) {
16151617
std::uninitialized_copy(params.begin(), params.end(),
16161618
getTrailingObjects<ParsedAutoDiffParameter>());
16171619
}
16181620

16191621
TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
16201622
SourceRange baseRange, TypeRepr *baseType,
1621-
DeclNameWithLoc original, IndexSubset *indices)
1623+
DeclNameWithLoc originalName,
1624+
IndexSubset *indices)
16221625
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
1623-
BaseType(baseType), Original(std::move(original)),
1626+
BaseType(baseType), OriginalFunctionName(std::move(originalName)),
16241627
ParameterIndices(indices) {}
16251628

16261629
TransposingAttr *

lib/Parse/ParseDecl.cpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,7 +1069,7 @@ bool Parser::parseDifferentiableAttributeArguments(
10691069
SyntaxParsingContext FuncDeclNameContext(
10701070
SyntaxContext, SyntaxKind::FunctionDeclName);
10711071
Diagnostic funcDiag(diag::attr_differentiable_expected_function_name.ID,
1072-
{ label });
1072+
{label});
10731073
result.Name =
10741074
parseUnqualifiedDeclName(/*afterDot=*/false, result.Loc,
10751075
funcDiag, /*allowOperators=*/true,
@@ -1165,11 +1165,14 @@ Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) {
11651165
// Parse the name of the function.
11661166
SyntaxParsingContext FuncDeclNameContext(
11671167
SyntaxContext, SyntaxKind::FunctionDeclName);
1168+
// NOTE: Use `afterDot = true` and `allowDeinitAndSubscript = true` to
1169+
// enable, e.g. `@differentiating(init)` and
1170+
// `@differentiating(subscript)`.
11681171
original.Name = parseUnqualifiedDeclName(
1169-
/*afterDot*/ false, original.Loc,
1172+
/*afterDot*/ true, original.Loc,
11701173
diag::attr_differentiating_expected_original_name,
1171-
/*allowOperators*/ true, /*allowZeroArgCompoundNames*/ true);
1172-
1174+
/*allowOperators*/ true, /*allowZeroArgCompoundNames*/ true,
1175+
/*allowDeinitAndSubscript*/ true);
11731176
if (consumeIfTrailingComma())
11741177
return makeParserError();
11751178
}
@@ -1228,19 +1231,13 @@ bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
12281231
if (parseBaseTypeForQualifiedDeclName(P, baseType))
12291232
return true;
12301233

1231-
// If base type was parsed and has at least one component, then there was a
1232-
// dot before the current token.
1233-
bool afterDot = false;
1234-
if (baseType) {
1235-
if (auto ident = dyn_cast<IdentTypeRepr>(baseType)) {
1236-
auto components = ident->getComponentRange();
1237-
afterDot = std::distance(components.begin(), components.end()) > 0;
1238-
}
1239-
}
1234+
// NOTE: Use `afterDot = true` and `allowDeinitAndSubscript = true` to enable
1235+
// initializer and subscript lookup.
12401236
original.Name =
1241-
P.parseUnqualifiedDeclName(afterDot, original.Loc, nameParseError,
1242-
/*allowOperators*/ true,
1243-
/*allowZeroArgCompoundNames*/ true);
1237+
P.parseUnqualifiedDeclName(/*afterDot*/ true, original.Loc,
1238+
nameParseError, /*allowOperators*/ true,
1239+
/*allowZeroArgCompoundNames*/ true,
1240+
/*allowDeinitAndSubscript*/ true);
12441241

12451242
// The base type is optional, but the final unqualified decl name is not.
12461243
// If name could not be parsed, return true for error.
@@ -1285,7 +1282,6 @@ ParserResult<TransposingAttr> Parser::parseTransposingAttribute(SourceLoc atLoc,
12851282
diag::attr_transposing_expected_original_name,
12861283
baseType, original))
12871284
return makeParserError();
1288-
12891285
if (consumeIfTrailingComma())
12901286
return makeParserError();
12911287
}

0 commit comments

Comments
 (0)