Skip to content

Commit c9c51be

Browse files
authored
[AutoDiff] Attribute gardening. (#29050)
Upstream changes from `tensorflow` branch: - #28932: deprecate `@differentiable(jvp:vjp)` arguments. - #29038: gardening. Additional gardening included.
1 parent 89c5c0e commit c9c51be

File tree

8 files changed

+275
-214
lines changed

8 files changed

+275
-214
lines changed

include/swift/AST/Attr.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,7 +1650,7 @@ class DifferentiableAttr final
16501650

16511651
/// Whether this function is linear (optional).
16521652
bool Linear;
1653-
/// The number of parsed parameters specified in 'wrt:'.
1653+
/// The number of parsed differentiability parameters specified in 'wrt:'.
16541654
unsigned NumParsedParameters = 0;
16551655
/// The JVP function.
16561656
Optional<DeclNameRefWithLoc> JVP;
@@ -1662,7 +1662,7 @@ class DifferentiableAttr final
16621662
/// The VJP function (optional), resolved by the type checker if VJP name is
16631663
/// specified.
16641664
FuncDecl *VJPFunction = nullptr;
1665-
/// The differentiation parameters' indices, resolved by the type checker.
1665+
/// The differentiability parameter indices, resolved by the type checker.
16661666
IndexSubset *ParameterIndices = nullptr;
16671667
/// The trailing where clause (optional).
16681668
TrailingWhereClause *WhereClause = nullptr;
@@ -1720,7 +1720,7 @@ class DifferentiableAttr final
17201720
ParameterIndices = parameterIndices;
17211721
}
17221722

1723-
/// The parsed differentiation parameters, i.e. the list of parameters
1723+
/// The parsed differentiability parameters, i.e. the list of parameters
17241724
/// specified in 'wrt:'.
17251725
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
17261726
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
@@ -1771,15 +1771,15 @@ class DifferentiableAttr final
17711771
///
17721772
/// The `@derivative(of:)` attribute also has an optional `wrt:` clause
17731773
/// specifying the parameters that are differentiated "with respect to", i.e.
1774-
/// the differentiation parameters. The differentiation parameters must conform
1775-
/// to the `Differentiable` protocol.
1774+
/// the differentiability parameters. The differentiability parameters must
1775+
/// conform to the `Differentiable` protocol.
17761776
///
1777-
/// If the `wrt:` clause is unspecified, the differentiation parameters are
1777+
/// If the `wrt:` clause is unspecified, the differentiability parameters are
17781778
/// inferred to be all parameters that conform to `Differentiable`.
17791779
///
17801780
/// `@derivative(of:)` attribute type-checking verifies that the type of the
17811781
/// derivative function declaration is consistent with the type of the
1782-
/// referenced original declaration and the differentiation parameters.
1782+
/// referenced original declaration and the differentiability parameters.
17831783
///
17841784
/// Examples:
17851785
/// @derivative(of: sin(_:))
@@ -1798,9 +1798,9 @@ class DerivativeAttr final
17981798
DeclNameRefWithLoc OriginalFunctionName;
17991799
/// The original function declaration, resolved by the type checker.
18001800
AbstractFunctionDecl *OriginalFunction = nullptr;
1801-
/// The number of parsed parameters specified in 'wrt:'.
1801+
/// The number of parsed differentiability parameters specified in 'wrt:'.
18021802
unsigned NumParsedParameters = 0;
1803-
/// The differentiation parameters' indices, resolved by the type checker.
1803+
/// The differentiability parameter indices, resolved by the type checker.
18041804
IndexSubset *ParameterIndices = nullptr;
18051805
/// The derivative function kind (JVP or VJP), resolved by the type checker.
18061806
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
@@ -1843,7 +1843,7 @@ class DerivativeAttr final
18431843
}
18441844
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }
18451845

1846-
/// The parsed differentiation parameters, i.e. the list of parameters
1846+
/// The parsed differentiability parameters, i.e. the list of parameters
18471847
/// specified in 'wrt:'.
18481848
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
18491849
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
@@ -1872,7 +1872,7 @@ class DerivativeAttr final
18721872
/// computed property declaration.
18731873
///
18741874
/// The `@transpose(of:)` attribute also has a `wrt:` clause specifying the
1875-
/// parameters that are transposed "with respect to", i.e. the transposed
1875+
/// parameters that are transposed "with respect to", i.e. the linearity
18761876
/// parameters.
18771877
///
18781878
/// Examples:
@@ -1892,9 +1892,9 @@ class TransposeAttr final
18921892
DeclNameRefWithLoc OriginalFunctionName;
18931893
/// The original function declaration, resolved by the type checker.
18941894
AbstractFunctionDecl *OriginalFunction = nullptr;
1895-
/// The number of parsed parameters specified in 'wrt:'.
1895+
/// The number of parsed linearity parameters specified in 'wrt:'.
18961896
unsigned NumParsedParameters = 0;
1897-
/// The transposed parameters' indices, resolved by the type checker.
1897+
/// The linearity parameter indices, resolved by the type checker.
18981898
IndexSubset *ParameterIndices = nullptr;
18991899

19001900
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
@@ -1927,7 +1927,7 @@ class TransposeAttr final
19271927
OriginalFunction = decl;
19281928
}
19291929

1930-
/// The parsed transposed parameters, i.e. the list of parameters specified in
1930+
/// The parsed linearity parameters, i.e. the list of parameters specified in
19311931
/// 'wrt:'.
19321932
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
19331933
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};

include/swift/AST/DiagnosticsParse.def

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,21 +1542,24 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
15421542
"expected a member name as second parameter in '_implements' attribute", ())
15431543

15441544
// differentiable
1545+
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
15451546
ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken,
15461547
"expected a %0 function name", (StringRef))
15471548
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
15481549
"expected a list of parameters to differentiate with respect to", ())
1550+
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
15491551
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
15501552
"use 'wrt:' to specify parameters to differentiate with respect to", ())
1551-
ERROR(attr_differentiable_missing_label,PointsToFirstBadToken,
1552-
"missing label '%0:' in '@differentiable' attribute", (StringRef))
15531553
ERROR(attr_differentiable_expected_label,none,
15541554
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
15551555
"or 'vjp:'", ())
1556-
ERROR(differentiable_attribute_expected_rparen,none,
1557-
"expected ')' in '@differentiable' attribute", ())
1558-
ERROR(unexpected_argument_differentiable,none,
1556+
ERROR(attr_differentiable_unexpected_argument,none,
15591557
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))
1558+
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
1559+
WARNING(attr_differentiable_jvp_vjp_deprecated_warning,none,
1560+
"'jvp:' and 'vjp:' arguments in '@differentiable' attribute are "
1561+
"deprecated; use '@derivative' attribute for derivative registration "
1562+
"instead", ())
15601563

15611564
// differentiation `wrt` parameters clause
15621565
ERROR(expected_colon_after_label,PointsToFirstBadToken,

include/swift/Parse/Parser.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,12 +1002,13 @@ class Parser {
10021002
Optional<DeclNameRefWithLoc> &vjpSpec,
10031003
TrailingWhereClause *&whereClause);
10041004

1005-
/// Parse a differentiation parameters clause, i.e. the 'wrt:' clause in
1006-
/// `@differentiable` and `@derivative` attributes.
1005+
/// Parse a differentiability parameters clause, i.e. the 'wrt:' clause in
1006+
/// `@differentiable`, `@derivative`, and `@transpose` attributes.
1007+
///
10071008
/// If `allowNamedParameters` is false, allow only index parameters and
1008-
/// 'self'.
1009-
bool parseDifferentiationParametersClause(
1010-
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName,
1009+
/// 'self'. Used for `@transpose` attributes.
1010+
bool parseDifferentiabilityParametersClause(
1011+
SmallVectorImpl<ParsedAutoDiffParameter> &parameters, StringRef attrName,
10111012
bool allowNamedParameters = true);
10121013

10131014
/// Parse the @derivative attribute.

lib/AST/Attr.cpp

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -366,51 +366,60 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
366366
Printer.printNewline();
367367
}
368368

369-
/// Printing style for a differentiation parameter in a `wrt:` differentiation
370-
/// parameters clause. Used for printing `@differentiable`, `@derivative`, and
371-
/// `@transpose` attributes.
372-
enum class DifferentiationParameterPrintingStyle {
373-
/// Print parameter by name.
369+
/// The kind of a parameter in a `wrt:` differentiation parameters clause:
370+
/// either a differentiability parameter or a linearity parameter. Used for
371+
/// printing `@differentiable`, `@derivative`, and `@transpose` attributes.
372+
enum class DifferentiationParameterKind {
373+
/// A differentiability parameter, printed by name.
374374
/// Used for `@differentiable` and `@derivative` attribute.
375-
Name,
376-
/// Print parameter by index.
375+
Differentiability,
376+
/// A linearity parameter, printed by index.
377377
/// Used for `@transpose` attribute.
378-
Index
378+
Linearity
379379
};
380380

381381
/// Returns the differentiation parameters clause string for the given function,
382-
/// parameter indices, parsed parameters, . Use the parameter indices if
383-
/// specified; otherwise, use the parsed parameters.
382+
/// parameter indices, parsed parameters, and differentiation parameter kind.
383+
/// Use the parameter indices if specified; otherwise, use the parsed
384+
/// parameters.
384385
static std::string getDifferentiationParametersClauseString(
385-
const AbstractFunctionDecl *function, IndexSubset *paramIndices,
386+
const AbstractFunctionDecl *function, IndexSubset *parameterIndices,
386387
ArrayRef<ParsedAutoDiffParameter> parsedParams,
387-
DifferentiationParameterPrintingStyle style) {
388+
DifferentiationParameterKind parameterKind) {
388389
assert(function);
389390
bool isInstanceMethod = function->isInstanceMember();
391+
bool isStaticMethod = function->isStatic();
390392
std::string result;
391393
llvm::raw_string_ostream printer(result);
392394

393395
// Use the parameter indices, if specified.
394-
if (paramIndices) {
395-
auto parameters = paramIndices->getBitVector();
396+
if (parameterIndices) {
397+
auto parameters = parameterIndices->getBitVector();
396398
auto parameterCount = parameters.count();
397399
printer << "wrt: ";
398400
if (parameterCount > 1)
399401
printer << '(';
400402
// Check if differentiating wrt `self`. If so, manually print it first.
401-
if (isInstanceMethod && parameters.test(parameters.size() - 1)) {
403+
bool isWrtSelf =
404+
(isInstanceMethod ||
405+
(isStaticMethod &&
406+
parameterKind == DifferentiationParameterKind::Linearity)) &&
407+
parameters.test(parameters.size() - 1);
408+
if (isWrtSelf) {
402409
parameters.reset(parameters.size() - 1);
403410
printer << "self";
404411
if (parameters.any())
405412
printer << ", ";
406413
}
407414
// Print remaining differentiation parameters.
408415
interleave(parameters.set_bits(), [&](unsigned index) {
409-
switch (style) {
410-
case DifferentiationParameterPrintingStyle::Name:
416+
switch (parameterKind) {
417+
// Print differentiability parameters by name.
418+
case DifferentiationParameterKind::Differentiability:
411419
printer << function->getParameters()->get(index)->getName().str();
412420
break;
413-
case DifferentiationParameterPrintingStyle::Index:
421+
// Print linearity parameters by index.
422+
case DifferentiationParameterKind::Linearity:
414423
printer << index;
415424
break;
416425
}
@@ -487,7 +496,7 @@ static void printDifferentiableAttrArguments(
487496
if (!omitWrtClause) {
488497
auto diffParamsString = getDifferentiationParametersClauseString(
489498
original, attr->getParameterIndices(), attr->getParsedParameters(),
490-
DifferentiationParameterPrintingStyle::Name);
499+
DifferentiationParameterKind::Differentiability);
491500
// Check whether differentiation parameter clause is empty.
492501
// Handles edge case where resolved parameter indices are unset and
493502
// parsed parameters are empty. This case should never trigger for
@@ -927,7 +936,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
927936
auto *derivative = cast<AbstractFunctionDecl>(D);
928937
auto diffParamsString = getDifferentiationParametersClauseString(
929938
derivative, attr->getParameterIndices(), attr->getParsedParameters(),
930-
DifferentiationParameterPrintingStyle::Name);
939+
DifferentiationParameterKind::Differentiability);
931940
if (!diffParamsString.empty())
932941
Printer << ", " << diffParamsString;
933942
Printer << ')';
@@ -942,7 +951,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
942951
auto *transpose = cast<AbstractFunctionDecl>(D);
943952
auto transParamsString = getDifferentiationParametersClauseString(
944953
transpose, attr->getParameterIndices(), attr->getParsedParameters(),
945-
DifferentiationParameterPrintingStyle::Index);
954+
DifferentiationParameterKind::Linearity);
946955
if (!transParamsString.empty())
947956
Printer << ", " << transParamsString;
948957
Printer << ')';
@@ -1510,11 +1519,11 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment(
15101519

15111520
void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
15121521
bool omitWrtClause,
1513-
bool omitAssociatedFunctions) const {
1522+
bool omitDerivativeFunctions) const {
15141523
StreamPrinter P(OS);
15151524
P << "@" << getAttrName();
15161525
printDifferentiableAttrArguments(this, P, PrintOptions(), D, omitWrtClause,
1517-
omitAssociatedFunctions);
1526+
omitDerivativeFunctions);
15181527
}
15191528

15201529
DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,

0 commit comments

Comments
 (0)