Skip to content

Commit d2bf0dc

Browse files
committed
[Serialization] Implement serialization for @differentiable attribute. (#17155)
Implement (de)serialization for all components of `@differentiable` attribute except the trailing where clause (which needs to be type-checked). This is a necessary step for the `#adjoint` expression to look up `@differentiable` attributes declared on functions in other modules correctly. Addresses SR-7977.
1 parent 3f68d50 commit d2bf0dc

File tree

5 files changed

+163
-7
lines changed

5 files changed

+163
-7
lines changed

include/swift/AST/Attr.def

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,8 @@ SIMPLE_DECL_ATTR(_forbidSerializingReference, ForbidSerializingReference,
381381
77)
382382

383383
// SWIFT_ENABLE_TENSORFLOW
384-
// FIXME: Make it serialized
385384
DECL_ATTR(differentiable, Differentiable,
386-
OnFunc | LongAttribute | NotSerialized,
387-
/* Not serialized */ 78)
385+
OnFunc | LongAttribute, 78)
388386

389387
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
390388
OnFunc | OnConstructor, /* Not serialized */ 79)

include/swift/Serialization/ModuleFormat.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,8 +1447,6 @@ namespace decls_block {
14471447
= BCRecordLayout<RestatedObjCConformance_DECL_ATTR>;
14481448
using ClangImporterSynthesizedTypeDeclAttrLayout
14491449
= BCRecordLayout<ClangImporterSynthesizedType_DECL_ATTR>;
1450-
// SWIFT_ENABLE_TENSORFLOW
1451-
using DifferentiableDeclAttrLayout = BCRecordLayout<Differentiable_DECL_ATTR>;
14521450

14531451
using InlineDeclAttrLayout = BCRecordLayout<
14541452
Inline_DECL_ATTR,
@@ -1505,6 +1503,17 @@ namespace decls_block {
15051503
BCFixed<1> // specialization kind
15061504
>;
15071505

1506+
// SWIFT_ENABLE_TENSORFLOW
1507+
using DifferentiableDeclAttrLayout = BCRecordLayout<
1508+
Differentiable_DECL_ATTR,
1509+
BCFixed<1>, // Differentiation mode ('forward' or 'reverse').
1510+
IdentifierIDField, // Primal name.
1511+
DeclIDField, // Primal function declaration.
1512+
IdentifierIDField, // Adjoint name.
1513+
DeclIDField, // Adjoint function declaration.
1514+
BCArray<BCFixed<32>> // Differentiation parameters.
1515+
>;
1516+
15081517
#define SIMPLE_DECL_ATTR(X, CLASS, ...) \
15091518
using CLASS##DeclAttrLayout = BCRecordLayout< \
15101519
CLASS##_DECL_ATTR, \

lib/Serialization/Deserialization.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2506,6 +2506,52 @@ ModuleFile::getDeclCheckedImpl(DeclID DID, Optional<DeclContext *> ForcedContext
25062506
break;
25072507
}
25082508

2509+
// SWIFT_ENABLE_TENSORFLOW
2510+
case decls_block::Differentiable_DECL_ATTR: {
2511+
AutoDiffMode autodiffMode = AutoDiffMode::Reverse;
2512+
unsigned autodiffModeValue;
2513+
uint64_t primalNameId;
2514+
DeclID primalDeclId;
2515+
uint64_t adjointNameId;
2516+
DeclID adjointDeclId;
2517+
ArrayRef<uint64_t> paramValues;
2518+
2519+
serialization::decls_block::DifferentiableDeclAttrLayout::readRecord(
2520+
scratch, autodiffModeValue, primalNameId, primalDeclId, adjointNameId,
2521+
adjointDeclId, paramValues);
2522+
autodiffMode = autodiffModeValue
2523+
? AutoDiffMode::Reverse
2524+
: AutoDiffMode::Forward;
2525+
2526+
using FuncSpecifier = DifferentiableAttr::FunctionSpecifier;
2527+
Optional<FuncSpecifier> primal;
2528+
FuncDecl *primalDecl = nullptr;
2529+
if (primalNameId != 0 && primalDeclId != 0) {
2530+
primal = { getIdentifier(primalNameId), DeclNameLoc() };
2531+
primalDecl = cast<FuncDecl>(getDecl(primalDeclId));
2532+
}
2533+
FuncSpecifier adjoint = { getIdentifier(adjointNameId), DeclNameLoc() };
2534+
FuncDecl *adjointDecl = cast<FuncDecl>(getDecl(adjointDeclId));
2535+
2536+
SmallVector<AutoDiffParameter, 4> parameters;
2537+
SourceLoc loc;
2538+
for (auto paramValue : paramValues) {
2539+
auto parameter = paramValue & 0x01
2540+
? AutoDiffParameter::getSelfParameter(loc)
2541+
: AutoDiffParameter::getIndexParameter(loc, paramValue >> 1);
2542+
parameters.push_back(parameter);
2543+
}
2544+
// TODO: Deserialize trailing where clause.
2545+
auto diffAttr =
2546+
DifferentiableAttr::create(ctx, loc, SourceRange(), autodiffMode,
2547+
loc, parameters, primal, adjoint,
2548+
/*TrailingWhereClause*/ nullptr);
2549+
diffAttr->setPrimalFunction(primalDecl);
2550+
diffAttr->setAdjointFunction(adjointDecl);
2551+
Attr = diffAttr;
2552+
break;
2553+
}
2554+
25092555
#define SIMPLE_DECL_ATTR(NAME, CLASS, ...) \
25102556
case decls_block::CLASS##_DECL_ATTR: { \
25112557
bool isImplicit; \

lib/Serialization/Serialization.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2174,8 +2174,6 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) {
21742174
case DAK_ObjCRuntimeName:
21752175
case DAK_RestatedObjCConformance:
21762176
case DAK_ClangImporterSynthesizedType:
2177-
// SWIFT_ENABLE_TENSORFLOW
2178-
case DAK_Differentiable:
21792177
llvm_unreachable("cannot serialize attribute");
21802178

21812179
case DAK_Count:
@@ -2333,6 +2331,43 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) {
23332331
writeGenericRequirements(SA->getRequirements(), DeclTypeAbbrCodes);
23342332
return;
23352333
}
2334+
2335+
// SWIFT_ENABLE_TENSORFLOW
2336+
case DAK_Differentiable: {
2337+
auto abbrCode = DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
2338+
auto attr = cast<DifferentiableAttr>(DA);
2339+
2340+
IdentifierID primalName = 0;
2341+
DeclID primalRef = 0;
2342+
if (auto primal = attr->getPrimal()) {
2343+
primalName = addDeclBaseNameRef(primal->Name.getBaseName());
2344+
primalRef = addDeclRef(attr->getPrimalFunction());
2345+
}
2346+
auto adjointName = addDeclBaseNameRef(attr->getAdjoint().Name.getBaseName());
2347+
auto adjointRef = addDeclRef(attr->getAdjointFunction());
2348+
2349+
SmallVector<uint32_t, 4> parameters;
2350+
for (auto param : attr->getParameters()) {
2351+
switch (param.getKind()) {
2352+
// The self parameter is uniquely identified by 0x01.
2353+
case AutoDiffParameter::Kind::Self:
2354+
parameters.push_back(1);
2355+
break;
2356+
// Index parameters are left-shifted by 1.
2357+
case AutoDiffParameter::Kind::Index:
2358+
parameters.push_back(param.getIndex() << 1);
2359+
break;
2360+
}
2361+
}
2362+
2363+
DifferentiableDeclAttrLayout::emitRecord(
2364+
Out, ScratchRecord, abbrCode, (unsigned) attr->getMode(), primalName,
2365+
primalRef, adjointName, adjointRef, parameters);
2366+
// TODO: Serialize trailing where clause.
2367+
// Type-checking where clause should be done first (mimicking the
2368+
// @_specialize attribute).
2369+
return;
2370+
}
23362371
}
23372372
}
23382373

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// SWIFT_ENABLE_TENSORFLOW
2+
// TODO: Handle trailing where clause in @differentiable attribute.
3+
4+
// RUN: %empty-directory(%t)
5+
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
6+
// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s
7+
8+
struct CheckpointsFoo {}
9+
func pfoo(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) {
10+
return (CheckpointsFoo(), x * x)
11+
}
12+
func dfoo_checkpointed(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float {
13+
return 2 * x
14+
}
15+
// CHECK-DAG: @differentiable(reverse, primal: pfoo, adjoint: dfoo_checkpointed)
16+
// CHECK-DAG: func foo_checkpointed(_ x: Float) -> Float
17+
@differentiable(reverse, primal: pfoo(_:), adjoint: dfoo_checkpointed(_:checkpoints:originalValue:seed:))
18+
func foo_checkpointed(_ x: Float) -> Float {
19+
return x * x
20+
}
21+
22+
struct S<T> {
23+
struct Checkpoints {
24+
let s: S
25+
}
26+
func primal(x: Float) -> (Checkpoints, Float) {
27+
return (Checkpoints(s: self), x)
28+
}
29+
func adjoint_checkpointed(x: Float, _: Checkpoints, _: Float, _: Float) -> S {
30+
return self
31+
}
32+
33+
// CHECK-DAG: @differentiable(reverse, (self), primal: primal, adjoint: adjoint_checkpointed)
34+
// CHECK-DAG: func original(x: Float) -> Float
35+
@differentiable(reverse, withRespectTo: (self), primal: primal, adjoint: adjoint_checkpointed)
36+
func original(x: Float) -> Float {
37+
return x
38+
}
39+
}
40+
41+
func pbaz1<T>(_ x: T, _ y: T) -> ((T, T), T) {
42+
return ((y, y), x)
43+
}
44+
func dbaz1_checkpointed<T>(_ x: T, _ y: T, primal: (T, T), originalValue: T, seed: T) -> (T, T) {
45+
return (y, x)
46+
}
47+
// CHECK-DAG: @differentiable(reverse, primal: pbaz1, adjoint: dbaz1_checkpointed)
48+
// CHECK-DAG: func baz1_checkpointed<T>(_ x: T, _ y: T) -> T
49+
@differentiable(reverse, primal: pbaz1(_:_:), adjoint: dbaz1_checkpointed(_:_:primal:originalValue:seed:))
50+
func baz1_checkpointed<T>(_ x: T, _ y: T) -> T {
51+
return x
52+
}
53+
54+
struct CheckpointsFP<T : FloatingPoint> {
55+
let meow: T
56+
}
57+
func pbaz2<T : FloatingPoint>(_ x: T, _ y: T) -> (CheckpointsFP<T>, T) {
58+
return (CheckpointsFP(meow: 1), x + y)
59+
}
60+
func dbaz2_checkpointed<T : FloatingPoint>(_ x: T, _ y: T, primal: CheckpointsFP<T>, originalValue: T, seed: T) -> (T, T) {
61+
return (1, 1)
62+
}
63+
// CHECK-DAG: @differentiable(reverse, primal: pbaz2, adjoint: dbaz2_checkpointed)
64+
// CHECK-DAG: func baz2_checkpointed<T>(_ x: T, _ y: T) -> T where T : FloatingPoint
65+
@differentiable(reverse, primal: pbaz2(_:_:), adjoint: dbaz2_checkpointed(_:_:primal:originalValue:seed:))
66+
func baz2_checkpointed<T : FloatingPoint>(_ x: T, _ y: T) -> T {
67+
return x
68+
}

0 commit comments

Comments
 (0)