@@ -366,51 +366,60 @@ static void printShortFormAvailable(ArrayRef<const DeclAttribute *> Attrs,
366
366
Printer.printNewline ();
367
367
}
368
368
369
- // / Printing style for a differentiation parameter in a `wrt:` differentiation
370
- // / parameters clause. Used for printing `@differentiable`, `@derivative`, and
371
- // / `@transpose` attributes.
372
- enum class DifferentiationParameterPrintingStyle {
373
- // / Print parameter by name.
369
+ // / The kind of a parameter in a `wrt:` differentiation parameters clause:
370
+ // / either a differentiability parameter or a linearity parameter. Used for
371
+ // / printing `@differentiable`, `@derivative`, and `@transpose` attributes.
372
+ enum class DifferentiationParameterKind {
373
+ // / A differentiability parameter, printed by name.
374
374
// / Used for `@differentiable` and `@derivative` attribute.
375
- Name ,
376
- // / Print parameter by index.
375
+ Differentiability ,
376
+ // / A linearity parameter, printed by index.
377
377
// / Used for `@transpose` attribute.
378
- Index
378
+ Linearity
379
379
};
380
380
381
381
// / Returns the differentiation parameters clause string for the given function,
382
- // / parameter indices, parsed parameters, . Use the parameter indices if
383
- // / specified; otherwise, use the parsed parameters.
382
+ // / parameter indices, parsed parameters, and differentiation parameter kind.
383
+ // / Use the parameter indices if specified; otherwise, use the parsed
384
+ // / parameters.
384
385
static std::string getDifferentiationParametersClauseString (
385
- const AbstractFunctionDecl *function, IndexSubset *paramIndices ,
386
+ const AbstractFunctionDecl *function, IndexSubset *parameterIndices ,
386
387
ArrayRef<ParsedAutoDiffParameter> parsedParams,
387
- DifferentiationParameterPrintingStyle style ) {
388
+ DifferentiationParameterKind parameterKind ) {
388
389
assert (function);
389
390
bool isInstanceMethod = function->isInstanceMember ();
391
+ bool isStaticMethod = function->isStatic ();
390
392
std::string result;
391
393
llvm::raw_string_ostream printer (result);
392
394
393
395
// Use the parameter indices, if specified.
394
- if (paramIndices ) {
395
- auto parameters = paramIndices ->getBitVector ();
396
+ if (parameterIndices ) {
397
+ auto parameters = parameterIndices ->getBitVector ();
396
398
auto parameterCount = parameters.count ();
397
399
printer << " wrt: " ;
398
400
if (parameterCount > 1 )
399
401
printer << ' (' ;
400
402
// Check if differentiating wrt `self`. If so, manually print it first.
401
- if (isInstanceMethod && parameters.test (parameters.size () - 1 )) {
403
+ bool isWrtSelf =
404
+ (isInstanceMethod ||
405
+ (isStaticMethod &&
406
+ parameterKind == DifferentiationParameterKind::Linearity)) &&
407
+ parameters.test (parameters.size () - 1 );
408
+ if (isWrtSelf) {
402
409
parameters.reset (parameters.size () - 1 );
403
410
printer << " self" ;
404
411
if (parameters.any ())
405
412
printer << " , " ;
406
413
}
407
414
// Print remaining differentiation parameters.
408
415
interleave (parameters.set_bits (), [&](unsigned index) {
409
- switch (style) {
410
- case DifferentiationParameterPrintingStyle::Name:
416
+ switch (parameterKind) {
417
+ // Print differentiability parameters by name.
418
+ case DifferentiationParameterKind::Differentiability:
411
419
printer << function->getParameters ()->get (index)->getName ().str ();
412
420
break ;
413
- case DifferentiationParameterPrintingStyle::Index:
421
+ // Print linearity parameters by index.
422
+ case DifferentiationParameterKind::Linearity:
414
423
printer << index;
415
424
break ;
416
425
}
@@ -487,7 +496,7 @@ static void printDifferentiableAttrArguments(
487
496
if (!omitWrtClause) {
488
497
auto diffParamsString = getDifferentiationParametersClauseString (
489
498
original, attr->getParameterIndices (), attr->getParsedParameters (),
490
- DifferentiationParameterPrintingStyle::Name );
499
+ DifferentiationParameterKind::Differentiability );
491
500
// Check whether differentiation parameter clause is empty.
492
501
// Handles edge case where resolved parameter indices are unset and
493
502
// parsed parameters are empty. This case should never trigger for
@@ -927,7 +936,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
927
936
auto *derivative = cast<AbstractFunctionDecl>(D);
928
937
auto diffParamsString = getDifferentiationParametersClauseString (
929
938
derivative, attr->getParameterIndices (), attr->getParsedParameters (),
930
- DifferentiationParameterPrintingStyle::Name );
939
+ DifferentiationParameterKind::Differentiability );
931
940
if (!diffParamsString.empty ())
932
941
Printer << " , " << diffParamsString;
933
942
Printer << ' )' ;
@@ -942,7 +951,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
942
951
auto *transpose = cast<AbstractFunctionDecl>(D);
943
952
auto transParamsString = getDifferentiationParametersClauseString (
944
953
transpose, attr->getParameterIndices (), attr->getParsedParameters (),
945
- DifferentiationParameterPrintingStyle::Index );
954
+ DifferentiationParameterKind::Linearity );
946
955
if (!transParamsString.empty ())
947
956
Printer << " , " << transParamsString;
948
957
Printer << ' )' ;
@@ -1510,11 +1519,11 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment(
1510
1519
1511
1520
void DifferentiableAttr::print (llvm::raw_ostream &OS, const Decl *D,
1512
1521
bool omitWrtClause,
1513
- bool omitAssociatedFunctions ) const {
1522
+ bool omitDerivativeFunctions ) const {
1514
1523
StreamPrinter P (OS);
1515
1524
P << " @" << getAttrName ();
1516
1525
printDifferentiableAttrArguments (this , P, PrintOptions (), D, omitWrtClause,
1517
- omitAssociatedFunctions );
1526
+ omitDerivativeFunctions );
1518
1527
}
1519
1528
1520
1529
DerivativeAttr::DerivativeAttr (bool implicit, SourceLoc atLoc,
0 commit comments