Skip to content

Commit c842fee

Browse files
committed
[AutoDiff upstream] Add @transpose(of:) attribute.
The `@transpose(of:)` attribute registers a function as a transpose of another function. This patch adds the `@transpose(of:)` attribute definition, syntax, parsing, and printing. Resolves TF-827. Todos: - Type-checking (TF-830, TF-1060). - Enable serialization (TF-838). - Use module-qualified names instead of custom qualified name syntax/parsing (TF-1066).
1 parent fb76496 commit c842fee

18 files changed

+767
-66
lines changed

include/swift/AST/Attr.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,11 @@ DECL_ATTR(_implicitly_synthesizes_nested_requirement, ImplicitlySynthesizesNeste
540540
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,
541541
98)
542542

543+
DECL_ATTR(transpose, Transpose,
544+
OnFunc | LongAttribute | AllowMultipleAttributes |
545+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
546+
99)
547+
543548
#undef TYPE_ATTR
544549
#undef DECL_ATTR_ALIAS
545550
#undef CONTEXTUAL_DECL_ATTR_ALIAS

include/swift/AST/Attr.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,6 +1858,89 @@ class DerivativeAttr final
18581858
}
18591859
};
18601860

1861+
/// The `@transpose` attribute registers a function as a transpose of another
1862+
/// function-like declaration: a 'func', 'init', 'subscript', or 'var' computed
1863+
/// property declaration.
1864+
///
1865+
/// The `@transpose` attribute also has a `wrt:` clause specifying the
1866+
/// parameters that are transposed "with respect to", i.e. the transposed
1867+
/// parameters.
1868+
///
1869+
/// Examples:
1870+
/// @transpose(of: foo)
1871+
/// @transpose(of: +, wrt: (lhs, rhs))
1872+
class TransposeAttr final
1873+
: public DeclAttribute,
1874+
private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
1875+
friend TrailingObjects;
1876+
1877+
/// The base type of the original function.
1878+
/// This is non-null only when the original function is not top-level (i.e. it
1879+
/// is an instance/static method).
1880+
TypeRepr *BaseTypeRepr;
1881+
/// The original function name.
1882+
DeclNameRefWithLoc OriginalFunctionName;
1883+
/// The original function declaration, resolved by the type checker.
1884+
AbstractFunctionDecl *OriginalFunction = nullptr;
1885+
/// The number of parsed parameters specified in 'wrt:'.
1886+
unsigned NumParsedParameters = 0;
1887+
/// The transposed parameters' indices, resolved by the type checker.
1888+
IndexSubset *ParameterIndices = nullptr;
1889+
1890+
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1891+
TypeRepr *baseType, DeclNameRefWithLoc original,
1892+
ArrayRef<ParsedAutoDiffParameter> params);
1893+
1894+
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1895+
TypeRepr *baseType, DeclNameRefWithLoc original,
1896+
IndexSubset *indices);
1897+
1898+
public:
1899+
static TransposeAttr *create(ASTContext &context, bool implicit,
1900+
SourceLoc atLoc, SourceRange baseRange,
1901+
TypeRepr *baseType, DeclNameRefWithLoc original,
1902+
ArrayRef<ParsedAutoDiffParameter> params);
1903+
1904+
static TransposeAttr *create(ASTContext &context, bool implicit,
1905+
SourceLoc atLoc, SourceRange baseRange,
1906+
TypeRepr *baseType, DeclNameRefWithLoc original,
1907+
IndexSubset *indices);
1908+
1909+
TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; }
1910+
DeclNameRefWithLoc getOriginalFunctionName() const {
1911+
return OriginalFunctionName;
1912+
}
1913+
AbstractFunctionDecl *getOriginalFunction() const {
1914+
return OriginalFunction;
1915+
}
1916+
void setOriginalFunction(AbstractFunctionDecl *decl) {
1917+
OriginalFunction = decl;
1918+
}
1919+
1920+
/// The parsed transposed parameters, i.e. the list of parameters specified in
1921+
/// 'wrt:'.
1922+
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1923+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1924+
}
1925+
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1926+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1927+
}
1928+
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1929+
return NumParsedParameters;
1930+
}
1931+
1932+
IndexSubset *getParameterIndices() const {
1933+
return ParameterIndices;
1934+
}
1935+
void setParameterIndices(IndexSubset *parameterIndices) {
1936+
ParameterIndices = parameterIndices;
1937+
}
1938+
1939+
static bool classof(const DeclAttribute *DA) {
1940+
return DA->getKind() == DAK_Transpose;
1941+
}
1942+
};
1943+
18611944
/// Attributes that may be applied to declarations.
18621945
class DeclAttributes {
18631946
/// Linked list of declaration attributes.

include/swift/AST/DiagnosticsParse.def

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,9 +1560,12 @@ ERROR(expected_colon_after_label,PointsToFirstBadToken,
15601560
ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
15611561
"expected a parameter, which can be a function parameter name, "
15621562
"parameter index, or 'self'", ())
1563+
ERROR(diff_params_clause_expected_parameter_unnamed,PointsToFirstBadToken,
1564+
"expected a parameter, which can be a function parameter index or 'self'",
1565+
())
15631566

1564-
// derivative
1565-
ERROR(attr_derivative_expected_original_name,PointsToFirstBadToken,
1567+
// Automatic differentiation attributes
1568+
ERROR(autodiff_attr_expected_original_decl_name,PointsToFirstBadToken,
15661569
"expected an original function name", ())
15671570

15681571
//------------------------------------------------------------------------------

include/swift/Parse/Parser.h

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -994,14 +994,22 @@ class Parser {
994994
Optional<DeclNameRefWithLoc> &vjpSpec,
995995
TrailingWhereClause *&whereClause);
996996

997-
/// Parse a differentiation parameters clause.
997+
/// Parse a differentiation parameters clause, i.e. the 'wrt:' clause in
998+
/// `@differentiable` and `@derivative` attributes.
999+
/// If `allowNamedParameters` is false, allow only index parameters and
1000+
/// 'self'.
9981001
bool parseDifferentiationParametersClause(
999-
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
1002+
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName,
1003+
bool allowNamedParameters = true);
10001004

10011005
/// Parse the @derivative attribute.
10021006
ParserResult<DerivativeAttr> parseDerivativeAttribute(SourceLoc AtLoc,
10031007
SourceLoc Loc);
10041008

1009+
/// Parse the @transpose attribute.
1010+
ParserResult<TransposeAttr> parseTransposeAttribute(SourceLoc AtLoc,
1011+
SourceLoc Loc);
1012+
10051013
/// Parse a specific attribute.
10061014
ParserStatus parseDeclAttribute(DeclAttributes &Attributes, SourceLoc AtLoc);
10071015

@@ -1143,7 +1151,19 @@ class Parser {
11431151
SourceLoc &LAngleLoc,
11441152
SourceLoc &RAngleLoc);
11451153

1146-
ParserResult<TypeRepr> parseTypeIdentifier();
1154+
/// Parses a type identifier (e.g. 'Foo' or 'Foo.Bar.Baz').
1155+
///
1156+
/// When `isParsingQualifiedDeclBaseType` is true:
1157+
/// - Parses and returns the base type for a qualified declaration name,
1158+
/// positioning the parser at the '.' before the final declaration name.
1159+
// This position is important for parsing final declaration names like
1160+
// '.init' via `parseUnqualifiedDeclName`.
1161+
/// - For example, 'Foo.Bar.f' parses as 'Foo.Bar' and the parser is
1162+
/// positioned at '.f'.
1163+
/// - If there is no base type qualifier (e.g. when parsing just 'f'), returns
1164+
/// an empty parser error.
1165+
ParserResult<TypeRepr> parseTypeIdentifier(
1166+
bool isParsingQualifiedDeclBaseType = false);
11471167
ParserResult<TypeRepr> parseOldStyleProtocolComposition();
11481168
ParserResult<TypeRepr> parseAnyType();
11491169
ParserResult<TypeRepr> parseSILBoxType(GenericParamList *generics,
@@ -1366,6 +1386,14 @@ class Parser {
13661386

13671387
bool canParseTypedPattern();
13681388

1389+
/// Returns true if a base type for a qualified declaration name can be
1390+
/// parsed.
1391+
/// Examples:
1392+
/// 'Foo.f' -> true
1393+
/// 'Foo.Bar.f' -> true
1394+
/// 'f' -> false
1395+
bool canParseBaseTypeForQualifiedDeclName();
1396+
13691397
//===--------------------------------------------------------------------===//
13701398
// Expression Parsing
13711399
ParserResult<Expr> parseExpr(Diag<> ID) {

lib/AST/Attr.cpp

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,25 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
366366
Printer.printNewline();
367367
}
368368

369-
// Returns the differentiation parameters clause string for the given function,
370-
// parameter indices, and parsed parameters.
369+
/// Printing style for a differentiation parameter in a `wrt:` differentiation
370+
/// parameters clause. Used for printing `@differentiable`, `@derivative`, and
371+
/// `@transpose` attributes.
372+
enum class DifferentiationParameterPrintingStyle {
373+
/// Print parameter by name.
374+
/// Used for `@differentiable` and `@derivative` attribute.
375+
Name,
376+
/// Print parameter by index.
377+
/// Used for `@transpose` attribute.
378+
Index
379+
};
380+
381+
/// Returns the differentiation parameters clause string for the given function,
382+
/// parameter indices, parsed parameters, . Use the parameter indices if
383+
/// specified; otherwise, use the parsed parameters.
371384
static std::string getDifferentiationParametersClauseString(
372385
const AbstractFunctionDecl *function, IndexSubset *paramIndices,
373-
ArrayRef<ParsedAutoDiffParameter> parsedParams) {
386+
ArrayRef<ParsedAutoDiffParameter> parsedParams,
387+
DifferentiationParameterPrintingStyle style) {
374388
assert(function);
375389
bool isInstanceMethod = function->isInstanceMember();
376390
std::string result;
@@ -392,7 +406,14 @@ static std::string getDifferentiationParametersClauseString(
392406
}
393407
// Print remaining differentiation parameters.
394408
interleave(parameters.set_bits(), [&](unsigned index) {
395-
printer << function->getParameters()->get(index)->getName().str();
409+
switch (style) {
410+
case DifferentiationParameterPrintingStyle::Name:
411+
printer << function->getParameters()->get(index)->getName().str();
412+
break;
413+
case DifferentiationParameterPrintingStyle::Index:
414+
printer << index;
415+
break;
416+
}
396417
}, [&] { printer << ", "; });
397418
if (parameterCount > 1)
398419
printer << ')';
@@ -425,11 +446,11 @@ static std::string getDifferentiationParametersClauseString(
425446
return printer.str();
426447
}
427448

428-
// Print the arguments of the given `@differentiable` attribute.
429-
// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
430-
// parameters clause.
431-
// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
432-
// functions.
449+
/// Print the arguments of the given `@differentiable` attribute.
450+
/// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
451+
/// parameters clause.
452+
/// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
453+
/// functions.
433454
static void printDifferentiableAttrArguments(
434455
const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options,
435456
const Decl *D, bool omitWrtClause = false,
@@ -465,7 +486,8 @@ static void printDifferentiableAttrArguments(
465486
// Print differentiation parameters clause, unless it is to be omitted.
466487
if (!omitWrtClause) {
467488
auto diffParamsString = getDifferentiationParametersClauseString(
468-
original, attr->getParameterIndices(), attr->getParsedParameters());
489+
original, attr->getParameterIndices(), attr->getParsedParameters(),
490+
DifferentiationParameterPrintingStyle::Name);
469491
// Check whether differentiation parameter clause is empty.
470492
// Handles edge case where resolved parameter indices are unset and
471493
// parsed parameters are empty. This case should never trigger for
@@ -897,6 +919,21 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
897919
break;
898920
}
899921

922+
case DAK_Transpose: {
923+
Printer.printAttrName("@transpose");
924+
Printer << "(of: ";
925+
auto *attr = cast<TransposeAttr>(this);
926+
Printer << attr->getOriginalFunctionName().Name;
927+
auto *transpose = cast<AbstractFunctionDecl>(D);
928+
auto transParamsString = getDifferentiationParametersClauseString(
929+
transpose, attr->getParameterIndices(), attr->getParsedParameters(),
930+
DifferentiationParameterPrintingStyle::Index);
931+
if (!transParamsString.empty())
932+
Printer << ", " << transParamsString;
933+
Printer << ')';
934+
break;
935+
}
936+
900937
case DAK_ImplicitlySynthesizesNestedRequirement:
901938
Printer.printAttrName("@_implicitly_synthesizes_nested_requirement");
902939
Printer << "(\"" << cast<ImplicitlySynthesizesNestedRequirementAttr>(this)->Value << "\")";
@@ -1040,6 +1077,8 @@ StringRef DeclAttribute::getAttrName() const {
10401077
return "differentiable";
10411078
case DAK_Derivative:
10421079
return "derivative";
1080+
case DAK_Transpose:
1081+
return "transpose";
10431082
}
10441083
llvm_unreachable("bad DeclAttrKind");
10451084
}
@@ -1497,6 +1536,45 @@ DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
14971536
std::move(originalName), indices);
14981537
}
14991538

1539+
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
1540+
SourceRange baseRange, TypeRepr *baseTypeRepr,
1541+
DeclNameRefWithLoc originalName,
1542+
ArrayRef<ParsedAutoDiffParameter> params)
1543+
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
1544+
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
1545+
NumParsedParameters(params.size()) {
1546+
std::uninitialized_copy(params.begin(), params.end(),
1547+
getTrailingObjects<ParsedAutoDiffParameter>());
1548+
}
1549+
1550+
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
1551+
SourceRange baseRange, TypeRepr *baseTypeRepr,
1552+
DeclNameRefWithLoc originalName, IndexSubset *indices)
1553+
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
1554+
BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
1555+
ParameterIndices(indices) {}
1556+
1557+
TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
1558+
SourceLoc atLoc, SourceRange baseRange,
1559+
TypeRepr *baseType,
1560+
DeclNameRefWithLoc originalName,
1561+
ArrayRef<ParsedAutoDiffParameter> params) {
1562+
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
1563+
void *mem = context.Allocate(size, alignof(TransposeAttr));
1564+
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
1565+
std::move(originalName), params);
1566+
}
1567+
1568+
TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
1569+
SourceLoc atLoc, SourceRange baseRange,
1570+
TypeRepr *baseType,
1571+
DeclNameRefWithLoc originalName,
1572+
IndexSubset *indices) {
1573+
void *mem = context.Allocate(sizeof(TransposeAttr), alignof(TransposeAttr));
1574+
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
1575+
std::move(originalName), indices);
1576+
}
1577+
15001578
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
15011579
TypeLoc ProtocolType,
15021580
DeclName MemberName,

0 commit comments

Comments
 (0)