Skip to content

[Serialization] Implement serialization for @differentiable attribute. #17155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,8 @@ SIMPLE_DECL_ATTR(_weakLinked, WeakLinked,
75)

// SWIFT_ENABLE_TENSORFLOW
// FIXME: Make it serialized
DECL_ATTR(differentiable, Differentiable,
OnFunc | LongAttribute | NotSerialized,
/* Not serialized */ 76)
OnFunc | LongAttribute, 76)

SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
OnFunc | OnConstructor, /* Not serialized */ 77)
Expand Down
15 changes: 12 additions & 3 deletions include/swift/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ const uint16_t VERSION_MAJOR = 0;
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
/// Don't worry about adhering to the 80-column limit for this line.
const uint16_t VERSION_MINOR = 405; // SWIFT_ENABLE_TENSORFLOW: graph_op.
const uint16_t VERSION_MINOR = 406; // SWIFT_ENABLE_TENSORFLOW: serialize @differentiable.

using DeclIDField = BCFixed<31>;

Expand Down Expand Up @@ -1438,8 +1438,6 @@ namespace decls_block {
= BCRecordLayout<RestatedObjCConformance_DECL_ATTR>;
using ClangImporterSynthesizedTypeDeclAttrLayout
= BCRecordLayout<ClangImporterSynthesizedType_DECL_ATTR>;
// SWIFT_ENABLE_TENSORFLOW
using DifferentiableDeclAttrLayout = BCRecordLayout<Differentiable_DECL_ATTR>;

using InlineDeclAttrLayout = BCRecordLayout<
Inline_DECL_ATTR,
Expand Down Expand Up @@ -1496,6 +1494,17 @@ namespace decls_block {
BCFixed<1> // specialization kind
>;

// SWIFT_ENABLE_TENSORFLOW
using DifferentiableDeclAttrLayout = BCRecordLayout<
Differentiable_DECL_ATTR,
BCFixed<1>, // Differentiation mode ('forward' or 'reverse').
IdentifierIDField, // Primal name.
DeclIDField, // Primal function declaration.
IdentifierIDField, // Adjoint name.
DeclIDField, // Adjoint function declaration.
BCArray<BCFixed<32>> // Differentiation parameters.
>;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VERSION_MINOR in this file needs to be bumped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Done in d7ea8bc.

#define SIMPLE_DECL_ATTR(X, CLASS, ...) \
using CLASS##DeclAttrLayout = BCRecordLayout< \
CLASS##_DECL_ATTR, \
Expand Down
46 changes: 46 additions & 0 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2528,6 +2528,52 @@ ModuleFile::getDeclCheckedImpl(DeclID DID, Optional<DeclContext *> ForcedContext
break;
}

// SWIFT_ENABLE_TENSORFLOW
case decls_block::Differentiable_DECL_ATTR: {
AutoDiffMode autodiffMode = AutoDiffMode::Reverse;
unsigned autodiffModeValue;
uint64_t primalNameId;
DeclID primalDeclId;
uint64_t adjointNameId;
DeclID adjointDeclId;
ArrayRef<uint64_t> paramValues;

serialization::decls_block::DifferentiableDeclAttrLayout::readRecord(
scratch, autodiffModeValue, primalNameId, primalDeclId, adjointNameId,
adjointDeclId, paramValues);
autodiffMode = autodiffModeValue
? AutoDiffMode::Reverse
: AutoDiffMode::Forward;

using FuncSpecifier = DifferentiableAttr::FunctionSpecifier;
Optional<FuncSpecifier> primal;
FuncDecl *primalDecl = nullptr;
if (primalNameId != 0 && primalDeclId != 0) {
primal = { getIdentifier(primalNameId), DeclNameLoc() };
primalDecl = cast<FuncDecl>(getDecl(primalDeclId));
}
FuncSpecifier adjoint = { getIdentifier(adjointNameId), DeclNameLoc() };
FuncDecl *adjointDecl = cast<FuncDecl>(getDecl(adjointDeclId));

SmallVector<AutoDiffParameter, 4> parameters;
SourceLoc loc;
for (auto paramValue : paramValues) {
auto parameter = paramValue & 0x01
? AutoDiffParameter::getSelfParameter(loc)
: AutoDiffParameter::getIndexParameter(loc, paramValue >> 1);
parameters.push_back(parameter);
}
// TODO: Deserialize trailing where clause.
auto diffAttr =
DifferentiableAttr::create(ctx, loc, SourceRange(), autodiffMode,
loc, parameters, primal, adjoint,
/*TrailingWhereClause*/ nullptr);
diffAttr->setPrimalFunction(primalDecl);
diffAttr->setAdjointFunction(adjointDecl);
Attr = diffAttr;
break;
}

#define SIMPLE_DECL_ATTR(NAME, CLASS, ...) \
case decls_block::CLASS##_DECL_ATTR: { \
bool isImplicit; \
Expand Down
39 changes: 37 additions & 2 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2178,8 +2178,6 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) {
case DAK_ObjCRuntimeName:
case DAK_RestatedObjCConformance:
case DAK_ClangImporterSynthesizedType:
// SWIFT_ENABLE_TENSORFLOW
case DAK_Differentiable:
llvm_unreachable("cannot serialize attribute");

case DAK_Count:
Expand Down Expand Up @@ -2337,6 +2335,43 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) {
writeGenericRequirements(SA->getRequirements(), DeclTypeAbbrCodes);
return;
}

// SWIFT_ENABLE_TENSORFLOW
case DAK_Differentiable: {
auto abbrCode = DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
auto attr = cast<DifferentiableAttr>(DA);

IdentifierID primalName = 0;
DeclID primalRef = 0;
if (auto primal = attr->getPrimal()) {
primalName = addDeclBaseNameRef(primal->Name.getBaseName());
primalRef = addDeclRef(attr->getPrimalFunction());
}
auto adjointName = addDeclBaseNameRef(attr->getAdjoint().Name.getBaseName());
auto adjointRef = addDeclRef(attr->getAdjointFunction());

SmallVector<uint32_t, 4> parameters;
for (auto param : attr->getParameters()) {
switch (param.getKind()) {
// The self parameter is uniquely identified by 0x01.
case AutoDiffParameter::Kind::Self:
parameters.push_back(1);
break;
// Index parameters are left-shifted by 1.
case AutoDiffParameter::Kind::Index:
parameters.push_back(param.getIndex() << 1);
break;
}
}

DifferentiableDeclAttrLayout::emitRecord(
Out, ScratchRecord, abbrCode, (unsigned) attr->getMode(), primalName,
primalRef, adjointName, adjointRef, parameters);
// TODO: Serialize trailing where clause.
// Type-checking where clause should be done first (mimicking the
// @_specialize attribute).
return;
}
}
}

Expand Down
68 changes: 68 additions & 0 deletions test/Serialization/differentiable_attr.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SWIFT_ENABLE_TENSORFLOW
// TODO: Handle trailing where clause in @differentiable attribute.

// RUN: %empty-directory(%t)
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s

struct CheckpointsFoo {}
func pfoo(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) {
return (CheckpointsFoo(), x * x)
}
func dfoo_checkpointed(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float {
return 2 * x
}
// CHECK-DAG: @differentiable(reverse, primal: pfoo, adjoint: dfoo_checkpointed)
// CHECK-DAG: func foo_checkpointed(_ x: Float) -> Float
@differentiable(reverse, primal: pfoo(_:), adjoint: dfoo_checkpointed(_:checkpoints:originalValue:seed:))
func foo_checkpointed(_ x: Float) -> Float {
return x * x
}

struct S<T> {
struct Checkpoints {
let s: S
}
func primal(x: Float) -> (Checkpoints, Float) {
return (Checkpoints(s: self), x)
}
func adjoint_checkpointed(x: Float, _: Checkpoints, _: Float, _: Float) -> S {
return self
}

// CHECK-DAG: @differentiable(reverse, (self), primal: primal, adjoint: adjoint_checkpointed)
// CHECK-DAG: func original(x: Float) -> Float
@differentiable(reverse, withRespectTo: (self), primal: primal, adjoint: adjoint_checkpointed)
func original(x: Float) -> Float {
return x
}
}

func pbaz1<T>(_ x: T, _ y: T) -> ((T, T), T) {
return ((y, y), x)
}
func dbaz1_checkpointed<T>(_ x: T, _ y: T, primal: (T, T), originalValue: T, seed: T) -> (T, T) {
return (y, x)
}
// CHECK-DAG: @differentiable(reverse, primal: pbaz1, adjoint: dbaz1_checkpointed)
// CHECK-DAG: func baz1_checkpointed<T>(_ x: T, _ y: T) -> T
@differentiable(reverse, primal: pbaz1(_:_:), adjoint: dbaz1_checkpointed(_:_:primal:originalValue:seed:))
func baz1_checkpointed<T>(_ x: T, _ y: T) -> T {
return x
}

struct CheckpointsFP<T : FloatingPoint> {
let meow: T
}
func pbaz2<T : FloatingPoint>(_ x: T, _ y: T) -> (CheckpointsFP<T>, T) {
return (CheckpointsFP(meow: 1), x + y)
}
func dbaz2_checkpointed<T : FloatingPoint>(_ x: T, _ y: T, primal: CheckpointsFP<T>, originalValue: T, seed: T) -> (T, T) {
return (1, 1)
}
// CHECK-DAG: @differentiable(reverse, primal: pbaz2, adjoint: dbaz2_checkpointed)
// CHECK-DAG: func baz2_checkpointed<T>(_ x: T, _ y: T) -> T where T : FloatingPoint
@differentiable(reverse, primal: pbaz2(_:_:), adjoint: dbaz2_checkpointed(_:_:primal:originalValue:seed:))
func baz2_checkpointed<T : FloatingPoint>(_ x: T, _ y: T) -> T {
return x
}