Skip to content

Commit e90816b

Browse files
authored
[AutoDiff] Clean up differentiation-related attributes. (#28466)
- Remove obsolete `linear` flag from `@differentiating` attribute. The `@transposing` attribute is now used for transpose registration. - Unify parsing diagnostics. - Unify duplicated parsing `errorAndSkipToEnd` logic. - Create a shared `errorAndSkipUntilConsumeRightParen` helper and use it when parsing `@differentiating` and `@transposing` attribute to improve (reduce) diagnostics. - Unify syntax for `@differentiating` and `@transposing` attributes: `DerivativeRegistrationAttributeArguments`. - Remove unused `ASTContext &` argument from `DifferentiableAttr` and `TransposingAttr` constructors. - Minor formatting changes.
1 parent e0789da commit e90816b

File tree

12 files changed

+143
-239
lines changed

12 files changed

+143
-239
lines changed

include/swift/AST/Attr.h

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,9 +1570,8 @@ class DifferentiableAttr final
15701570
/// has a where clause.
15711571
GenericSignature DerivativeGenericSignature = GenericSignature();
15721572

1573-
explicit DifferentiableAttr(ASTContext &context, bool implicit,
1574-
SourceLoc atLoc, SourceRange baseRange,
1575-
bool linear,
1573+
explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
1574+
SourceRange baseRange, bool linear,
15761575
ArrayRef<ParsedAutoDiffParameter> parameters,
15771576
Optional<DeclNameWithLoc> jvp,
15781577
Optional<DeclNameWithLoc> vjp,
@@ -1683,38 +1682,31 @@ class DifferentiatingAttr final
16831682
DeclNameWithLoc Original;
16841683
/// The original function, resolved by the type checker.
16851684
FuncDecl *OriginalFunction = nullptr;
1686-
/// Whether this function is linear.
1687-
bool Linear;
16881685
/// The number of parsed parameters specified in 'wrt:'.
16891686
unsigned NumParsedParameters = 0;
16901687
/// The differentiation parameters' indices, resolved by the type checker.
16911688
IndexSubset *ParameterIndices = nullptr;
16921689

1693-
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
1694-
SourceLoc atLoc, SourceRange baseRange,
1695-
DeclNameWithLoc original, bool linear,
1690+
explicit DifferentiatingAttr(bool implicit, SourceLoc atLoc,
1691+
SourceRange baseRange, DeclNameWithLoc original,
16961692
ArrayRef<ParsedAutoDiffParameter> params);
16971693

1698-
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
1699-
SourceLoc atLoc, SourceRange baseRange,
1700-
DeclNameWithLoc original, bool linear,
1694+
explicit DifferentiatingAttr(bool implicit, SourceLoc atLoc,
1695+
SourceRange baseRange, DeclNameWithLoc original,
17011696
IndexSubset *indices);
17021697

17031698
public:
17041699
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
17051700
SourceLoc atLoc, SourceRange baseRange,
1706-
DeclNameWithLoc original, bool linear,
1701+
DeclNameWithLoc original,
17071702
ArrayRef<ParsedAutoDiffParameter> params);
17081703

17091704
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
17101705
SourceLoc atLoc, SourceRange baseRange,
1711-
DeclNameWithLoc original, bool linear,
1706+
DeclNameWithLoc original,
17121707
IndexSubset *indices);
17131708

17141709
DeclNameWithLoc getOriginal() const { return Original; }
1715-
1716-
bool isLinear() const { return Linear; }
1717-
17181710
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
17191711
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
17201712

@@ -1765,34 +1757,33 @@ class TransposingAttr final
17651757
unsigned NumParsedParameters = 0;
17661758
/// The differentiation parameters' indices, resolved by the type checker.
17671759
IndexSubset *ParameterIndices = nullptr;
1768-
1769-
explicit TransposingAttr(ASTContext &context, bool implicit,
1770-
SourceLoc atLoc, SourceRange baseRange,
1771-
TypeRepr *baseType, DeclNameWithLoc original,
1760+
1761+
explicit TransposingAttr(bool implicit, SourceLoc atLoc,
1762+
SourceRange baseRange, TypeRepr *baseType,
1763+
DeclNameWithLoc original,
17721764
ArrayRef<ParsedAutoDiffParameter> params);
1773-
1774-
explicit TransposingAttr(ASTContext &context, bool implicit,
1775-
SourceLoc atLoc, SourceRange baseRange,
1776-
TypeRepr *baseType, DeclNameWithLoc original,
1777-
IndexSubset *indices);
1778-
1765+
1766+
explicit TransposingAttr(bool implicit, SourceLoc atLoc,
1767+
SourceRange baseRange, TypeRepr *baseType,
1768+
DeclNameWithLoc original, IndexSubset *indices);
1769+
17791770
public:
17801771
static TransposingAttr *create(ASTContext &context, bool implicit,
17811772
SourceLoc atLoc, SourceRange baseRange,
17821773
TypeRepr *baseType, DeclNameWithLoc original,
17831774
ArrayRef<ParsedAutoDiffParameter> params);
1784-
1775+
17851776
static TransposingAttr *create(ASTContext &context, bool implicit,
17861777
SourceLoc atLoc, SourceRange baseRange,
17871778
TypeRepr *baseType, DeclNameWithLoc original,
17881779
IndexSubset *indices);
1789-
1780+
17901781
TypeRepr *getBaseType() const { return BaseType; }
17911782
DeclNameWithLoc getOriginal() const { return Original; }
1792-
1783+
17931784
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
17941785
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
1795-
1786+
17961787
/// The parsed transposing parameters, i.e. the list of parameters
17971788
/// specified in 'wrt:'.
17981789
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
@@ -1804,14 +1795,14 @@ class TransposingAttr final
18041795
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
18051796
return NumParsedParameters;
18061797
}
1807-
1798+
18081799
IndexSubset *getParameterIndices() const {
18091800
return ParameterIndices;
18101801
}
18111802
void setParameterIndices(IndexSubset *pi) {
18121803
ParameterIndices = pi;
18131804
}
1814-
1805+
18151806
static bool classof(const DeclAttribute *DA) {
18161807
return DA->getKind() == DAK_Transposing;
18171808
}

include/swift/AST/DiagnosticsParse.def

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,17 +1557,17 @@ ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
15571557
"expected a list of parameters to differentiate with respect to", ())
15581558
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
15591559
"use 'wrt:' to specify parameters to differentiate with respect to", ())
1560-
ERROR(attr_differentiable_missing_label,PointsToFirstBadToken,
1561-
"missing label '%0:' in '@differentiable' attribute", (StringRef))
15621560
ERROR(attr_differentiable_expected_label,none,
15631561
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
15641562
"or 'vjp:'", ())
15651563

15661564
// differentiating
15671565
ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,
15681566
"expected an original function name", ())
1569-
ERROR(attr_differentiating_expected_label_linear_or_wrt,none,
1570-
"expected either 'linear' or 'wrt:'", ())
1567+
ERROR(attr_missing_label,PointsToFirstBadToken,
1568+
"missing label '%0:' in '@%1' attribute", (StringRef, StringRef))
1569+
ERROR(attr_expected_label,none,
1570+
"expected label '%0:' in '@%1' attribute", (StringRef, StringRef))
15711571

15721572
// transposing
15731573
ERROR(attr_transposing_expected_original_name,PointsToFirstBadToken,

include/swift/Parse/Parser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,8 @@ class Parser {
10071007
/// Parse the @differentiating attribute.
10081008
ParserResult<DifferentiatingAttr>
10091009
parseDifferentiatingAttribute(SourceLoc AtLoc, SourceLoc Loc);
1010-
1010+
1011+
/// Parse the @transposing attribute.
10111012
ParserResult<TransposingAttr> parseTransposingAttribute(SourceLoc AtLoc,
10121013
SourceLoc Loc);
10131014

lib/AST/Attr.cpp

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,9 +1451,8 @@ SpecializeAttr *SpecializeAttr::create(ASTContext &Ctx, SourceLoc atLoc,
14511451

14521452

14531453
// SWIFT_ENABLE_TENSORFLOW
1454-
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
1455-
SourceLoc atLoc, SourceRange baseRange,
1456-
bool linear,
1454+
DifferentiableAttr::DifferentiableAttr(bool implicit, SourceLoc atLoc,
1455+
SourceRange baseRange, bool linear,
14571456
ArrayRef<ParsedAutoDiffParameter> params,
14581457
Optional<DeclNameWithLoc> jvp,
14591458
Optional<DeclNameWithLoc> vjp,
@@ -1488,8 +1487,8 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
14881487
TrailingWhereClause *clause) {
14891488
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(parameters.size());
14901489
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
1491-
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
1492-
linear, parameters, std::move(jvp),
1490+
return new (mem) DifferentiableAttr(implicit, atLoc, baseRange, linear,
1491+
parameters, std::move(jvp),
14931492
std::move(vjp), clause);
14941493
}
14951494

@@ -1570,48 +1569,45 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
15701569

15711570
// SWIFT_ENABLE_TENSORFLOW
15721571
DifferentiatingAttr::DifferentiatingAttr(
1573-
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
1574-
DeclNameWithLoc original, bool linear,
1575-
ArrayRef<ParsedAutoDiffParameter> params)
1572+
bool implicit, SourceLoc atLoc, SourceRange baseRange,
1573+
DeclNameWithLoc original, ArrayRef<ParsedAutoDiffParameter> params)
15761574
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1577-
Original(std::move(original)), Linear(linear),
1578-
NumParsedParameters(params.size()) {
1575+
Original(std::move(original)), NumParsedParameters(params.size()) {
15791576
std::copy(params.begin(), params.end(),
15801577
getTrailingObjects<ParsedAutoDiffParameter>());
15811578
}
15821579

1583-
DifferentiatingAttr::DifferentiatingAttr(
1584-
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
1585-
DeclNameWithLoc original, bool linear, IndexSubset *indices)
1580+
DifferentiatingAttr::DifferentiatingAttr(bool implicit, SourceLoc atLoc,
1581+
SourceRange baseRange,
1582+
DeclNameWithLoc original,
1583+
IndexSubset *indices)
15861584
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1587-
Original(std::move(original)), Linear(linear), ParameterIndices(indices) {
1588-
}
1585+
Original(std::move(original)), ParameterIndices(indices) {}
15891586

15901587
DifferentiatingAttr *
1591-
DifferentiatingAttr::create(ASTContext &context, bool implicit,
1592-
SourceLoc atLoc, SourceRange baseRange,
1593-
DeclNameWithLoc original, bool linear,
1588+
DifferentiatingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1589+
SourceRange baseRange, DeclNameWithLoc original,
15941590
ArrayRef<ParsedAutoDiffParameter> params) {
15951591
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
15961592
void *mem = context.Allocate(size, alignof(DifferentiatingAttr));
1597-
return new (mem) DifferentiatingAttr(context, implicit, atLoc, baseRange,
1598-
std::move(original), linear, params);
1593+
return new (mem) DifferentiatingAttr(implicit, atLoc, baseRange,
1594+
std::move(original), params);
15991595
}
16001596

1601-
DifferentiatingAttr *
1602-
DifferentiatingAttr::create(ASTContext &context, bool implicit,
1603-
SourceLoc atLoc, SourceRange baseRange,
1604-
DeclNameWithLoc original, bool linear,
1605-
IndexSubset *indices) {
1597+
DifferentiatingAttr *DifferentiatingAttr::create(ASTContext &context,
1598+
bool implicit, SourceLoc atLoc,
1599+
SourceRange baseRange,
1600+
DeclNameWithLoc original,
1601+
IndexSubset *indices) {
16061602
void *mem = context.Allocate(sizeof(DifferentiatingAttr),
16071603
alignof(DifferentiatingAttr));
1608-
return new (mem) DifferentiatingAttr(context, implicit, atLoc, baseRange,
1609-
std::move(original), linear, indices);
1604+
return new (mem) DifferentiatingAttr(implicit, atLoc, baseRange,
1605+
std::move(original), indices);
16101606
}
16111607

1612-
TransposingAttr::TransposingAttr(ASTContext &context, bool implicit,
1613-
SourceLoc atLoc, SourceRange baseRange,
1614-
TypeRepr *baseType, DeclNameWithLoc original,
1608+
TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
1609+
SourceRange baseRange, TypeRepr *baseType,
1610+
DeclNameWithLoc original,
16151611
ArrayRef<ParsedAutoDiffParameter> params)
16161612
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
16171613
BaseType(baseType), Original(std::move(original)),
@@ -1620,10 +1616,9 @@ TransposingAttr::TransposingAttr(ASTContext &context, bool implicit,
16201616
getTrailingObjects<ParsedAutoDiffParameter>());
16211617
}
16221618

1623-
TransposingAttr::TransposingAttr(ASTContext &context, bool implicit,
1624-
SourceLoc atLoc, SourceRange baseRange,
1625-
TypeRepr *baseType, DeclNameWithLoc original,
1626-
IndexSubset *indices)
1619+
TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
1620+
SourceRange baseRange, TypeRepr *baseType,
1621+
DeclNameWithLoc original, IndexSubset *indices)
16271622
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
16281623
BaseType(baseType), Original(std::move(original)),
16291624
ParameterIndices(indices) {}
@@ -1635,8 +1630,8 @@ TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
16351630
ArrayRef<ParsedAutoDiffParameter> params) {
16361631
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
16371632
void *mem = context.Allocate(size, alignof(TransposingAttr));
1638-
return new (mem) TransposingAttr(context, implicit, atLoc, baseRange,
1639-
baseType, std::move(original), params);
1633+
return new (mem) TransposingAttr(implicit, atLoc, baseRange, baseType,
1634+
std::move(original), params);
16401635
}
16411636

16421637
TransposingAttr *
@@ -1646,8 +1641,8 @@ TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
16461641
IndexSubset *indices) {
16471642
void *mem =
16481643
context.Allocate(sizeof(TransposingAttr), alignof(TransposingAttr));
1649-
return new (mem) TransposingAttr(context, implicit, atLoc, baseRange,
1650-
baseType, std::move(original), indices);
1644+
return new (mem) TransposingAttr(implicit, atLoc, baseRange, baseType,
1645+
std::move(original), indices);
16511646
}
16521647

16531648
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,

0 commit comments

Comments
 (0)