Skip to content

[AutoDiff] Attribute gardening. #29050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,7 @@ class DifferentiableAttr final

/// Whether this function is linear (optional).
bool Linear;
/// The number of parsed parameters specified in 'wrt:'.
/// The number of parsed differentiability parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The JVP function.
Optional<DeclNameRefWithLoc> JVP;
Expand All @@ -1662,7 +1662,7 @@ class DifferentiableAttr final
/// The VJP function (optional), resolved by the type checker if VJP name is
/// specified.
FuncDecl *VJPFunction = nullptr;
/// The differentiation parameters' indices, resolved by the type checker.
/// The differentiability parameter indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The trailing where clause (optional).
TrailingWhereClause *WhereClause = nullptr;
Expand Down Expand Up @@ -1720,7 +1720,7 @@ class DifferentiableAttr final
ParameterIndices = parameterIndices;
}

/// The parsed differentiation parameters, i.e. the list of parameters
/// The parsed differentiability parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
Expand Down Expand Up @@ -1771,15 +1771,15 @@ class DifferentiableAttr final
///
/// The `@derivative(of:)` attribute also has an optional `wrt:` clause
/// specifying the parameters that are differentiated "with respect to", i.e.
/// the differentiation parameters. The differentiation parameters must conform
/// to the `Differentiable` protocol.
/// the differentiability parameters. The differentiability parameters must
/// conform to the `Differentiable` protocol.
///
/// If the `wrt:` clause is unspecified, the differentiation parameters are
/// If the `wrt:` clause is unspecified, the differentiability parameters are
/// inferred to be all parameters that conform to `Differentiable`.
///
/// `@derivative(of:)` attribute type-checking verifies that the type of the
/// derivative function declaration is consistent with the type of the
/// referenced original declaration and the differentiation parameters.
/// referenced original declaration and the differentiability parameters.
///
/// Examples:
/// @derivative(of: sin(_:))
Expand All @@ -1798,9 +1798,9 @@ class DerivativeAttr final
DeclNameRefWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The number of parsed parameters specified in 'wrt:'.
/// The number of parsed differentiability parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
/// The differentiability parameter indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The derivative function kind (JVP or VJP), resolved by the type checker.
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
Expand Down Expand Up @@ -1843,7 +1843,7 @@ class DerivativeAttr final
}
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }

/// The parsed differentiation parameters, i.e. the list of parameters
/// The parsed differentiability parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
Expand Down Expand Up @@ -1872,7 +1872,7 @@ class DerivativeAttr final
/// computed property declaration.
///
/// The `@transpose(of:)` attribute also has a `wrt:` clause specifying the
/// parameters that are transposed "with respect to", i.e. the transposed
/// parameters that are transposed "with respect to", i.e. the linearity
/// parameters.
///
/// Examples:
Expand All @@ -1892,9 +1892,9 @@ class TransposeAttr final
DeclNameRefWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The number of parsed parameters specified in 'wrt:'.
/// The number of parsed linearity parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The transposed parameters' indices, resolved by the type checker.
/// The linearity parameter indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;

explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
Expand Down Expand Up @@ -1927,7 +1927,7 @@ class TransposeAttr final
OriginalFunction = decl;
}

/// The parsed transposed parameters, i.e. the list of parameters specified in
/// The parsed linearity parameters, i.e. the list of parameters specified in
/// 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
Expand Down
13 changes: 8 additions & 5 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1542,21 +1542,24 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
"expected a member name as second parameter in '_implements' attribute", ())

// differentiable
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken,
"expected a %0 function name", (StringRef))
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
"expected a list of parameters to differentiate with respect to", ())
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
"use 'wrt:' to specify parameters to differentiate with respect to", ())
ERROR(attr_differentiable_missing_label,PointsToFirstBadToken,
"missing label '%0:' in '@differentiable' attribute", (StringRef))
ERROR(attr_differentiable_expected_label,none,
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
"or 'vjp:'", ())
ERROR(differentiable_attribute_expected_rparen,none,
"expected ')' in '@differentiable' attribute", ())
ERROR(unexpected_argument_differentiable,none,
ERROR(attr_differentiable_unexpected_argument,none,
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
WARNING(attr_differentiable_jvp_vjp_deprecated_warning,none,
"'jvp:' and 'vjp:' arguments in '@differentiable' attribute are "
"deprecated; use '@derivative' attribute for derivative registration "
"instead", ())

// differentiation `wrt` parameters clause
ERROR(expected_colon_after_label,PointsToFirstBadToken,
Expand Down
11 changes: 6 additions & 5 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -1002,12 +1002,13 @@ class Parser {
Optional<DeclNameRefWithLoc> &vjpSpec,
TrailingWhereClause *&whereClause);

/// Parse a differentiation parameters clause, i.e. the 'wrt:' clause in
/// `@differentiable` and `@derivative` attributes.
/// Parse a differentiability parameters clause, i.e. the 'wrt:' clause in
/// `@differentiable`, `@derivative`, and `@transpose` attributes.
///
/// If `allowNamedParameters` is false, allow only index parameters and
/// 'self'.
bool parseDifferentiationParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName,
/// 'self'. Used for `@transpose` attributes.
bool parseDifferentiabilityParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &parameters, StringRef attrName,
bool allowNamedParameters = true);

/// Parse the @derivative attribute.
Expand Down
55 changes: 32 additions & 23 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,51 +366,60 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
Printer.printNewline();
}

/// Printing style for a differentiation parameter in a `wrt:` differentiation
/// parameters clause. Used for printing `@differentiable`, `@derivative`, and
/// `@transpose` attributes.
enum class DifferentiationParameterPrintingStyle {
/// Print parameter by name.
/// The kind of a parameter in a `wrt:` differentiation parameters clause:
/// either a differentiability parameter or a linearity parameter. Used for
/// printing `@differentiable`, `@derivative`, and `@transpose` attributes.
enum class DifferentiationParameterKind {
/// A differentiability parameter, printed by name.
/// Used for `@differentiable` and `@derivative` attribute.
Name,
/// Print parameter by index.
Differentiability,
/// A linearity parameter, printed by index.
/// Used for `@transpose` attribute.
Index
Linearity
};

/// Returns the differentiation parameters clause string for the given function,
/// parameter indices, parsed parameters, . Use the parameter indices if
/// specified; otherwise, use the parsed parameters.
/// parameter indices, parsed parameters, and differentiation parameter kind.
/// Use the parameter indices if specified; otherwise, use the parsed
/// parameters.
static std::string getDifferentiationParametersClauseString(
const AbstractFunctionDecl *function, IndexSubset *paramIndices,
const AbstractFunctionDecl *function, IndexSubset *parameterIndices,
ArrayRef<ParsedAutoDiffParameter> parsedParams,
DifferentiationParameterPrintingStyle style) {
DifferentiationParameterKind parameterKind) {
assert(function);
bool isInstanceMethod = function->isInstanceMember();
bool isStaticMethod = function->isStatic();
std::string result;
llvm::raw_string_ostream printer(result);

// Use the parameter indices, if specified.
if (paramIndices) {
auto parameters = paramIndices->getBitVector();
if (parameterIndices) {
auto parameters = parameterIndices->getBitVector();
auto parameterCount = parameters.count();
printer << "wrt: ";
if (parameterCount > 1)
printer << '(';
// Check if differentiating wrt `self`. If so, manually print it first.
if (isInstanceMethod && parameters.test(parameters.size() - 1)) {
bool isWrtSelf =
(isInstanceMethod ||
(isStaticMethod &&
parameterKind == DifferentiationParameterKind::Linearity)) &&
parameters.test(parameters.size() - 1);
if (isWrtSelf) {
parameters.reset(parameters.size() - 1);
printer << "self";
if (parameters.any())
printer << ", ";
}
// Print remaining differentiation parameters.
interleave(parameters.set_bits(), [&](unsigned index) {
switch (style) {
case DifferentiationParameterPrintingStyle::Name:
switch (parameterKind) {
// Print differentiability parameters by name.
case DifferentiationParameterKind::Differentiability:
printer << function->getParameters()->get(index)->getName().str();
break;
case DifferentiationParameterPrintingStyle::Index:
// Print linearity parameters by index.
case DifferentiationParameterKind::Linearity:
printer << index;
break;
}
Expand Down Expand Up @@ -487,7 +496,7 @@ static void printDifferentiableAttrArguments(
if (!omitWrtClause) {
auto diffParamsString = getDifferentiationParametersClauseString(
original, attr->getParameterIndices(), attr->getParsedParameters(),
DifferentiationParameterPrintingStyle::Name);
DifferentiationParameterKind::Differentiability);
// Check whether differentiation parameter clause is empty.
// Handles edge case where resolved parameter indices are unset and
// parsed parameters are empty. This case should never trigger for
Expand Down Expand Up @@ -927,7 +936,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
auto *derivative = cast<AbstractFunctionDecl>(D);
auto diffParamsString = getDifferentiationParametersClauseString(
derivative, attr->getParameterIndices(), attr->getParsedParameters(),
DifferentiationParameterPrintingStyle::Name);
DifferentiationParameterKind::Differentiability);
if (!diffParamsString.empty())
Printer << ", " << diffParamsString;
Printer << ')';
Expand All @@ -942,7 +951,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
auto *transpose = cast<AbstractFunctionDecl>(D);
auto transParamsString = getDifferentiationParametersClauseString(
transpose, attr->getParameterIndices(), attr->getParsedParameters(),
DifferentiationParameterPrintingStyle::Index);
DifferentiationParameterKind::Linearity);
if (!transParamsString.empty())
Printer << ", " << transParamsString;
Printer << ')';
Expand Down Expand Up @@ -1510,11 +1519,11 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment(

void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
bool omitWrtClause,
bool omitAssociatedFunctions) const {
bool omitDerivativeFunctions) const {
StreamPrinter P(OS);
P << "@" << getAttrName();
printDifferentiableAttrArguments(this, P, PrintOptions(), D, omitWrtClause,
omitAssociatedFunctions);
omitDerivativeFunctions);
}

DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
Expand Down
Loading