@@ -366,11 +366,25 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
366
366
Printer.printNewline ();
367
367
}
368
368
369
- // Returns the differentiation parameters clause string for the given function,
370
- // parameter indices, and parsed parameters.
369
+ // / Printing style for a differentiation parameter in a `wrt:` differentiation
370
+ // / parameters clause. Used for printing `@differentiable`, `@derivative`, and
371
+ // / `@transpose` attributes.
372
+ enum class DifferentiationParameterPrintingStyle {
373
+ // / Print parameter by name.
374
+ // / Used for `@differentiable` and `@derivative` attribute.
375
+ Name,
376
+ // / Print parameter by index.
377
+ // / Used for `@transpose` attribute.
378
+ Index
379
+ };
380
+
381
+ // / Returns the differentiation parameters clause string for the given function,
382
+ // / parameter indices, parsed parameters, . Use the parameter indices if
383
+ // / specified; otherwise, use the parsed parameters.
371
384
static std::string getDifferentiationParametersClauseString (
372
385
const AbstractFunctionDecl *function, IndexSubset *paramIndices,
373
- ArrayRef<ParsedAutoDiffParameter> parsedParams) {
386
+ ArrayRef<ParsedAutoDiffParameter> parsedParams,
387
+ DifferentiationParameterPrintingStyle style) {
374
388
assert (function);
375
389
bool isInstanceMethod = function->isInstanceMember ();
376
390
std::string result;
@@ -392,7 +406,14 @@ static std::string getDifferentiationParametersClauseString(
392
406
}
393
407
// Print remaining differentiation parameters.
394
408
interleave (parameters.set_bits (), [&](unsigned index) {
395
- printer << function->getParameters ()->get (index)->getName ().str ();
409
+ switch (style) {
410
+ case DifferentiationParameterPrintingStyle::Name:
411
+ printer << function->getParameters ()->get (index)->getName ().str ();
412
+ break ;
413
+ case DifferentiationParameterPrintingStyle::Index:
414
+ printer << index;
415
+ break ;
416
+ }
396
417
}, [&] { printer << " , " ; });
397
418
if (parameterCount > 1 )
398
419
printer << ' )' ;
@@ -425,11 +446,11 @@ static std::string getDifferentiationParametersClauseString(
425
446
return printer.str ();
426
447
}
427
448
428
- // Print the arguments of the given `@differentiable` attribute.
429
- // - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
430
- // parameters clause.
431
- // - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
432
- // functions.
449
+ // / Print the arguments of the given `@differentiable` attribute.
450
+ // / - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
451
+ // / parameters clause.
452
+ // / - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
453
+ // / functions.
433
454
static void printDifferentiableAttrArguments (
434
455
const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options,
435
456
const Decl *D, bool omitWrtClause = false ,
@@ -465,7 +486,8 @@ static void printDifferentiableAttrArguments(
465
486
// Print differentiation parameters clause, unless it is to be omitted.
466
487
if (!omitWrtClause) {
467
488
auto diffParamsString = getDifferentiationParametersClauseString (
468
- original, attr->getParameterIndices (), attr->getParsedParameters ());
489
+ original, attr->getParameterIndices (), attr->getParsedParameters (),
490
+ DifferentiationParameterPrintingStyle::Name);
469
491
// Check whether differentiation parameter clause is empty.
470
492
// Handles edge case where resolved parameter indices are unset and
471
493
// parsed parameters are empty. This case should never trigger for
@@ -897,6 +919,21 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
897
919
break ;
898
920
}
899
921
922
+ case DAK_Transpose: {
923
+ Printer.printAttrName (" @transpose" );
924
+ Printer << " (of: " ;
925
+ auto *attr = cast<TransposeAttr>(this );
926
+ Printer << attr->getOriginalFunctionName ().Name ;
927
+ auto *transpose = cast<AbstractFunctionDecl>(D);
928
+ auto transParamsString = getDifferentiationParametersClauseString (
929
+ transpose, attr->getParameterIndices (), attr->getParsedParameters (),
930
+ DifferentiationParameterPrintingStyle::Index);
931
+ if (!transParamsString.empty ())
932
+ Printer << " , " << transParamsString;
933
+ Printer << ' )' ;
934
+ break ;
935
+ }
936
+
900
937
case DAK_ImplicitlySynthesizesNestedRequirement:
901
938
Printer.printAttrName (" @_implicitly_synthesizes_nested_requirement" );
902
939
Printer << " (\" " << cast<ImplicitlySynthesizesNestedRequirementAttr>(this )->Value << " \" )" ;
@@ -1040,6 +1077,8 @@ StringRef DeclAttribute::getAttrName() const {
1040
1077
return " differentiable" ;
1041
1078
case DAK_Derivative:
1042
1079
return " derivative" ;
1080
+ case DAK_Transpose:
1081
+ return " transpose" ;
1043
1082
}
1044
1083
llvm_unreachable (" bad DeclAttrKind" );
1045
1084
}
@@ -1497,6 +1536,45 @@ DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
1497
1536
std::move (originalName), indices);
1498
1537
}
1499
1538
1539
+ TransposeAttr::TransposeAttr (bool implicit, SourceLoc atLoc,
1540
+ SourceRange baseRange, TypeRepr *baseTypeRepr,
1541
+ DeclNameRefWithLoc originalName,
1542
+ ArrayRef<ParsedAutoDiffParameter> params)
1543
+ : DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
1544
+ BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
1545
+ NumParsedParameters(params.size()) {
1546
+ std::uninitialized_copy (params.begin (), params.end (),
1547
+ getTrailingObjects<ParsedAutoDiffParameter>());
1548
+ }
1549
+
1550
+ TransposeAttr::TransposeAttr (bool implicit, SourceLoc atLoc,
1551
+ SourceRange baseRange, TypeRepr *baseTypeRepr,
1552
+ DeclNameRefWithLoc originalName, IndexSubset *indices)
1553
+ : DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
1554
+ BaseTypeRepr(baseTypeRepr), OriginalFunctionName(std::move(originalName)),
1555
+ ParameterIndices(indices) {}
1556
+
1557
+ TransposeAttr *TransposeAttr::create (ASTContext &context, bool implicit,
1558
+ SourceLoc atLoc, SourceRange baseRange,
1559
+ TypeRepr *baseType,
1560
+ DeclNameRefWithLoc originalName,
1561
+ ArrayRef<ParsedAutoDiffParameter> params) {
1562
+ unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size ());
1563
+ void *mem = context.Allocate (size, alignof (TransposeAttr));
1564
+ return new (mem) TransposeAttr (implicit, atLoc, baseRange, baseType,
1565
+ std::move (originalName), params);
1566
+ }
1567
+
1568
+ TransposeAttr *TransposeAttr::create (ASTContext &context, bool implicit,
1569
+ SourceLoc atLoc, SourceRange baseRange,
1570
+ TypeRepr *baseType,
1571
+ DeclNameRefWithLoc originalName,
1572
+ IndexSubset *indices) {
1573
+ void *mem = context.Allocate (sizeof (TransposeAttr), alignof (TransposeAttr));
1574
+ return new (mem) TransposeAttr (implicit, atLoc, baseRange, baseType,
1575
+ std::move (originalName), indices);
1576
+ }
1577
+
1500
1578
ImplementsAttr::ImplementsAttr (SourceLoc atLoc, SourceRange range,
1501
1579
TypeLoc ProtocolType,
1502
1580
DeclName MemberName,
0 commit comments