Skip to content

Commit b9499c3

Browse files
authored
[AutoDiff] Rename @transposing to @transpose(of:). (#28488)
Rename `@transposing` to `@transpose(of:)`. `@transpose(of:)` more clearly evokes transpose registration; the syntax is otherwise unchanged. Discussed here: #28321 (comment) Remove `@transposing`. Removal without deprecation should be fine because there are no known users of `@transposing` attribute. Resolves TF-992. TF-1009 tracks `@transpose` syntax support for qualified names.
1 parent 82db8fa commit b9499c3

20 files changed

+297
-255
lines changed

include/swift/AST/Attr.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
524524
OnVar |
525525
ABIBreakingToAdd | ABIBreakingToRemove | APIBreakingToAdd | APIBreakingToRemove,
526526
94)
527-
DECL_ATTR(transposing, Transposing,
527+
DECL_ATTR(transpose, Transpose,
528528
OnFunc | LongAttribute | AllowMultipleAttributes |
529529
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
530530
NotSerialized, 96)

include/swift/AST/Attr.h

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,12 +1743,11 @@ using DifferentiatingAttr = DerivativeAttr;
17431743
/// Attribute that registers a function as a transpose of another function.
17441744
///
17451745
/// Examples:
1746-
/// @transposing(foo)
1747-
/// @transposing(+, wrt: (lhs, rhs))
1748-
class TransposingAttr final
1749-
: public DeclAttribute,
1750-
private llvm::TrailingObjects<TransposingAttr,
1751-
ParsedAutoDiffParameter> {
1746+
/// @transpose(of: foo)
1747+
/// @transpose(of: +, wrt: (lhs, rhs))
1748+
class TransposeAttr final
1749+
: public DeclAttribute,
1750+
private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
17521751
friend TrailingObjects;
17531752

17541753
/// The base type of the original function.
@@ -1761,28 +1760,27 @@ class TransposingAttr final
17611760
AbstractFunctionDecl *OriginalFunction = nullptr;
17621761
/// The number of parsed parameters specified in 'wrt:'.
17631762
unsigned NumParsedParameters = 0;
1764-
/// The differentiation parameters' indices, resolved by the type checker.
1763+
/// The transposed parameters' indices, resolved by the type checker.
17651764
IndexSubset *ParameterIndices = nullptr;
17661765

1767-
explicit TransposingAttr(bool implicit, SourceLoc atLoc,
1768-
SourceRange baseRange, TypeRepr *baseType,
1769-
DeclNameWithLoc original,
1770-
ArrayRef<ParsedAutoDiffParameter> params);
1766+
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1767+
TypeRepr *baseType, DeclNameWithLoc original,
1768+
ArrayRef<ParsedAutoDiffParameter> params);
17711769

1772-
explicit TransposingAttr(bool implicit, SourceLoc atLoc,
1773-
SourceRange baseRange, TypeRepr *baseType,
1774-
DeclNameWithLoc original, IndexSubset *indices);
1770+
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1771+
TypeRepr *baseType, DeclNameWithLoc original,
1772+
IndexSubset *indices);
17751773

17761774
public:
1777-
static TransposingAttr *create(ASTContext &context, bool implicit,
1778-
SourceLoc atLoc, SourceRange baseRange,
1779-
TypeRepr *baseType, DeclNameWithLoc original,
1780-
ArrayRef<ParsedAutoDiffParameter> params);
1775+
static TransposeAttr *create(ASTContext &context, bool implicit,
1776+
SourceLoc atLoc, SourceRange baseRange,
1777+
TypeRepr *baseType, DeclNameWithLoc original,
1778+
ArrayRef<ParsedAutoDiffParameter> params);
17811779

1782-
static TransposingAttr *create(ASTContext &context, bool implicit,
1783-
SourceLoc atLoc, SourceRange baseRange,
1784-
TypeRepr *baseType, DeclNameWithLoc original,
1785-
IndexSubset *indices);
1780+
static TransposeAttr *create(ASTContext &context, bool implicit,
1781+
SourceLoc atLoc, SourceRange baseRange,
1782+
TypeRepr *baseType, DeclNameWithLoc original,
1783+
IndexSubset *indices);
17861784

17871785
TypeRepr *getBaseType() const { return BaseType; }
17881786
DeclNameWithLoc getOriginalFunctionName() const {
@@ -1795,8 +1793,8 @@ class TransposingAttr final
17951793
OriginalFunction = decl;
17961794
}
17971795

1798-
/// The parsed transposing parameters, i.e. the list of parameters
1799-
/// specified in 'wrt:'.
1796+
/// The parsed transposed parameters, i.e. the list of parameters specified in
1797+
/// 'wrt:'.
18001798
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
18011799
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
18021800
}
@@ -1815,7 +1813,7 @@ class TransposingAttr final
18151813
}
18161814

18171815
static bool classof(const DeclAttribute *DA) {
1818-
return DA->getKind() == DAK_Transposing;
1816+
return DA->getKind() == DAK_Transpose;
18191817
}
18201818
};
18211819

include/swift/AST/DiagnosticsParse.def

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,14 +1572,14 @@ WARNING(attr_differentiating_deprecated,PointsToFirstBadToken,
15721572
"'@differentiating' attribute is deprecated; use '@derivative(of:)' "
15731573
"instead", ())
15741574

1575-
// transposing
1576-
ERROR(attr_transposing_expected_original_name,PointsToFirstBadToken,
1575+
// transpose
1576+
ERROR(attr_transpose_expected_original_name,PointsToFirstBadToken,
15771577
"expected an original function name", ())
1578-
ERROR(attr_transposing_expected_label_linear_or_wrt,none,
1578+
ERROR(attr_transpose_expected_label_linear_or_wrt,none,
15791579
"expected 'wrt:'", ())
15801580

1581-
// transposing `wrt` parameters clause
1582-
ERROR(transposing_params_clause_expected_parameter,PointsToFirstBadToken,
1581+
// transpose `wrt` parameters clause
1582+
ERROR(transpose_params_clause_expected_parameter,PointsToFirstBadToken,
15831583
"expected a parameter, which can be a 'unsigned int' parameter number "
15841584
"or 'self'", ())
15851585

include/swift/AST/DiagnosticsSema.def

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3024,17 +3024,17 @@ ERROR(derivative_attr_original_stored_property_unsupported,none,
30243024
ERROR(derivative_attr_original_already_has_derivative,none,
30253025
"a derivative already exists for %0", (DeclName))
30263026

3027-
// transposing
3027+
// @transpose
30283028
ERROR(transpose_params_clause_param_not_differentiable,none,
30293029
"can only transpose with respect to parameters that conform to "
30303030
"'Differentiable' and where '%0 == %0.TangentVector'", (StringRef))
3031-
ERROR(transposing_attr_overload_not_found,none,
3031+
ERROR(transpose_attr_overload_not_found,none,
30323032
"could not find function %0 with expected type %1", (DeclName, Type))
3033-
ERROR(transposing_attr_cannot_use_named_wrt_params,none,
3034-
"cannot use named 'wrt' parameters in '@transposing' attribute, found %0",
3035-
(Identifier))
3036-
ERROR(transposing_attr_result_value_not_differentiable,none,
3037-
"'@transposing' attribute requires original function result %0 to "
3033+
ERROR(transpose_attr_cannot_use_named_wrt_params,none,
3034+
"cannot use named 'wrt' parameters in '@transpose(of:)' attribute, found "
3035+
"%0", (Identifier))
3036+
ERROR(transpose_attr_result_value_not_differentiable,none,
3037+
"'@transpose(of:)' attribute requires original function result %0 to "
30383038
"conform to 'Differentiable'", (Type))
30393039

30403040
// differentiation `wrt` parameters clause

include/swift/AST/Types.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3142,7 +3142,7 @@ class AnyFunctionType : public TypeBase {
31423142
///
31433143
/// If `makeSelfParamFirst` is true, self's tangent is reordered to appear
31443144
/// first. This should be used during type-checking, e.g. type-checking
3145-
/// `@differentiable`, `@derivative`, and `@transposing` attributes.
3145+
/// `@differentiable`, `@derivative`, and `@transpose` attributes.
31463146
///
31473147
/// \note The original function type (`self`) need not be `@differentiable`.
31483148
/// The resulting function will preserve all `ExtInfo` of the original
@@ -3158,11 +3158,10 @@ class AnyFunctionType : public TypeBase {
31583158
/// corresponding original function type.
31593159
AnyFunctionType *getAutoDiffOriginalFunctionType();
31603160

3161-
/// Given the type of a transposing derivative function, returns the
3162-
/// corresponding original function type.
3161+
/// Given the type of a transpose function, returns the corresponding original
3162+
/// function type.
31633163
AnyFunctionType *
3164-
getTransposeOriginalFunctionType(IndexSubset *wrtParamIndices,
3165-
bool wrtSelf);
3164+
getTransposeOriginalFunctionType(IndexSubset *wrtParamIndices, bool wrtSelf);
31663165

31673166
AnyFunctionType *getWithoutDifferentiability() const;
31683167

include/swift/Parse/Parser.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -996,13 +996,10 @@ class Parser {
996996
Optional<DeclNameWithLoc> &jvpSpec, Optional<DeclNameWithLoc> &vjpSpec,
997997
TrailingWhereClause *&whereClause);
998998

999-
/// Parse a differentiation parameters clause.
999+
/// Parse a differentiation parameters clause, i.e. the "wrt:" clause in
1000+
/// @differentiable and @derivative attributes.
10001001
bool parseDifferentiationParametersClause(
10011002
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
1002-
1003-
/// Parse a transposing parameters clause.
1004-
bool parseTransposingParametersClause(
1005-
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
10061003

10071004
/// Parse the @derivative attribute.
10081005
ParserResult<DerivativeAttr> parseDerivativeAttribute(SourceLoc AtLoc,
@@ -1013,9 +1010,14 @@ class Parser {
10131010
ParserResult<DerivativeAttr> parseDifferentiatingAttribute(SourceLoc AtLoc,
10141011
SourceLoc Loc);
10151012

1016-
/// Parse the @transposing attribute.
1017-
ParserResult<TransposingAttr> parseTransposingAttribute(SourceLoc AtLoc,
1018-
SourceLoc Loc);
1013+
/// Parse a transposed parameters clause, i.e. the "wrt:" clause in @transpose
1014+
/// attributes.
1015+
bool parseTransposedParametersClause(
1016+
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
1017+
1018+
/// Parse the @transpose attribute.
1019+
ParserResult<TransposeAttr> parseTransposeAttribute(SourceLoc AtLoc,
1020+
SourceLoc Loc);
10191021

10201022
/// Parse the @quoted attribute.
10211023
ParserResult<QuotedAttr> parseQuotedAttribute(SourceLoc AtLoc, SourceLoc Loc);

lib/AST/Attr.cpp

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -936,10 +936,10 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
936936
}
937937

938938
// SWIFT_ENABLE_TENSORFLOW
939-
case DAK_Transposing: {
940-
Printer.printAttrName("@transposing");
939+
case DAK_Transpose: {
940+
Printer.printAttrName("@transpose");
941941
Printer << '(';
942-
auto *attr = cast<TransposingAttr>(this);
942+
auto *attr = cast<TransposeAttr>(this);
943943
auto *transpose = cast<AbstractFunctionDecl>(D);
944944
Printer << attr->getOriginalFunctionName().Name;
945945
auto transParamsString = getTransposedParametersClauseString(
@@ -1110,12 +1110,13 @@ StringRef DeclAttribute::getAttrName() const {
11101110
return "differentiable";
11111111
case DAK_Derivative:
11121112
return "derivative";
1113+
case DAK_Transpose:
1114+
return "transpose";
11131115
case DAK_Differentiating:
11141116
return "differentiating";
1115-
case DAK_Transposing:
1116-
return "transposing";
11171117
case DAK_Quoted:
11181118
return "quoted";
1119+
// SWIFT_ENABLE_TENSORFLOW END
11191120
}
11201121
llvm_unreachable("bad DeclAttrKind");
11211122
}
@@ -1608,45 +1609,43 @@ DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
16081609
std::move(originalName), indices);
16091610
}
16101611

1611-
TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
1612-
SourceRange baseRange, TypeRepr *baseType,
1613-
DeclNameWithLoc originalName,
1614-
ArrayRef<ParsedAutoDiffParameter> params)
1615-
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
1612+
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
1613+
SourceRange baseRange, TypeRepr *baseType,
1614+
DeclNameWithLoc originalName,
1615+
ArrayRef<ParsedAutoDiffParameter> params)
1616+
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
16161617
BaseType(baseType), OriginalFunctionName(std::move(originalName)),
16171618
NumParsedParameters(params.size()) {
16181619
std::uninitialized_copy(params.begin(), params.end(),
16191620
getTrailingObjects<ParsedAutoDiffParameter>());
16201621
}
16211622

1622-
TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
1623-
SourceRange baseRange, TypeRepr *baseType,
1624-
DeclNameWithLoc originalName,
1625-
IndexSubset *indices)
1626-
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
1623+
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
1624+
SourceRange baseRange, TypeRepr *baseType,
1625+
DeclNameWithLoc originalName, IndexSubset *indices)
1626+
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
16271627
BaseType(baseType), OriginalFunctionName(std::move(originalName)),
16281628
ParameterIndices(indices) {}
16291629

1630-
TransposingAttr *
1631-
TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1632-
SourceRange baseRange, TypeRepr *baseType,
1633-
DeclNameWithLoc original,
1634-
ArrayRef<ParsedAutoDiffParameter> params) {
1630+
TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
1631+
SourceLoc atLoc, SourceRange baseRange,
1632+
TypeRepr *baseType,
1633+
DeclNameWithLoc originalName,
1634+
ArrayRef<ParsedAutoDiffParameter> params) {
16351635
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
1636-
void *mem = context.Allocate(size, alignof(TransposingAttr));
1637-
return new (mem) TransposingAttr(implicit, atLoc, baseRange, baseType,
1638-
std::move(original), params);
1639-
}
1640-
1641-
TransposingAttr *
1642-
TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1643-
SourceRange baseRange, TypeRepr *baseType,
1644-
DeclNameWithLoc original,
1645-
IndexSubset *indices) {
1646-
void *mem =
1647-
context.Allocate(sizeof(TransposingAttr), alignof(TransposingAttr));
1648-
return new (mem) TransposingAttr(implicit, atLoc, baseRange, baseType,
1649-
std::move(original), indices);
1636+
void *mem = context.Allocate(size, alignof(TransposeAttr));
1637+
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
1638+
std::move(originalName), params);
1639+
}
1640+
1641+
TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
1642+
SourceLoc atLoc, SourceRange baseRange,
1643+
TypeRepr *baseType,
1644+
DeclNameWithLoc originalName,
1645+
IndexSubset *indices) {
1646+
void *mem = context.Allocate(sizeof(TransposeAttr), alignof(TransposeAttr));
1647+
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
1648+
std::move(originalName), indices);
16501649
}
16511650

16521651
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,

lib/AST/Type.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4897,8 +4897,7 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
48974897
assert(originalResult);
48984898

48994899
SmallVector<TupleTypeElt, 4> transposeResultTypes;
4900-
// Return type of '@transposing' function can have single type or tuples
4901-
// of types.
4900+
// Return type of transpose function can be a singular type or a tuple type.
49024901
if (auto transposeResultTupleType = transposeResult->getAs<TupleType>()) {
49034902
transposeResultTypes.append(transposeResultTupleType->getElements().begin(),
49044903
transposeResultTupleType->getElements().end());

0 commit comments

Comments
 (0)