Skip to content

[AutoDiff] Serialize/Deserialize linear, make Attribute reflect whether Function is linear. #25269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,8 @@ class DifferentiableAttr final
ParsedAutoDiffParameter> {
friend TrailingObjects;

/// Whether this function is linear (optional).
bool linear;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The JVP function.
Expand All @@ -1520,13 +1522,15 @@ class DifferentiableAttr final

explicit DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
ArrayRef<ParsedAutoDiffParameter> parameters,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause);

explicit DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
AutoDiffParameterIndices *indices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
Expand All @@ -1535,13 +1539,15 @@ class DifferentiableAttr final
public:
static DifferentiableAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
ArrayRef<ParsedAutoDiffParameter> params,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause);

static DifferentiableAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
AutoDiffParameterIndices *indices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
Expand All @@ -1568,6 +1574,8 @@ class DifferentiableAttr final
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
return NumParsedParameters;
}

bool isLinear() const { return linear; }

TrailingWhereClause *getWhereClause() const { return WhereClause; }

Expand Down Expand Up @@ -1608,33 +1616,37 @@ class DifferentiatingAttr final
DeclNameWithLoc Original;
/// The original function, resolved by the type checker.
FuncDecl *OriginalFunction = nullptr;
/// Whether this function is linear (optional).
bool linear;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
AutoDiffParameterIndices *ParameterIndices = nullptr;

explicit DifferentiatingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
DeclNameWithLoc original, bool linear,
ArrayRef<ParsedAutoDiffParameter> params);

explicit DifferentiatingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
DeclNameWithLoc original, bool linear,
AutoDiffParameterIndices *indices);

public:
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
DeclNameWithLoc original, bool linear,
ArrayRef<ParsedAutoDiffParameter> params);

static DifferentiatingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
DeclNameWithLoc original, bool linear,
AutoDiffParameterIndices *indices);

DeclNameWithLoc getOriginal() const { return Original; }

bool isLinear() const { return linear; }

FuncDecl *getOriginalFunction() const { return OriginalFunction; }
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1505,7 +1505,7 @@ ERROR(attr_differentiable_expected_label,none,
ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,
"expected an original function name", ())
ERROR(attr_differentiating_expected_label_linear_or_wrt,none,
"expected either 'linear' or 'wrt:'", ())
"expected either 'linear' or 'wrt:'", ())
Copy link
Contributor Author

@bartchr808 bartchr808 Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed PR request that @rxwei caught from a previous PR (thanks!)


// differentiation `wrt` parameters clause
ERROR(expected_colon_after_label,PointsToFirstBadToken,
Expand Down
1 change: 1 addition & 0 deletions include/swift/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,7 @@ namespace decls_block {
using DifferentiableDeclAttrLayout = BCRecordLayout<
Differentiable_DECL_ATTR,
BCFixed<1>, // Implicit flag.
BCFixed<1>, // Linear flag.
IdentifierIDField, // JVP name.
DeclIDField, // JVP function declaration.
IdentifierIDField, // VJP name.
Expand Down
43 changes: 28 additions & 15 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,19 @@ static void printDifferentiableAttrArguments(
}
stream << ", ";
};

// Print if the function is marked as linear.
if (attr->isLinear()) {
isLeadingClause = false;
stream << "linear";
}

// Print differentiation parameters, if any.
auto diffParamsString = getDifferentiationParametersClauseString(
original, attr->getParameterIndices(), attr->getParsedParameters(),
prettyPrintInModule);
if (!diffParamsString.empty()) {
isLeadingClause = false;
printCommaIfNecessary();
stream << diffParamsString;
}
// Print jvp function name.
Expand Down Expand Up @@ -1319,54 +1325,59 @@ SpecializeAttr *SpecializeAttr::create(ASTContext &Ctx, SourceLoc atLoc,
// SWIFT_ENABLE_TENSORFLOW
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
ArrayRef<ParsedAutoDiffParameter> params,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
NumParsedParameters(params.size()),
JVP(std::move(jvp)), VJP(std::move(vjp)), WhereClause(clause) {
linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
VJP(std::move(vjp)), WhereClause(clause) {
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
AutoDiffParameterIndices *indices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
ArrayRef<Requirement> requirements)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
JVP(std::move(jvp)), VJP(std::move(vjp)), ParameterIndices(indices) {
linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
ParameterIndices(indices) {
setRequirements(context, requirements);
}

DifferentiableAttr *
DifferentiableAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
ArrayRef<ParsedAutoDiffParameter> parameters,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(parameters.size());
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
parameters, std::move(jvp),
linear, parameters, std::move(jvp),
std::move(vjp), clause);
}

DifferentiableAttr *
DifferentiableAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
AutoDiffParameterIndices *indices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
ArrayRef<Requirement> requirements) {
void *mem = context.Allocate(sizeof(DifferentiableAttr),
alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
indices, std::move(jvp), std::move(vjp),
requirements);
linear, indices, std::move(jvp),
std::move(vjp), requirements);
}

void DifferentiableAttr::setRequirements(ASTContext &context,
Expand Down Expand Up @@ -1397,39 +1408,41 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
// SWIFT_ENABLE_TENSORFLOW
DifferentiatingAttr::DifferentiatingAttr(
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, ArrayRef<ParsedAutoDiffParameter> params)
DeclNameWithLoc original, bool linear,
ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
Original(std::move(original)), NumParsedParameters(params.size()) {
Original(std::move(original)), linear(linear),
NumParsedParameters(params.size()) {
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

DifferentiatingAttr::DifferentiatingAttr(
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, AutoDiffParameterIndices *indices)
DeclNameWithLoc original, bool linear, AutoDiffParameterIndices *indices)
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
Original(std::move(original)), ParameterIndices(indices) {}
Original(std::move(original)), linear(linear), ParameterIndices(indices) {}

DifferentiatingAttr *
DifferentiatingAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
DeclNameWithLoc original, bool linear,
ArrayRef<ParsedAutoDiffParameter> params) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
void *mem = context.Allocate(size, alignof(DifferentiatingAttr));
return new (mem) DifferentiatingAttr(context, implicit, atLoc, baseRange,
std::move(original), params);
std::move(original), linear, params);
}

DifferentiatingAttr *
DifferentiatingAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
DeclNameWithLoc original, bool linear,
AutoDiffParameterIndices *indices) {
void *mem = context.Allocate(sizeof(DifferentiatingAttr),
alignof(DifferentiatingAttr));
return new (mem) DifferentiatingAttr(context, implicit, atLoc, baseRange,
std::move(original), indices);
std::move(original), linear, indices);
}

ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
Expand Down
11 changes: 5 additions & 6 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,8 +829,7 @@ ParserResult<DifferentiableAttr>
Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
StringRef AttrName = "differentiable";
SourceLoc lParenLoc = loc, rParenLoc = loc;

bool linear;
bool linear = false;
SmallVector<ParsedAutoDiffParameter, 8> params;
Optional<DeclNameWithLoc> jvpSpec;
Optional<DeclNameWithLoc> vjpSpec;
Expand All @@ -851,9 +850,9 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
}

return ParserResult<DifferentiableAttr>(
DifferentiableAttr::create(Context, /*implicit*/ false, atLoc,
SourceRange(loc, rParenLoc),
params, jvpSpec, vjpSpec, whereClause));
DifferentiableAttr::create(Context, /*implicit*/ false, atLoc,
SourceRange(loc, rParenLoc), linear,
params, jvpSpec, vjpSpec, whereClause));
}

bool Parser::parseDifferentiationParametersClause(
Expand Down Expand Up @@ -1151,7 +1150,7 @@ Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) {
return ParserResult<DifferentiatingAttr>(
DifferentiatingAttr::create(Context, /*implicit*/ false, atLoc,
SourceRange(loc, rParenLoc),
original, params));
original, linear, params));
}

void Parser::parseObjCSelector(SmallVector<Identifier, 4> &Names,
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -743,8 +743,8 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl()))
requirements = extDecl->getGenericRequirements();
auto *diffableAttr = DifferentiableAttr::create(
C, /*implicit*/ true, SourceLoc(), SourceLoc(), {}, None,
None, requirements);
C, /*implicit*/ true, SourceLoc(), SourceLoc(),
/*linear*/ false, {}, None, None, requirements);
member->getAttrs().add(diffableAttr);
// If getter does not exist, trigger synthesis and compute type.
if (!member->getGetter())
Expand Down
6 changes: 3 additions & 3 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3452,9 +3452,9 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
// the same differentiation parameters, create one.
if (!da) {
da = DifferentiableAttr::create(ctx, /*implicit*/ true, attr->AtLoc,
attr->getRange(), checkedWrtParamIndices,
/*jvp*/ None, /*vjp*/ None,
derivativeRequirements);
attr->getRange(), attr->isLinear(),
checkedWrtParamIndices, /*jvp*/ None,
/*vjp*/ None, derivativeRequirements);
switch (kind) {
case AutoDiffAssociatedFunctionKind::JVP:
da->setJVPFunction(derivative);
Expand Down
7 changes: 4 additions & 3 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4127,6 +4127,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
// SWIFT_ENABLE_TENSORFLOW
case decls_block::Differentiable_DECL_ATTR: {
bool isImplicit;
bool linear;
uint64_t jvpNameId;
DeclID jvpDeclId;
uint64_t vjpNameId;
Expand All @@ -4135,8 +4136,8 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
SmallVector<Requirement, 4> requirements;

serialization::decls_block::DifferentiableDeclAttrLayout::readRecord(
scratch, isImplicit, jvpNameId, jvpDeclId, vjpNameId, vjpDeclId,
parameters);
scratch, isImplicit, linear, jvpNameId, jvpDeclId, vjpNameId,
vjpDeclId, parameters);

Optional<DeclNameWithLoc> jvp;
FuncDecl *jvpDecl = nullptr;
Expand All @@ -4160,7 +4161,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {

auto diffAttr =
DifferentiableAttr::create(ctx, isImplicit, SourceLoc(),
SourceRange(), indices, jvp, vjp,
SourceRange(), linear, indices, jvp, vjp,
requirements);
diffAttr->setJVPFunction(jvpDecl);
diffAttr->setVJPFunction(vjpDecl);
Expand Down
2 changes: 1 addition & 1 deletion lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2544,7 +2544,7 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {

DifferentiableDeclAttrLayout::emitRecord(
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(),
jvpName, jvpRef, vjpName, vjpRef, indices);
attr->isLinear(), jvpName, jvpRef, vjpName, vjpRef, indices);

S.writeGenericRequirements(attr->getRequirements(), S.DeclTypeAbbrCodes);
return;
Expand Down
21 changes: 21 additions & 0 deletions test/Serialization/differentiable_attr.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,27 @@ func simple(x: Float) -> Float {
return x
}

// CHECK: @differentiable(linear, wrt: x, jvp: jvpSimple, vjp: vjpSimple)
// CHECK-NEXT: func simple2(x: Float) -> Float
@differentiable(linear, jvp: jvpSimple, vjp: vjpSimple)
func simple2(x: Float) -> Float {
return x
}

// CHECK: @differentiable(linear, wrt: x, vjp: vjpSimple)
// CHECK-NEXT: func simple3(x: Float) -> Float
@differentiable(linear, vjp: vjpSimple)
func simple3(x: Float) -> Float {
return x
}

// CHECK: @differentiable(linear, wrt: x)
// CHECK-NEXT: func simple4(x: Float) -> Float
@differentiable(linear)
func simple4(x: Float) -> Float {
return x
}

func jvpSimple(x: Float) -> (Float, (Float) -> Float) {
return (x, { v in v })
}
Expand Down
14 changes: 14 additions & 0 deletions test/Serialization/differentiating_attr.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ func vjpAdd(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, F
return (x + y, { ($0, $0) })
}

// CHECK: @differentiable(linear, wrt: x, jvp: jvpLin)
// CHECK-NEXT: @differentiable(linear, wrt: (x, y), vjp: vjpLin)
func lin(x: Float, y: Float) -> Float {
return x + y
}
@differentiating(lin, linear, wrt: x)
func jvpLin(x: Float, y: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x + y, { $0 })
}
@differentiating(lin, linear)
func vjpLin(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}

// CHECK: @differentiable(wrt: x, vjp: vjpGeneric where T : Differentiable)
func generic<T : Numeric>(x: T) -> T {
return x
Expand Down