Skip to content

Commit 32130d4

Browse files
authored
[Serialization] Implement serialization for @differentiable attribute. (#17155)
Implement (de)serialization for all components of `@differentiable` attribute except the trailing where clause (which needs to be type-checked). This is a necessary step for the `#adjoint` expression to look up `@differentiable` attributes declared on functions in other modules correctly. Addresses SR-7977.
1 parent 57df5e0 commit 32130d4

File tree

5 files changed

+164
-8
lines changed

5 files changed

+164
-8
lines changed

include/swift/AST/Attr.def

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,8 @@ SIMPLE_DECL_ATTR(_weakLinked, WeakLinked,
317317
75)
318318

319319
// SWIFT_ENABLE_TENSORFLOW
320-
// FIXME: Make it serialized
321320
DECL_ATTR(differentiable, Differentiable,
322-
OnFunc | LongAttribute | NotSerialized,
323-
/* Not serialized */ 76)
321+
OnFunc | LongAttribute, 76)
324322

325323
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
326324
OnFunc | OnConstructor, /* Not serialized */ 77)

include/swift/Serialization/ModuleFormat.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const uint16_t VERSION_MAJOR = 0;
5555
/// describe what change you made. The content of this comment isn't important;
5656
/// it just ensures a conflict if two people change the module format.
5757
/// Don't worry about adhering to the 80-column limit for this line.
58-
const uint16_t VERSION_MINOR = 405; // SWIFT_ENABLE_TENSORFLOW: graph_op.
58+
const uint16_t VERSION_MINOR = 406; // SWIFT_ENABLE_TENSORFLOW: serialize @differentiable.
5959

6060
using DeclIDField = BCFixed<31>;
6161

@@ -1438,8 +1438,6 @@ namespace decls_block {
14381438
= BCRecordLayout<RestatedObjCConformance_DECL_ATTR>;
14391439
using ClangImporterSynthesizedTypeDeclAttrLayout
14401440
= BCRecordLayout<ClangImporterSynthesizedType_DECL_ATTR>;
1441-
// SWIFT_ENABLE_TENSORFLOW
1442-
using DifferentiableDeclAttrLayout = BCRecordLayout<Differentiable_DECL_ATTR>;
14431441

14441442
using InlineDeclAttrLayout = BCRecordLayout<
14451443
Inline_DECL_ATTR,
@@ -1496,6 +1494,17 @@ namespace decls_block {
14961494
BCFixed<1> // specialization kind
14971495
>;
14981496

1497+
// SWIFT_ENABLE_TENSORFLOW
1498+
using DifferentiableDeclAttrLayout = BCRecordLayout<
1499+
Differentiable_DECL_ATTR,
1500+
BCFixed<1>, // Differentiation mode ('forward' or 'reverse').
1501+
IdentifierIDField, // Primal name.
1502+
DeclIDField, // Primal function declaration.
1503+
IdentifierIDField, // Adjoint name.
1504+
DeclIDField, // Adjoint function declaration.
1505+
BCArray<BCFixed<32>> // Differentiation parameters.
1506+
>;
1507+
14991508
#define SIMPLE_DECL_ATTR(X, CLASS, ...) \
15001509
using CLASS##DeclAttrLayout = BCRecordLayout< \
15011510
CLASS##_DECL_ATTR, \

lib/Serialization/Deserialization.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2528,6 +2528,52 @@ ModuleFile::getDeclCheckedImpl(DeclID DID, Optional<DeclContext *> ForcedContext
25282528
break;
25292529
}
25302530

2531+
// SWIFT_ENABLE_TENSORFLOW
2532+
case decls_block::Differentiable_DECL_ATTR: {
2533+
AutoDiffMode autodiffMode = AutoDiffMode::Reverse;
2534+
unsigned autodiffModeValue;
2535+
uint64_t primalNameId;
2536+
DeclID primalDeclId;
2537+
uint64_t adjointNameId;
2538+
DeclID adjointDeclId;
2539+
ArrayRef<uint64_t> paramValues;
2540+
2541+
serialization::decls_block::DifferentiableDeclAttrLayout::readRecord(
2542+
scratch, autodiffModeValue, primalNameId, primalDeclId, adjointNameId,
2543+
adjointDeclId, paramValues);
2544+
autodiffMode = autodiffModeValue
2545+
? AutoDiffMode::Reverse
2546+
: AutoDiffMode::Forward;
2547+
2548+
using FuncSpecifier = DifferentiableAttr::FunctionSpecifier;
2549+
Optional<FuncSpecifier> primal;
2550+
FuncDecl *primalDecl = nullptr;
2551+
if (primalNameId != 0 && primalDeclId != 0) {
2552+
primal = { getIdentifier(primalNameId), DeclNameLoc() };
2553+
primalDecl = cast<FuncDecl>(getDecl(primalDeclId));
2554+
}
2555+
FuncSpecifier adjoint = { getIdentifier(adjointNameId), DeclNameLoc() };
2556+
FuncDecl *adjointDecl = cast<FuncDecl>(getDecl(adjointDeclId));
2557+
2558+
SmallVector<AutoDiffParameter, 4> parameters;
2559+
SourceLoc loc;
2560+
for (auto paramValue : paramValues) {
2561+
auto parameter = paramValue & 0x01
2562+
? AutoDiffParameter::getSelfParameter(loc)
2563+
: AutoDiffParameter::getIndexParameter(loc, paramValue >> 1);
2564+
parameters.push_back(parameter);
2565+
}
2566+
// TODO: Deserialize trailing where clause.
2567+
auto diffAttr =
2568+
DifferentiableAttr::create(ctx, loc, SourceRange(), autodiffMode,
2569+
loc, parameters, primal, adjoint,
2570+
/*TrailingWhereClause*/ nullptr);
2571+
diffAttr->setPrimalFunction(primalDecl);
2572+
diffAttr->setAdjointFunction(adjointDecl);
2573+
Attr = diffAttr;
2574+
break;
2575+
}
2576+
25312577
#define SIMPLE_DECL_ATTR(NAME, CLASS, ...) \
25322578
case decls_block::CLASS##_DECL_ATTR: { \
25332579
bool isImplicit; \

lib/Serialization/Serialization.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,8 +2178,6 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) {
21782178
case DAK_ObjCRuntimeName:
21792179
case DAK_RestatedObjCConformance:
21802180
case DAK_ClangImporterSynthesizedType:
2181-
// SWIFT_ENABLE_TENSORFLOW
2182-
case DAK_Differentiable:
21832181
llvm_unreachable("cannot serialize attribute");
21842182

21852183
case DAK_Count:
@@ -2337,6 +2335,43 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) {
23372335
writeGenericRequirements(SA->getRequirements(), DeclTypeAbbrCodes);
23382336
return;
23392337
}
2338+
2339+
// SWIFT_ENABLE_TENSORFLOW
2340+
case DAK_Differentiable: {
2341+
auto abbrCode = DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
2342+
auto attr = cast<DifferentiableAttr>(DA);
2343+
2344+
IdentifierID primalName = 0;
2345+
DeclID primalRef = 0;
2346+
if (auto primal = attr->getPrimal()) {
2347+
primalName = addDeclBaseNameRef(primal->Name.getBaseName());
2348+
primalRef = addDeclRef(attr->getPrimalFunction());
2349+
}
2350+
auto adjointName = addDeclBaseNameRef(attr->getAdjoint().Name.getBaseName());
2351+
auto adjointRef = addDeclRef(attr->getAdjointFunction());
2352+
2353+
SmallVector<uint32_t, 4> parameters;
2354+
for (auto param : attr->getParameters()) {
2355+
switch (param.getKind()) {
2356+
// The self parameter is uniquely identified by 0x01.
2357+
case AutoDiffParameter::Kind::Self:
2358+
parameters.push_back(1);
2359+
break;
2360+
// Index parameters are left-shifted by 1.
2361+
case AutoDiffParameter::Kind::Index:
2362+
parameters.push_back(param.getIndex() << 1);
2363+
break;
2364+
}
2365+
}
2366+
2367+
DifferentiableDeclAttrLayout::emitRecord(
2368+
Out, ScratchRecord, abbrCode, (unsigned) attr->getMode(), primalName,
2369+
primalRef, adjointName, adjointRef, parameters);
2370+
// TODO: Serialize trailing where clause.
2371+
// Type-checking where clause should be done first (mimicking the
2372+
// @_specialize attribute).
2373+
return;
2374+
}
23402375
}
23412376
}
23422377

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// SWIFT_ENABLE_TENSORFLOW
2+
// TODO: Handle trailing where clause in @differentiable attribute.
3+
4+
// RUN: %empty-directory(%t)
5+
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
6+
// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s
7+
8+
struct CheckpointsFoo {}
9+
func pfoo(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) {
10+
return (CheckpointsFoo(), x * x)
11+
}
12+
func dfoo_checkpointed(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float {
13+
return 2 * x
14+
}
15+
// CHECK-DAG: @differentiable(reverse, primal: pfoo, adjoint: dfoo_checkpointed)
16+
// CHECK-DAG: func foo_checkpointed(_ x: Float) -> Float
17+
@differentiable(reverse, primal: pfoo(_:), adjoint: dfoo_checkpointed(_:checkpoints:originalValue:seed:))
18+
func foo_checkpointed(_ x: Float) -> Float {
19+
return x * x
20+
}
21+
22+
struct S<T> {
23+
struct Checkpoints {
24+
let s: S
25+
}
26+
func primal(x: Float) -> (Checkpoints, Float) {
27+
return (Checkpoints(s: self), x)
28+
}
29+
func adjoint_checkpointed(x: Float, _: Checkpoints, _: Float, _: Float) -> S {
30+
return self
31+
}
32+
33+
// CHECK-DAG: @differentiable(reverse, (self), primal: primal, adjoint: adjoint_checkpointed)
34+
// CHECK-DAG: func original(x: Float) -> Float
35+
@differentiable(reverse, withRespectTo: (self), primal: primal, adjoint: adjoint_checkpointed)
36+
func original(x: Float) -> Float {
37+
return x
38+
}
39+
}
40+
41+
func pbaz1<T>(_ x: T, _ y: T) -> ((T, T), T) {
42+
return ((y, y), x)
43+
}
44+
func dbaz1_checkpointed<T>(_ x: T, _ y: T, primal: (T, T), originalValue: T, seed: T) -> (T, T) {
45+
return (y, x)
46+
}
47+
// CHECK-DAG: @differentiable(reverse, primal: pbaz1, adjoint: dbaz1_checkpointed)
48+
// CHECK-DAG: func baz1_checkpointed<T>(_ x: T, _ y: T) -> T
49+
@differentiable(reverse, primal: pbaz1(_:_:), adjoint: dbaz1_checkpointed(_:_:primal:originalValue:seed:))
50+
func baz1_checkpointed<T>(_ x: T, _ y: T) -> T {
51+
return x
52+
}
53+
54+
struct CheckpointsFP<T : FloatingPoint> {
55+
let meow: T
56+
}
57+
func pbaz2<T : FloatingPoint>(_ x: T, _ y: T) -> (CheckpointsFP<T>, T) {
58+
return (CheckpointsFP(meow: 1), x + y)
59+
}
60+
func dbaz2_checkpointed<T : FloatingPoint>(_ x: T, _ y: T, primal: CheckpointsFP<T>, originalValue: T, seed: T) -> (T, T) {
61+
return (1, 1)
62+
}
63+
// CHECK-DAG: @differentiable(reverse, primal: pbaz2, adjoint: dbaz2_checkpointed)
64+
// CHECK-DAG: func baz2_checkpointed<T>(_ x: T, _ y: T) -> T where T : FloatingPoint
65+
@differentiable(reverse, primal: pbaz2(_:_:), adjoint: dbaz2_checkpointed(_:_:primal:originalValue:seed:))
66+
func baz2_checkpointed<T : FloatingPoint>(_ x: T, _ y: T) -> T {
67+
return x
68+
}

0 commit comments

Comments
 (0)