Skip to content

[AutoDiff upstream] Add @transpose attribute. #27545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,11 @@ DECL_ATTR(_implicitly_synthesizes_nested_requirement, ImplicitlySynthesizesNeste
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,
98)

DECL_ATTR(transpose, Transpose,
OnFunc | LongAttribute | AllowMultipleAttributes |
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
99)

#undef TYPE_ATTR
#undef DECL_ATTR_ALIAS
#undef CONTEXTUAL_DECL_ATTR_ALIAS
Expand Down
99 changes: 91 additions & 8 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1765,19 +1765,19 @@ class DifferentiableAttr final
}
};

/// The `@derivative` attribute registers a function as a derivative of another
/// function-like declaration: a 'func', 'init', 'subscript', or 'var' computed
/// property declaration.
/// The `@derivative(of:)` attribute registers a function as a derivative of
/// another function-like declaration: a 'func', 'init', 'subscript', or 'var'
/// computed property declaration.
///
/// The `@derivative` attribute also has an optional `wrt:` clause specifying
/// the parameters that are differentiated "with respect to", i.e. the
/// differentiation parameters. The differentiation parameters must conform to
/// the `Differentiable` protocol.
/// The `@derivative(of:)` attribute also has an optional `wrt:` clause
/// specifying the parameters that are differentiated "with respect to", i.e.
/// the differentiation parameters. The differentiation parameters must conform
/// to the `Differentiable` protocol.
///
/// If the `wrt:` clause is unspecified, the differentiation parameters are
/// inferred to be all parameters that conform to `Differentiable`.
///
/// `@derivative` attribute type-checking verifies that the type of the
/// `@derivative(of:)` attribute type-checking verifies that the type of the
/// derivative function declaration is consistent with the type of the
/// referenced original declaration and the differentiation parameters.
///
Expand Down Expand Up @@ -1858,6 +1858,89 @@ class DerivativeAttr final
}
};

/// The `@transpose(of:)` attribute registers a function as a transpose of
/// another function-like declaration: a 'func', 'init', 'subscript', or 'var'
/// computed property declaration.
///
/// The `@transpose(of:)` attribute also has a `wrt:` clause specifying the
/// parameters that are transposed "with respect to", i.e. the transposed
/// parameters.
///
/// Examples:
/// @transpose(of: foo)
/// @transpose(of: +, wrt: (0, 1))
class TransposeAttr final
: public DeclAttribute,
private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
friend TrailingObjects;

/// The base type of the original function.
/// This is non-null only when the original function is not top-level (i.e. it
/// is an instance/static method).
TypeRepr *BaseTypeRepr;
/// The original function name.
DeclNameRefWithLoc OriginalFunctionName;
/// The original function declaration, resolved by the type checker.
AbstractFunctionDecl *OriginalFunction = nullptr;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The transposed parameters' indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;

explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameRefWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameRefWithLoc original,
IndexSubset *indices);

public:
static TransposeAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameRefWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

static TransposeAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameRefWithLoc original,
IndexSubset *indices);

TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; }
DeclNameRefWithLoc getOriginalFunctionName() const {
return OriginalFunctionName;
}
AbstractFunctionDecl *getOriginalFunction() const {
return OriginalFunction;
}
void setOriginalFunction(AbstractFunctionDecl *decl) {
OriginalFunction = decl;
}

/// The parsed transposed parameters, i.e. the list of parameters specified in
/// 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
}
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
}
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
return NumParsedParameters;
}

IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(IndexSubset *parameterIndices) {
ParameterIndices = parameterIndices;
}

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Transpose;
}
};

/// Attributes that may be applied to declarations.
class DeclAttributes {
/// Linked list of declaration attributes.
Expand Down
7 changes: 5 additions & 2 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1560,9 +1560,12 @@ ERROR(expected_colon_after_label,PointsToFirstBadToken,
ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
"expected a parameter, which can be a function parameter name, "
"parameter index, or 'self'", ())
ERROR(diff_params_clause_expected_parameter_unnamed,PointsToFirstBadToken,
"expected a parameter, which can be a function parameter index or 'self'",
())

// derivative
ERROR(attr_derivative_expected_original_name,PointsToFirstBadToken,
// Automatic differentiation attributes
ERROR(autodiff_attr_expected_original_decl_name,PointsToFirstBadToken,
"expected an original function name", ())

//------------------------------------------------------------------------------
Expand Down
41 changes: 38 additions & 3 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -994,14 +994,22 @@ class Parser {
Optional<DeclNameRefWithLoc> &vjpSpec,
TrailingWhereClause *&whereClause);

/// Parse a differentiation parameters clause.
/// Parse a differentiation parameters clause, i.e. the 'wrt:' clause in
/// `@differentiable` and `@derivative` attributes.
/// If `allowNamedParameters` is false, allow only index parameters and
/// 'self'.
bool parseDifferentiationParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName,
bool allowNamedParameters = true);

/// Parse the @derivative attribute.
ParserResult<DerivativeAttr> parseDerivativeAttribute(SourceLoc AtLoc,
SourceLoc Loc);

/// Parse the @transpose attribute.
ParserResult<TransposeAttr> parseTransposeAttribute(SourceLoc AtLoc,
SourceLoc Loc);

/// Parse a specific attribute.
ParserStatus parseDeclAttribute(DeclAttributes &Attributes, SourceLoc AtLoc);

Expand Down Expand Up @@ -1143,7 +1151,19 @@ class Parser {
SourceLoc &LAngleLoc,
SourceLoc &RAngleLoc);

ParserResult<TypeRepr> parseTypeIdentifier();
/// Parses a type identifier (e.g. 'Foo' or 'Foo.Bar.Baz').
///
/// When `isParsingQualifiedDeclBaseType` is true:
/// - Parses and returns the base type for a qualified declaration name,
/// positioning the parser at the '.' before the final declaration name.
// This position is important for parsing final declaration names like
// '.init' via `parseUnqualifiedDeclName`.
/// - For example, 'Foo.Bar.f' parses as 'Foo.Bar' and the parser is
/// positioned at '.f'.
/// - If there is no base type qualifier (e.g. when parsing just 'f'), returns
/// an empty parser error.
ParserResult<TypeRepr> parseTypeIdentifier(
bool isParsingQualifiedDeclBaseType = false);
ParserResult<TypeRepr> parseOldStyleProtocolComposition();
ParserResult<TypeRepr> parseAnyType();
ParserResult<TypeRepr> parseSILBoxType(GenericParamList *generics,
Expand Down Expand Up @@ -1357,6 +1377,14 @@ class Parser {
bool canParseAsGenericArgumentList();

bool canParseType();

/// Returns true if a simple type identifier can be parsed.
///
/// \verbatim
/// simple-type-identifier: identifier generic-argument-list?
/// \endverbatim
bool canParseSimpleTypeIdentifier();

bool canParseTypeIdentifier();
bool canParseTypeIdentifierOrTypeComposition();
bool canParseOldStyleProtocolComposition();
Expand All @@ -1366,6 +1394,13 @@ class Parser {

bool canParseTypedPattern();

/// Returns true if a qualified declaration name base type can be parsed.
///
/// \verbatim
/// qualified-decl-name-base-type: simple-type-identifier '.'
/// \endverbatim
bool canParseBaseTypeForQualifiedDeclName();

//===--------------------------------------------------------------------===//
// Expression Parsing
ParserResult<Expr> parseExpr(Diag<> ID) {
Expand Down
101 changes: 90 additions & 11 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,25 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
Printer.printNewline();
}

// Returns the differentiation parameters clause string for the given function,
// parameter indices, and parsed parameters.
/// Printing style for a differentiation parameter in a `wrt:` differentiation
/// parameters clause. Used for printing `@differentiable`, `@derivative`, and
/// `@transpose` attributes.
enum class DifferentiationParameterPrintingStyle {
/// Print parameter by name.
/// Used for `@differentiable` and `@derivative` attribute.
Name,
/// Print parameter by index.
/// Used for `@transpose` attribute.
Index
};

/// Returns the differentiation parameters clause string for the given function,
/// parameter indices, parsed parameters, . Use the parameter indices if
/// specified; otherwise, use the parsed parameters.
static std::string getDifferentiationParametersClauseString(
const AbstractFunctionDecl *function, IndexSubset *paramIndices,
ArrayRef<ParsedAutoDiffParameter> parsedParams) {
ArrayRef<ParsedAutoDiffParameter> parsedParams,
DifferentiationParameterPrintingStyle style) {
assert(function);
bool isInstanceMethod = function->isInstanceMember();
std::string result;
Expand All @@ -392,7 +406,14 @@ static std::string getDifferentiationParametersClauseString(
}
// Print remaining differentiation parameters.
interleave(parameters.set_bits(), [&](unsigned index) {
printer << function->getParameters()->get(index)->getName().str();
switch (style) {
case DifferentiationParameterPrintingStyle::Name:
printer << function->getParameters()->get(index)->getName().str();
break;
case DifferentiationParameterPrintingStyle::Index:
printer << index;
break;
}
}, [&] { printer << ", "; });
if (parameterCount > 1)
printer << ')';
Expand Down Expand Up @@ -425,11 +446,11 @@ static std::string getDifferentiationParametersClauseString(
return printer.str();
}

// Print the arguments of the given `@differentiable` attribute.
// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
// parameters clause.
// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
// functions.
/// Print the arguments of the given `@differentiable` attribute.
/// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
/// parameters clause.
/// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
/// functions.
static void printDifferentiableAttrArguments(
const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options,
const Decl *D, bool omitWrtClause = false,
Expand Down Expand Up @@ -465,7 +486,8 @@ static void printDifferentiableAttrArguments(
// Print differentiation parameters clause, unless it is to be omitted.
if (!omitWrtClause) {
auto diffParamsString = getDifferentiationParametersClauseString(
original, attr->getParameterIndices(), attr->getParsedParameters());
original, attr->getParameterIndices(), attr->getParsedParameters(),
DifferentiationParameterPrintingStyle::Name);
// Check whether differentiation parameter clause is empty.
// Handles edge case where resolved parameter indices are unset and
// parsed parameters are empty. This case should never trigger for
Expand Down Expand Up @@ -904,13 +926,29 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
Printer << attr->getOriginalFunctionName().Name;
auto *derivative = cast<AbstractFunctionDecl>(D);
auto diffParamsString = getDifferentiationParametersClauseString(
derivative, attr->getParameterIndices(), attr->getParsedParameters());
derivative, attr->getParameterIndices(), attr->getParsedParameters(),
DifferentiationParameterPrintingStyle::Name);
if (!diffParamsString.empty())
Printer << ", " << diffParamsString;
Printer << ')';
break;
}

case DAK_Transpose: {
Printer.printAttrName("@transpose");
Printer << "(of: ";
auto *attr = cast<TransposeAttr>(this);
Printer << attr->getOriginalFunctionName().Name;
auto *transpose = cast<AbstractFunctionDecl>(D);
auto transParamsString = getDifferentiationParametersClauseString(
transpose, attr->getParameterIndices(), attr->getParsedParameters(),
DifferentiationParameterPrintingStyle::Index);
if (!transParamsString.empty())
Printer << ", " << transParamsString;
Printer << ')';
break;
}

case DAK_ImplicitlySynthesizesNestedRequirement:
Printer.printAttrName("@_implicitly_synthesizes_nested_requirement");
Printer << "(\"" << cast<ImplicitlySynthesizesNestedRequirementAttr>(this)->Value << "\")";
Expand Down Expand Up @@ -1054,6 +1092,8 @@ StringRef DeclAttribute::getAttrName() const {
return "differentiable";
case DAK_Derivative:
return "derivative";
case DAK_Transpose:
return "transpose";
}
llvm_unreachable("bad DeclAttrKind");
}
Expand Down Expand Up @@ -1511,6 +1551,45 @@ DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
std::move(originalName), indices);
}

TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseTypeRepr,
DeclNameRefWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
NumParsedParameters(params.size()) {
std::uninitialized_copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseTypeRepr,
DeclNameRefWithLoc originalName, IndexSubset *indices)
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
ParameterIndices(indices) {}

TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType,
DeclNameRefWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
void *mem = context.Allocate(size, alignof(TransposeAttr));
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
std::move(originalName), params);
}

TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType,
DeclNameRefWithLoc originalName,
IndexSubset *indices) {
void *mem = context.Allocate(sizeof(TransposeAttr), alignof(TransposeAttr));
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
std::move(originalName), indices);
}

ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
TypeLoc ProtocolType,
DeclName MemberName,
Expand Down
Loading