Skip to content

Commit b7415d2

Browse files
committed
[AutoDiff upstream] Introduce @transposing attribute.
This PR introduces the `@transposing` attribute to mark functions as transposing other functions. This PR only contains changes related to parsing the attribute. Type checking and other changes will be added in subsequent patches. This work is related to the `@differentiable` attribute in #27506.
1 parent 12c07b4 commit b7415d2

18 files changed

+1058
-22
lines changed

include/swift/AST/Attr.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,12 @@ DECL_ATTR(differentiable, Differentiable,
503503
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,
504504
90)
505505

506+
// Note: leaving space for differentiating (91).
507+
508+
DECL_ATTR(transposing, Transposing,
509+
OnFunc | LongAttribute | AllowMultipleAttributes | ABIStableToAdd |
510+
ABIStableToRemove | APIStableToAdd | APIStableToRemove | NotSerialized, 92)
511+
506512
SIMPLE_DECL_ATTR(IBSegueAction, IBSegueAction,
507513
OnFunc |
508514
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,

include/swift/AST/Attr.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,6 +1855,80 @@ class DifferentiableAttr final
18551855
}
18561856
};
18571857

1858+
/// Attribute that registers a function as a transpose of another function.
1859+
///
1860+
/// Examples:
1861+
/// @transposing(foo)
1862+
/// @transposing(+, wrt: (lhs, rhs))
1863+
class TransposingAttr final
1864+
: public DeclAttribute,
1865+
private llvm::TrailingObjects<TransposingAttr,
1866+
ParsedAutoDiffParameter> {
1867+
friend TrailingObjects;
1868+
1869+
/// The base type of the original function.
1870+
/// This is non-null only when the original function is not top-level (i.e. it
1871+
/// is an instance/static method).
1872+
TypeRepr *BaseType;
1873+
/// The original function name.
1874+
DeclNameWithLoc Original;
1875+
/// The original function, resolved by the type checker.
1876+
FuncDecl *OriginalFunction = nullptr;
1877+
/// The number of parsed parameters specified in 'wrt:'.
1878+
unsigned NumParsedParameters = 0;
1879+
/// The differentiation parameters' indices, resolved by the type checker.
1880+
AutoDiffIndexSubset *ParameterIndexSubset = nullptr;
1881+
1882+
explicit TransposingAttr(ASTContext &context, bool implicit,
1883+
SourceLoc atLoc, SourceRange baseRange,
1884+
TypeRepr *baseType, DeclNameWithLoc original,
1885+
ArrayRef<ParsedAutoDiffParameter> params);
1886+
1887+
explicit TransposingAttr(ASTContext &context, bool implicit,
1888+
SourceLoc atLoc, SourceRange baseRange,
1889+
TypeRepr *baseType, DeclNameWithLoc original,
1890+
AutoDiffIndexSubset *indices);
1891+
1892+
public:
1893+
static TransposingAttr *create(ASTContext &context, bool implicit,
1894+
SourceLoc atLoc, SourceRange baseRange,
1895+
TypeRepr *baseType, DeclNameWithLoc original,
1896+
ArrayRef<ParsedAutoDiffParameter> params);
1897+
1898+
static TransposingAttr *create(ASTContext &context, bool implicit,
1899+
SourceLoc atLoc, SourceRange baseRange,
1900+
TypeRepr *baseType, DeclNameWithLoc original,
1901+
AutoDiffIndexSubset *indices);
1902+
1903+
TypeRepr *getBaseType() const { return BaseType; }
1904+
DeclNameWithLoc getOriginal() const { return Original; }
1905+
1906+
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
1907+
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
1908+
1909+
/// The parsed transposing parameters, i.e. the list of parameters
1910+
/// specified in 'wrt:'.
1911+
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1912+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1913+
}
1914+
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1915+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1916+
}
1917+
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1918+
return NumParsedParameters;
1919+
}
1920+
1921+
AutoDiffIndexSubset *getParameterIndexSubset() const {
1922+
return ParameterIndexSubset;
1923+
}
1924+
void setParameterIndices(AutoDiffIndexSubset *pi) {
1925+
ParameterIndexSubset = pi;
1926+
}
1927+
1928+
static bool classof(const DeclAttribute *DA) {
1929+
return DA->getKind() == DAK_Transposing;
1930+
}
1931+
};
18581932

18591933
void simple_display(llvm::raw_ostream &out, const DeclAttribute *attr);
18601934

include/swift/AST/DiagnosticsParse.def

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,6 +1534,17 @@ ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
15341534
"expected a parameter, which can be a function parameter name, "
15351535
"parameter index, or 'self'", ())
15361536

1537+
// transposing
1538+
ERROR(attr_transposing_expected_original_name,PointsToFirstBadToken,
1539+
"expected an original function name", ())
1540+
ERROR(attr_transposing_expected_label_linear_or_wrt,none,
1541+
"expected 'wrt:'", ())
1542+
1543+
// transposing `wrt` parameters clause
1544+
ERROR(transposing_params_clause_expected_parameter,PointsToFirstBadToken,
1545+
"expected a parameter, which can be a 'unsigned int' parameter number "
1546+
"or 'self'", ())
1547+
15371548
//------------------------------------------------------------------------------
15381549
// MARK: Generics parsing diagnostics
15391550
//------------------------------------------------------------------------------

include/swift/Parse/Parser.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,14 @@ class Parser {
10921092
bool parseDifferentiationParametersClause(
10931093
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
10941094

1095+
/// Parse a transposing parameters clause.
1096+
bool parseTransposingParametersClause(
1097+
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
1098+
1099+
/// Parse a transposing attribute
1100+
ParserResult<TransposingAttr> parseTransposingAttribute(SourceLoc AtLoc,
1101+
SourceLoc Loc);
1102+
10951103
/// Parse a specific attribute.
10961104
ParserStatus parseDeclAttribute(DeclAttributes &Attributes, SourceLoc AtLoc);
10971105

@@ -1226,7 +1234,7 @@ class Parser {
12261234
parseTypeSimple(Diag<> MessageID, bool HandleCodeCompletion);
12271235
ParsedSyntaxResult<ParsedTypeSyntax>
12281236
parseTypeSimpleOrComposition(Diag<> MessageID, bool HandleCodeCompletion);
1229-
ParsedSyntaxResult<ParsedTypeSyntax> parseTypeIdentifier();
1237+
ParsedSyntaxResult<ParsedTypeSyntax> parseTypeIdentifier(bool isParsingQualifiedDeclName = false);
12301238
ParsedSyntaxResult<ParsedTypeSyntax> parseAnyType();
12311239
ParsedSyntaxResult<ParsedTypeSyntax> parseTypeTupleBody();
12321240
ParsedSyntaxResult<ParsedTypeSyntax> parseTypeCollection();
@@ -1442,6 +1450,12 @@ class Parser {
14421450

14431451
bool canParseTypedPattern();
14441452

1453+
/// Determines whether a type qualifier for a decl name can be parsed. e.g.:
1454+
/// 'Foo.f' -> true
1455+
/// 'Foo.Bar.f' -> true
1456+
/// 'f' -> false
1457+
bool canParseTypeQualifierForDeclName();
1458+
14451459
//===--------------------------------------------------------------------===//
14461460
// Expression Parsing
14471461
ParserResult<Expr> parseExpr(Diag<> ID) {

lib/AST/ASTContext.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,12 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL)
434434

435435
llvm::StringMap<OptionSet<SearchPathKind>> SearchPathsSet;
436436

437+
/// For uniquifying `AutoDiffParameterIndices` allocations.
438+
llvm::FoldingSet<AutoDiffParameterIndices> AutoDiffParameterIndicesSet;
439+
440+
/// For uniquifying `AutoDiffIndexSubset` allocations.
441+
llvm::FoldingSet<AutoDiffIndexSubset> AutoDiffIndexSubsets;
442+
437443
/// The permanent arena.
438444
Arena Permanent;
439445

@@ -4618,3 +4624,49 @@ void VarDecl::setOriginalWrappedProperty(VarDecl *originalProperty) {
46184624
assert(ctx.getImpl().OriginalWrappedProperties.count(this) == 0);
46194625
ctx.getImpl().OriginalWrappedProperties[this] = originalProperty;
46204626
}
4627+
4628+
AutoDiffParameterIndices *
4629+
AutoDiffParameterIndices::get(llvm::SmallBitVector indices, ASTContext &C) {
4630+
auto &foldingSet = C.getImpl().AutoDiffParameterIndicesSet;
4631+
4632+
llvm::FoldingSetNodeID id;
4633+
id.AddInteger(indices.size());
4634+
for (unsigned setBit : indices.set_bits())
4635+
id.AddInteger(setBit);
4636+
4637+
void *insertPos;
4638+
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
4639+
if (existing)
4640+
return existing;
4641+
4642+
// TODO(SR-9290): Note that the AutoDiffParameterIndices' destructor never
4643+
// gets called, which causes a small memory leak in the case that the
4644+
// SmallBitVector decides to allocate some heap space.
4645+
void *mem = C.Allocate(sizeof(AutoDiffParameterIndices),
4646+
alignof(AutoDiffParameterIndices));
4647+
auto *newNode = ::new (mem) AutoDiffParameterIndices(indices);
4648+
foldingSet.InsertNode(newNode, insertPos);
4649+
4650+
return newNode;
4651+
}
4652+
4653+
AutoDiffIndexSubset *
4654+
AutoDiffIndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) {
4655+
auto &foldingSet = ctx.getImpl().AutoDiffIndexSubsets;
4656+
llvm::FoldingSetNodeID id;
4657+
unsigned capacity = indices.size();
4658+
id.AddInteger(capacity);
4659+
for (unsigned index : indices.set_bits())
4660+
id.AddInteger(index);
4661+
void *insertPos = nullptr;
4662+
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
4663+
if (existing)
4664+
return existing;
4665+
auto sizeToAlloc = sizeof(AutoDiffIndexSubset) +
4666+
getNumBitWordsNeededForCapacity(capacity);
4667+
auto *buf = reinterpret_cast<AutoDiffIndexSubset *>(
4668+
ctx.Allocate(sizeToAlloc, alignof(AutoDiffIndexSubset)));
4669+
auto *newNode = new (buf) AutoDiffIndexSubset(indices);
4670+
foldingSet.InsertNode(newNode, insertPos);
4671+
return newNode;
4672+
}

lib/AST/Attr.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,67 @@ static std::string getDifferentiationParametersClauseString(
409409
return printer.str();
410410
}
411411

412+
// Returns the differentiation parameters clause string for the given function,
413+
// parameter indices, and parsed parameters.
414+
static std::string getTransposingParametersClauseString(
415+
const AbstractFunctionDecl *function, AutoDiffIndexSubset *indices,
416+
ArrayRef<ParsedAutoDiffParameter> parsedParams) {
417+
bool isInstanceMethod = function && function->isInstanceMember();
418+
419+
std::string result;
420+
llvm::raw_string_ostream printer(result);
421+
422+
// Use parameters from `AutoDiffIndexSubset`, if specified.
423+
if (indices) {
424+
SmallBitVector parameters(indices->getBitVector());
425+
auto parameterCount = parameters.count();
426+
printer << "wrt: ";
427+
if (parameterCount > 1)
428+
printer << '(';
429+
// Check if differentiating wrt `self`. If so, manually print it first.
430+
if (isInstanceMethod && parameters.test(parameters.size() - 1)) {
431+
parameters.reset(parameters.size() - 1);
432+
printer << "self";
433+
if (parameters.any())
434+
printer << ", ";
435+
}
436+
// Print remaining differentiation parameters.
437+
interleave(parameters.set_bits(), [&](unsigned index) { printer << index; },
438+
[&] { printer << ", "; });
439+
if (parameterCount > 1)
440+
printer << ')';
441+
}
442+
// Otherwise, use the parsed parameters.
443+
else if (!parsedParams.empty()) {
444+
printer << "wrt: ";
445+
if (parsedParams.size() > 1)
446+
printer << '(';
447+
interleave(
448+
parsedParams,
449+
[&](const ParsedAutoDiffParameter &param) {
450+
switch (param.getKind()) {
451+
case ParsedAutoDiffParameter::Kind::Named:
452+
printer << param.getName();
453+
break;
454+
case ParsedAutoDiffParameter::Kind::Self:
455+
printer << "self";
456+
break;
457+
case ParsedAutoDiffParameter::Kind::Ordered:
458+
assert((param.getIndex() < function->getParameters()->size()) &&
459+
"'wrt:' parameter index should be less than the number "
460+
"of parameters");
461+
auto *funcParam = function->getParameters()->get(param.getIndex());
462+
printer << funcParam->getNameStr();
463+
break;
464+
}
465+
},
466+
[&] { printer << ", "; });
467+
if (parsedParams.size() > 1)
468+
printer << ')';
469+
}
470+
return printer.str();
471+
}
472+
412473
// Print the arguments of the given `@differentiable` attribute.
413474
static void printDifferentiableAttrArguments(
414475
const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options,
@@ -832,6 +893,28 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
832893
break;
833894
}
834895

896+
case DAK_Differentiable: {
897+
Printer.printAttrName("@differentiable");
898+
auto *attr = cast<DifferentiableAttr>(this);
899+
printDifferentiableAttrArguments(attr, Printer, Options, D);
900+
break;
901+
}
902+
903+
case DAK_Transposing: {
904+
Printer.printAttrName("@transposing");
905+
Printer << '(';
906+
auto *attr = cast<TransposingAttr>(this);
907+
auto *transpose = dyn_cast_or_null<AbstractFunctionDecl>(D);
908+
Printer << attr->getOriginal().Name;
909+
auto diffParamsString = getTransposingParametersClauseString(
910+
transpose, attr->getParameterIndexSubset(),
911+
attr->getParsedParameters());
912+
if (!diffParamsString.empty())
913+
Printer << ", " << diffParamsString;
914+
Printer << ')';
915+
break;
916+
}
917+
835918
case DAK_DynamicReplacement: {
836919
Printer.printAttrName("@_dynamicReplacement");
837920
Printer << "(for: \"";
@@ -989,6 +1072,8 @@ StringRef DeclAttribute::getAttrName() const {
9891072
return "_projectedValueProperty";
9901073
case DAK_Differentiable:
9911074
return "differentiable";
1075+
case DAK_Transposing:
1076+
return "transposing";
9921077
}
9931078
llvm_unreachable("bad DeclAttrKind");
9941079
}
@@ -1410,6 +1495,47 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
14101495
omitAssociatedFunctions);
14111496
}
14121497

1498+
TransposingAttr::TransposingAttr(ASTContext &context, bool implicit,
1499+
SourceLoc atLoc, SourceRange baseRange,
1500+
TypeRepr *baseType, DeclNameWithLoc original,
1501+
ArrayRef<ParsedAutoDiffParameter> params)
1502+
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
1503+
BaseType(baseType), Original(std::move(original)),
1504+
NumParsedParameters(params.size()) {
1505+
std::uninitialized_copy(params.begin(), params.end(),
1506+
getTrailingObjects<ParsedAutoDiffParameter>());
1507+
}
1508+
1509+
TransposingAttr::TransposingAttr(ASTContext &context, bool implicit,
1510+
SourceLoc atLoc, SourceRange baseRange,
1511+
TypeRepr *baseType, DeclNameWithLoc original,
1512+
AutoDiffIndexSubset *indices)
1513+
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
1514+
BaseType(baseType), Original(std::move(original)),
1515+
ParameterIndexSubset(indices) {}
1516+
1517+
TransposingAttr *
1518+
TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1519+
SourceRange baseRange, TypeRepr *baseType,
1520+
DeclNameWithLoc original,
1521+
ArrayRef<ParsedAutoDiffParameter> params) {
1522+
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
1523+
void *mem = context.Allocate(size, alignof(TransposingAttr));
1524+
return new (mem) TransposingAttr(context, implicit, atLoc, baseRange,
1525+
baseType, std::move(original), params);
1526+
}
1527+
1528+
TransposingAttr *
1529+
TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1530+
SourceRange baseRange, TypeRepr *baseType,
1531+
DeclNameWithLoc original,
1532+
AutoDiffIndexSubset *indices) {
1533+
void *mem =
1534+
context.Allocate(sizeof(TransposingAttr), alignof(TransposingAttr));
1535+
return new (mem) TransposingAttr(context, implicit, atLoc, baseRange,
1536+
baseType, std::move(original), indices);
1537+
}
1538+
14131539
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
14141540
TypeLoc ProtocolType,
14151541
DeclName MemberName,

0 commit comments

Comments
 (0)