Skip to content

Commit 8d9df6c

Browse files
authored
[AutoDiff] Serialize/Deserialize linear, make Attribute reflect whether Function is linear. (#25269)
- Makes the 'DifferentiableAttr' and 'DifferentiatingAttr' object reflect whether the function is marked as linear or not, instead of defaulting to false like I did temporarily in previous PRs. - Serializes/Deserializes the linear argument.
1 parent 2d9111d commit 8d9df6c

File tree

11 files changed

+96
-35
lines changed

11 files changed

+96
-35
lines changed

include/swift/AST/Attr.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,6 +1496,8 @@ class DifferentiableAttr final
14961496
ParsedAutoDiffParameter> {
14971497
friend TrailingObjects;
14981498

1499+
/// Whether this function is linear (optional).
1500+
bool linear;
14991501
/// The number of parsed parameters specified in 'wrt:'.
15001502
unsigned NumParsedParameters = 0;
15011503
/// The JVP function.
@@ -1520,13 +1522,15 @@ class DifferentiableAttr final
15201522

15211523
explicit DifferentiableAttr(ASTContext &context, bool implicit,
15221524
SourceLoc atLoc, SourceRange baseRange,
1525+
bool linear,
15231526
ArrayRef<ParsedAutoDiffParameter> parameters,
15241527
Optional<DeclNameWithLoc> jvp,
15251528
Optional<DeclNameWithLoc> vjp,
15261529
TrailingWhereClause *clause);
15271530

15281531
explicit DifferentiableAttr(ASTContext &context, bool implicit,
15291532
SourceLoc atLoc, SourceRange baseRange,
1533+
bool linear,
15301534
AutoDiffParameterIndices *indices,
15311535
Optional<DeclNameWithLoc> jvp,
15321536
Optional<DeclNameWithLoc> vjp,
@@ -1535,13 +1539,15 @@ class DifferentiableAttr final
15351539
public:
15361540
static DifferentiableAttr *create(ASTContext &context, bool implicit,
15371541
SourceLoc atLoc, SourceRange baseRange,
1542+
bool linear,
15381543
ArrayRef<ParsedAutoDiffParameter> params,
15391544
Optional<DeclNameWithLoc> jvp,
15401545
Optional<DeclNameWithLoc> vjp,
15411546
TrailingWhereClause *clause);
15421547

15431548
static DifferentiableAttr *create(ASTContext &context, bool implicit,
15441549
SourceLoc atLoc, SourceRange baseRange,
1550+
bool linear,
15451551
AutoDiffParameterIndices *indices,
15461552
Optional<DeclNameWithLoc> jvp,
15471553
Optional<DeclNameWithLoc> vjp,
@@ -1568,6 +1574,8 @@ class DifferentiableAttr final
15681574
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
15691575
return NumParsedParameters;
15701576
}
1577+
1578+
bool isLinear() const { return linear; }
15711579

15721580
TrailingWhereClause *getWhereClause() const { return WhereClause; }
15731581

@@ -1608,33 +1616,37 @@ class DifferentiatingAttr final
16081616
DeclNameWithLoc Original;
16091617
/// The original function, resolved by the type checker.
16101618
FuncDecl *OriginalFunction = nullptr;
1619+
/// Whether this function is linear (optional).
1620+
bool linear;
16111621
/// The number of parsed parameters specified in 'wrt:'.
16121622
unsigned NumParsedParameters = 0;
16131623
/// The differentiation parameters' indices, resolved by the type checker.
16141624
AutoDiffParameterIndices *ParameterIndices = nullptr;
16151625

16161626
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
16171627
SourceLoc atLoc, SourceRange baseRange,
1618-
DeclNameWithLoc original,
1628+
DeclNameWithLoc original, bool linear,
16191629
ArrayRef<ParsedAutoDiffParameter> params);
16201630

16211631
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
16221632
SourceLoc atLoc, SourceRange baseRange,
1623-
DeclNameWithLoc original,
1633+
DeclNameWithLoc original, bool linear,
16241634
AutoDiffParameterIndices *indices);
16251635

16261636
public:
16271637
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
16281638
SourceLoc atLoc, SourceRange baseRange,
1629-
DeclNameWithLoc original,
1639+
DeclNameWithLoc original, bool linear,
16301640
ArrayRef<ParsedAutoDiffParameter> params);
16311641

16321642
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
16331643
SourceLoc atLoc, SourceRange baseRange,
1634-
DeclNameWithLoc original,
1644+
DeclNameWithLoc original, bool linear,
16351645
AutoDiffParameterIndices *indices);
16361646

16371647
DeclNameWithLoc getOriginal() const { return Original; }
1648+
1649+
bool isLinear() const { return linear; }
16381650

16391651
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
16401652
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }

include/swift/AST/DiagnosticsParse.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1505,7 +1505,7 @@ ERROR(attr_differentiable_expected_label,none,
15051505
ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,
15061506
"expected an original function name", ())
15071507
ERROR(attr_differentiating_expected_label_linear_or_wrt,none,
1508-
"expected either 'linear' or 'wrt:'", ())
1508+
"expected either 'linear' or 'wrt:'", ())
15091509

15101510
// differentiation `wrt` parameters clause
15111511
ERROR(expected_colon_after_label,PointsToFirstBadToken,

include/swift/Serialization/ModuleFormat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,7 @@ namespace decls_block {
16531653
using DifferentiableDeclAttrLayout = BCRecordLayout<
16541654
Differentiable_DECL_ATTR,
16551655
BCFixed<1>, // Implicit flag.
1656+
BCFixed<1>, // Linear flag.
16561657
IdentifierIDField, // JVP name.
16571658
DeclIDField, // JVP function declaration.
16581659
IdentifierIDField, // VJP name.

lib/AST/Attr.cpp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -423,13 +423,19 @@ static void printDifferentiableAttrArguments(
423423
}
424424
stream << ", ";
425425
};
426+
427+
// Print if the function is marked as linear.
428+
if (attr->isLinear()) {
429+
isLeadingClause = false;
430+
stream << "linear";
431+
}
426432

427433
// Print differentiation parameters, if any.
428434
auto diffParamsString = getDifferentiationParametersClauseString(
429435
original, attr->getParameterIndices(), attr->getParsedParameters(),
430436
prettyPrintInModule);
431437
if (!diffParamsString.empty()) {
432-
isLeadingClause = false;
438+
printCommaIfNecessary();
433439
stream << diffParamsString;
434440
}
435441
// Print jvp function name.
@@ -1319,54 +1325,59 @@ SpecializeAttr *SpecializeAttr::create(ASTContext &Ctx, SourceLoc atLoc,
13191325
// SWIFT_ENABLE_TENSORFLOW
13201326
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
13211327
SourceLoc atLoc, SourceRange baseRange,
1328+
bool linear,
13221329
ArrayRef<ParsedAutoDiffParameter> params,
13231330
Optional<DeclNameWithLoc> jvp,
13241331
Optional<DeclNameWithLoc> vjp,
13251332
TrailingWhereClause *clause)
13261333
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1327-
NumParsedParameters(params.size()),
1328-
JVP(std::move(jvp)), VJP(std::move(vjp)), WhereClause(clause) {
1334+
linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
1335+
VJP(std::move(vjp)), WhereClause(clause) {
13291336
std::copy(params.begin(), params.end(),
13301337
getTrailingObjects<ParsedAutoDiffParameter>());
13311338
}
13321339

13331340
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
13341341
SourceLoc atLoc, SourceRange baseRange,
1342+
bool linear,
13351343
AutoDiffParameterIndices *indices,
13361344
Optional<DeclNameWithLoc> jvp,
13371345
Optional<DeclNameWithLoc> vjp,
13381346
ArrayRef<Requirement> requirements)
13391347
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1340-
JVP(std::move(jvp)), VJP(std::move(vjp)), ParameterIndices(indices) {
1348+
linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
1349+
ParameterIndices(indices) {
13411350
setRequirements(context, requirements);
13421351
}
13431352

13441353
DifferentiableAttr *
13451354
DifferentiableAttr::create(ASTContext &context, bool implicit,
13461355
SourceLoc atLoc, SourceRange baseRange,
1356+
bool linear,
13471357
ArrayRef<ParsedAutoDiffParameter> parameters,
13481358
Optional<DeclNameWithLoc> jvp,
13491359
Optional<DeclNameWithLoc> vjp,
13501360
TrailingWhereClause *clause) {
13511361
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(parameters.size());
13521362
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
13531363
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
1354-
parameters, std::move(jvp),
1364+
linear, parameters, std::move(jvp),
13551365
std::move(vjp), clause);
13561366
}
13571367

13581368
DifferentiableAttr *
13591369
DifferentiableAttr::create(ASTContext &context, bool implicit,
13601370
SourceLoc atLoc, SourceRange baseRange,
1371+
bool linear,
13611372
AutoDiffParameterIndices *indices,
13621373
Optional<DeclNameWithLoc> jvp,
13631374
Optional<DeclNameWithLoc> vjp,
13641375
ArrayRef<Requirement> requirements) {
13651376
void *mem = context.Allocate(sizeof(DifferentiableAttr),
13661377
alignof(DifferentiableAttr));
13671378
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
1368-
indices, std::move(jvp), std::move(vjp),
1369-
requirements);
1379+
linear, indices, std::move(jvp),
1380+
std::move(vjp), requirements);
13701381
}
13711382

13721383
void DifferentiableAttr::setRequirements(ASTContext &context,
@@ -1397,39 +1408,41 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
13971408
// SWIFT_ENABLE_TENSORFLOW
13981409
DifferentiatingAttr::DifferentiatingAttr(
13991410
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
1400-
DeclNameWithLoc original, ArrayRef<ParsedAutoDiffParameter> params)
1411+
DeclNameWithLoc original, bool linear,
1412+
ArrayRef<ParsedAutoDiffParameter> params)
14011413
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1402-
Original(std::move(original)), NumParsedParameters(params.size()) {
1414+
Original(std::move(original)), linear(linear),
1415+
NumParsedParameters(params.size()) {
14031416
std::copy(params.begin(), params.end(),
14041417
getTrailingObjects<ParsedAutoDiffParameter>());
14051418
}
14061419

14071420
DifferentiatingAttr::DifferentiatingAttr(
14081421
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
1409-
DeclNameWithLoc original, AutoDiffParameterIndices *indices)
1422+
DeclNameWithLoc original, bool linear, AutoDiffParameterIndices *indices)
14101423
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1411-
Original(std::move(original)), ParameterIndices(indices) {}
1424+
Original(std::move(original)), linear(linear), ParameterIndices(indices) {}
14121425

14131426
DifferentiatingAttr *
14141427
DifferentiatingAttr::create(ASTContext &context, bool implicit,
14151428
SourceLoc atLoc, SourceRange baseRange,
1416-
DeclNameWithLoc original,
1429+
DeclNameWithLoc original, bool linear,
14171430
ArrayRef<ParsedAutoDiffParameter> params) {
14181431
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
14191432
void *mem = context.Allocate(size, alignof(DifferentiatingAttr));
14201433
return new (mem) DifferentiatingAttr(context, implicit, atLoc, baseRange,
1421-
std::move(original), params);
1434+
std::move(original), linear, params);
14221435
}
14231436

14241437
DifferentiatingAttr *
14251438
DifferentiatingAttr::create(ASTContext &context, bool implicit,
14261439
SourceLoc atLoc, SourceRange baseRange,
1427-
DeclNameWithLoc original,
1440+
DeclNameWithLoc original, bool linear,
14281441
AutoDiffParameterIndices *indices) {
14291442
void *mem = context.Allocate(sizeof(DifferentiatingAttr),
14301443
alignof(DifferentiatingAttr));
14311444
return new (mem) DifferentiatingAttr(context, implicit, atLoc, baseRange,
1432-
std::move(original), indices);
1445+
std::move(original), linear, indices);
14331446
}
14341447

14351448
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,

lib/Parse/ParseDecl.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -829,8 +829,7 @@ ParserResult<DifferentiableAttr>
829829
Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
830830
StringRef AttrName = "differentiable";
831831
SourceLoc lParenLoc = loc, rParenLoc = loc;
832-
833-
bool linear;
832+
bool linear = false;
834833
SmallVector<ParsedAutoDiffParameter, 8> params;
835834
Optional<DeclNameWithLoc> jvpSpec;
836835
Optional<DeclNameWithLoc> vjpSpec;
@@ -851,9 +850,9 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
851850
}
852851

853852
return ParserResult<DifferentiableAttr>(
854-
DifferentiableAttr::create(Context, /*implicit*/ false, atLoc,
855-
SourceRange(loc, rParenLoc),
856-
params, jvpSpec, vjpSpec, whereClause));
853+
DifferentiableAttr::create(Context, /*implicit*/ false, atLoc,
854+
SourceRange(loc, rParenLoc), linear,
855+
params, jvpSpec, vjpSpec, whereClause));
857856
}
858857

859858
bool Parser::parseDifferentiationParametersClause(
@@ -1151,7 +1150,7 @@ Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) {
11511150
return ParserResult<DifferentiatingAttr>(
11521151
DifferentiatingAttr::create(Context, /*implicit*/ false, atLoc,
11531152
SourceRange(loc, rParenLoc),
1154-
original, params));
1153+
original, linear, params));
11551154
}
11561155

11571156
void Parser::parseObjCSelector(SmallVector<Identifier, 4> &Names,

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,8 +743,8 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
743743
if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl()))
744744
requirements = extDecl->getGenericRequirements();
745745
auto *diffableAttr = DifferentiableAttr::create(
746-
C, /*implicit*/ true, SourceLoc(), SourceLoc(), {}, None,
747-
None, requirements);
746+
C, /*implicit*/ true, SourceLoc(), SourceLoc(),
747+
/*linear*/ false, {}, None, None, requirements);
748748
member->getAttrs().add(diffableAttr);
749749
// If getter does not exist, trigger synthesis and compute type.
750750
if (!member->getGetter())

lib/Sema/TypeCheckAttr.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3452,9 +3452,9 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
34523452
// the same differentiation parameters, create one.
34533453
if (!da) {
34543454
da = DifferentiableAttr::create(ctx, /*implicit*/ true, attr->AtLoc,
3455-
attr->getRange(), checkedWrtParamIndices,
3456-
/*jvp*/ None, /*vjp*/ None,
3457-
derivativeRequirements);
3455+
attr->getRange(), attr->isLinear(),
3456+
checkedWrtParamIndices, /*jvp*/ None,
3457+
/*vjp*/ None, derivativeRequirements);
34583458
switch (kind) {
34593459
case AutoDiffAssociatedFunctionKind::JVP:
34603460
da->setJVPFunction(derivative);

lib/Serialization/Deserialization.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4127,6 +4127,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
41274127
// SWIFT_ENABLE_TENSORFLOW
41284128
case decls_block::Differentiable_DECL_ATTR: {
41294129
bool isImplicit;
4130+
bool linear;
41304131
uint64_t jvpNameId;
41314132
DeclID jvpDeclId;
41324133
uint64_t vjpNameId;
@@ -4135,8 +4136,8 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
41354136
SmallVector<Requirement, 4> requirements;
41364137

41374138
serialization::decls_block::DifferentiableDeclAttrLayout::readRecord(
4138-
scratch, isImplicit, jvpNameId, jvpDeclId, vjpNameId, vjpDeclId,
4139-
parameters);
4139+
scratch, isImplicit, linear, jvpNameId, jvpDeclId, vjpNameId,
4140+
vjpDeclId, parameters);
41404141

41414142
Optional<DeclNameWithLoc> jvp;
41424143
FuncDecl *jvpDecl = nullptr;
@@ -4160,7 +4161,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
41604161

41614162
auto diffAttr =
41624163
DifferentiableAttr::create(ctx, isImplicit, SourceLoc(),
4163-
SourceRange(), indices, jvp, vjp,
4164+
SourceRange(), linear, indices, jvp, vjp,
41644165
requirements);
41654166
diffAttr->setJVPFunction(jvpDecl);
41664167
diffAttr->setVJPFunction(vjpDecl);

lib/Serialization/Serialization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2544,7 +2544,7 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
25442544

25452545
DifferentiableDeclAttrLayout::emitRecord(
25462546
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(),
2547-
jvpName, jvpRef, vjpName, vjpRef, indices);
2547+
attr->isLinear(), jvpName, jvpRef, vjpName, vjpRef, indices);
25482548

25492549
S.writeGenericRequirements(attr->getRequirements(), S.DeclTypeAbbrCodes);
25502550
return;

test/Serialization/differentiable_attr.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@ func simple(x: Float) -> Float {
1414
return x
1515
}
1616

17+
// CHECK: @differentiable(linear, wrt: x, jvp: jvpSimple, vjp: vjpSimple)
18+
// CHECK-NEXT: func simple2(x: Float) -> Float
19+
@differentiable(linear, jvp: jvpSimple, vjp: vjpSimple)
20+
func simple2(x: Float) -> Float {
21+
return x
22+
}
23+
24+
// CHECK: @differentiable(linear, wrt: x, vjp: vjpSimple)
25+
// CHECK-NEXT: func simple3(x: Float) -> Float
26+
@differentiable(linear, vjp: vjpSimple)
27+
func simple3(x: Float) -> Float {
28+
return x
29+
}
30+
31+
// CHECK: @differentiable(linear, wrt: x)
32+
// CHECK-NEXT: func simple4(x: Float) -> Float
33+
@differentiable(linear)
34+
func simple4(x: Float) -> Float {
35+
return x
36+
}
37+
1738
func jvpSimple(x: Float) -> (Float, (Float) -> Float) {
1839
return (x, { v in v })
1940
}

test/Serialization/differentiating_attr.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@ func vjpAdd(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, F
2121
return (x + y, { ($0, $0) })
2222
}
2323

24+
// CHECK: @differentiable(linear, wrt: x, jvp: jvpLin)
25+
// CHECK-NEXT: @differentiable(linear, wrt: (x, y), vjp: vjpLin)
26+
func lin(x: Float, y: Float) -> Float {
27+
return x + y
28+
}
29+
@differentiating(lin, linear, wrt: x)
30+
func jvpLin(x: Float, y: Float) -> (value: Float, differential: (Float) -> (Float)) {
31+
return (x + y, { $0 })
32+
}
33+
@differentiating(lin, linear)
34+
func vjpLin(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
35+
return (x + y, { ($0, $0) })
36+
}
37+
2438
// CHECK: @differentiable(wrt: x, vjp: vjpGeneric where T : Differentiable)
2539
func generic<T : Numeric>(x: T) -> T {
2640
return x

0 commit comments

Comments
 (0)