Skip to content

Commit e688135

Browse files
committed
PR feedback and differentiating attr.
1 parent b398a76 commit e688135

File tree

8 files changed

+52
-21
lines changed

8 files changed

+52
-21
lines changed

include/swift/AST/Attr.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,33 +1616,37 @@ class DifferentiatingAttr final
16161616
DeclNameWithLoc Original;
16171617
/// The original function, resolved by the type checker.
16181618
FuncDecl *OriginalFunction = nullptr;
1619+
/// Whether this function is linear (optional).
1620+
bool linear;
16191621
/// The number of parsed parameters specified in 'wrt:'.
16201622
unsigned NumParsedParameters = 0;
16211623
/// The differentiation parameters' indices, resolved by the type checker.
16221624
AutoDiffParameterIndices *ParameterIndices = nullptr;
16231625

16241626
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
16251627
SourceLoc atLoc, SourceRange baseRange,
1626-
DeclNameWithLoc original,
1628+
DeclNameWithLoc original, bool linear,
16271629
ArrayRef<ParsedAutoDiffParameter> params);
16281630

16291631
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
16301632
SourceLoc atLoc, SourceRange baseRange,
1631-
DeclNameWithLoc original,
1633+
DeclNameWithLoc original, bool linear,
16321634
AutoDiffParameterIndices *indices);
16331635

16341636
public:
16351637
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
16361638
SourceLoc atLoc, SourceRange baseRange,
1637-
DeclNameWithLoc original,
1639+
DeclNameWithLoc original, bool linear,
16381640
ArrayRef<ParsedAutoDiffParameter> params);
16391641

16401642
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
16411643
SourceLoc atLoc, SourceRange baseRange,
1642-
DeclNameWithLoc original,
1644+
DeclNameWithLoc original, bool linear,
16431645
AutoDiffParameterIndices *indices);
16441646

16451647
DeclNameWithLoc getOriginal() const { return Original; }
1648+
1649+
bool isLinear() const { return linear; }
16461650

16471651
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
16481652
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }

include/swift/AST/DiagnosticsParse.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1505,7 +1505,7 @@ ERROR(attr_differentiable_expected_label,none,
15051505
ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,
15061506
"expected an original function name", ())
15071507
ERROR(attr_differentiating_expected_label_linear_or_wrt,none,
1508-
"expected either 'linear' or 'wrt:'", ())
1508+
"expected either 'linear' or 'wrt:'", ())
15091509

15101510
// differentiation `wrt` parameters clause
15111511
ERROR(expected_colon_after_label,PointsToFirstBadToken,

lib/AST/Attr.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,39 +1408,39 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
14081408
// SWIFT_ENABLE_TENSORFLOW
14091409
DifferentiatingAttr::DifferentiatingAttr(
14101410
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
1411-
DeclNameWithLoc original, ArrayRef<ParsedAutoDiffParameter> params)
1411+
DeclNameWithLoc original, bool linear, ArrayRef<ParsedAutoDiffParameter> params)
14121412
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1413-
Original(std::move(original)), NumParsedParameters(params.size()) {
1413+
Original(std::move(original)), linear(linear), NumParsedParameters(params.size()) {
14141414
std::copy(params.begin(), params.end(),
14151415
getTrailingObjects<ParsedAutoDiffParameter>());
14161416
}
14171417

14181418
DifferentiatingAttr::DifferentiatingAttr(
14191419
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
1420-
DeclNameWithLoc original, AutoDiffParameterIndices *indices)
1420+
DeclNameWithLoc original, bool linear, AutoDiffParameterIndices *indices)
14211421
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1422-
Original(std::move(original)), ParameterIndices(indices) {}
1422+
Original(std::move(original)), linear(linear), ParameterIndices(indices) {}
14231423

14241424
DifferentiatingAttr *
14251425
DifferentiatingAttr::create(ASTContext &context, bool implicit,
14261426
SourceLoc atLoc, SourceRange baseRange,
1427-
DeclNameWithLoc original,
1427+
DeclNameWithLoc original, bool linear,
14281428
ArrayRef<ParsedAutoDiffParameter> params) {
14291429
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
14301430
void *mem = context.Allocate(size, alignof(DifferentiatingAttr));
14311431
return new (mem) DifferentiatingAttr(context, implicit, atLoc, baseRange,
1432-
std::move(original), params);
1432+
std::move(original), linear, params);
14331433
}
14341434

14351435
DifferentiatingAttr *
14361436
DifferentiatingAttr::create(ASTContext &context, bool implicit,
14371437
SourceLoc atLoc, SourceRange baseRange,
1438-
DeclNameWithLoc original,
1438+
DeclNameWithLoc original, bool linear,
14391439
AutoDiffParameterIndices *indices) {
14401440
void *mem = context.Allocate(sizeof(DifferentiatingAttr),
14411441
alignof(DifferentiatingAttr));
14421442
return new (mem) DifferentiatingAttr(context, implicit, atLoc, baseRange,
1443-
std::move(original), indices);
1443+
std::move(original), linear, indices);
14441444
}
14451445

14461446
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,

lib/Parse/ParseDecl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -851,8 +851,8 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
851851

852852
return ParserResult<DifferentiableAttr>(
853853
DifferentiableAttr::create(Context, /*implicit*/ false, atLoc,
854-
SourceRange(loc, rParenLoc), linear,
855-
params, jvpSpec, vjpSpec, whereClause));
854+
SourceRange(loc, rParenLoc), linear,
855+
params, jvpSpec, vjpSpec, whereClause));
856856
}
857857

858858
bool Parser::parseDifferentiationParametersClause(
@@ -1150,7 +1150,7 @@ Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) {
11501150
return ParserResult<DifferentiatingAttr>(
11511151
DifferentiatingAttr::create(Context, /*implicit*/ false, atLoc,
11521152
SourceRange(loc, rParenLoc),
1153-
original, params));
1153+
original, linear, params));
11541154
}
11551155

11561156
void Parser::parseObjCSelector(SmallVector<Identifier, 4> &Names,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3452,7 +3452,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
34523452
// the same differentiation parameters, create one.
34533453
if (!da) {
34543454
da = DifferentiableAttr::create(ctx, /*implicit*/ true, attr->AtLoc,
3455-
attr->getRange(), /*TODO linear*/ false,
3455+
attr->getRange(), attr->isLinear(),
34563456
checkedWrtParamIndices, /*jvp*/ None,
34573457
/*vjp*/ None, derivativeRequirements);
34583458
switch (kind) {

lib/Serialization/Serialization.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2522,7 +2522,6 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
25222522
case DAK_Differentiable: {
25232523
auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
25242524
auto attr = cast<DifferentiableAttr>(DA);
2525-
bool linear = false;
25262525

25272526
IdentifierID jvpName = 0;
25282527
DeclID jvpRef = 0;
@@ -2545,7 +2544,7 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
25452544

25462545
DifferentiableDeclAttrLayout::emitRecord(
25472546
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(),
2548-
linear, jvpName, jvpRef, vjpName, vjpRef, indices);
2547+
attr->isLinear(), jvpName, jvpRef, vjpName, vjpRef, indices);
25492548

25502549
S.writeGenericRequirements(attr->getRequirements(), S.DeclTypeAbbrCodes);
25512550
return;

test/Serialization/differentiable_attr.swift

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,27 @@ func simple(x: Float) -> Float {
1414
return x
1515
}
1616

17-
// CHECK: @differentiable(wrt: x, jvp: jvpSimple, vjp: vjpSimple)
17+
// CHECK: @differentiable(linear, wrt: x, jvp: jvpSimple, vjp: vjpSimple)
1818
// CHECK-NEXT: func simple2(x: Float) -> Float
19-
@differentiable(linear, wrt: x, jvp: jvpSimple, vjp: vjpSimple)
19+
@differentiable(linear, jvp: jvpSimple, vjp: vjpSimple)
2020
func simple2(x: Float) -> Float {
2121
return x
2222
}
2323

24+
// CHECK: @differentiable(linear, wrt: x, vjp: vjpSimple)
25+
// CHECK-NEXT: func simple3(x: Float) -> Float
26+
@differentiable(linear, vjp: vjpSimple)
27+
func simple3(x: Float) -> Float {
28+
return x
29+
}
30+
31+
// CHECK: @differentiable(linear, wrt: x)
32+
// CHECK-NEXT: func simple4(x: Float) -> Float
33+
@differentiable(linear)
34+
func simple4(x: Float) -> Float {
35+
return x
36+
}
37+
2438
func jvpSimple(x: Float) -> (Float, (Float) -> Float) {
2539
return (x, { v in v })
2640
}

test/Serialization/differentiating_attr.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@ func vjpAdd(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, F
2121
return (x + y, { ($0, $0) })
2222
}
2323

24+
// CHECK: @differentiable(linear, wrt: x, jvp: jvpLin)
25+
// CHECK-NEXT: @differentiable(linear, wrt: (x, y), vjp: vjpLin)
26+
func lin(x: Float, y: Float) -> Float {
27+
return x + y
28+
}
29+
@differentiating(lin, linear, wrt: x)
30+
func jvpLin(x: Float, y: Float) -> (value: Float, differential: (Float) -> (Float)) {
31+
return (x + y, { $0 })
32+
}
33+
@differentiating(lin, linear)
34+
func vjpLin(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
35+
return (x + y, { ($0, $0) })
36+
}
37+
2438
// CHECK: @differentiable(wrt: x, vjp: vjpGeneric where T : Differentiable)
2539
func generic<T : Numeric>(x: T) -> T {
2640
return x

0 commit comments

Comments
 (0)