Skip to content

Commit 979914e

Browse files
committed
Implement @derivative attribut serialization.
Fix test/Serialization/derivative_attr.swift.
1 parent 571e9d9 commit 979914e

File tree

6 files changed

+141
-18
lines changed

6 files changed

+141
-18
lines changed

include/swift/AST/Attr.def

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ DECL_ATTR(differentiable, Differentiable,
513513
91)
514514
DECL_ATTR(derivative, Derivative,
515515
OnFunc | LongAttribute | AllowMultipleAttributes |
516-
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
517-
NotSerialized, 92)
516+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
517+
92)
518518
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
519519
OnAccessor | OnFunc | OnConstructor | OnSubscript |
520520
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove |
@@ -542,8 +542,8 @@ DECL_ATTR(quoted, Quoted,
542542
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
543543
DECL_ATTR(differentiating, Differentiating,
544544
OnFunc | LongAttribute | AllowMultipleAttributes |
545-
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
546-
NotSerialized, 98)
545+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
546+
98)
547547
// SWIFT_ENABLE_TENSORFLOW END
548548

549549
#undef TYPE_ATTR

lib/AST/Attr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -948,8 +948,8 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
948948
Printer.printAttrName("@derivative");
949949
Printer << "(of: ";
950950
auto *attr = cast<DerivativeAttr>(this);
951-
auto *derivative = cast<AbstractFunctionDecl>(D);
952951
Printer << attr->getOriginalFunctionName().Name;
952+
auto *derivative = cast<AbstractFunctionDecl>(D);
953953
auto diffParamsString = getDifferentiationParametersClauseString(
954954
derivative, attr->getParameterIndices(), attr->getParsedParameters());
955955
if (!diffParamsString.empty())
@@ -963,8 +963,8 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
963963
Printer.printAttrName("@transpose");
964964
Printer << '(';
965965
auto *attr = cast<TransposeAttr>(this);
966-
auto *transpose = cast<AbstractFunctionDecl>(D);
967966
Printer << attr->getOriginalFunctionName().Name;
967+
auto *transpose = cast<AbstractFunctionDecl>(D);
968968
auto transParamsString = getTransposedParametersClauseString(
969969
transpose, attr->getParameterIndices(), attr->getParsedParameters());
970970
if (!transParamsString.empty())

lib/Serialization/Deserialization.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,6 +2120,22 @@ getActualReadWriteImplKind(unsigned rawKind) {
21202120
return None;
21212121
}
21222122

2123+
// SWIFT_ENABLE_TENSORFLOW
2124+
/// Translate from the serialization DifferentiabilityKind enumerators, which
2125+
/// are guaranteed to be stable, to the AST ones.
2126+
static Optional<swift::AutoDiffDerivativeFunctionKind>
2127+
getActualAutoDiffDerivativeFunctionKind(uint8_t raw) {
2128+
switch (serialization::AutoDiffDerivativeFunctionKind(raw)) {
2129+
#define CASE(ID) \
2130+
case serialization::AutoDiffDerivativeFunctionKind::ID: \
2131+
return {swift::AutoDiffDerivativeFunctionKind::ID};
2132+
CASE(JVP)
2133+
CASE(VJP)
2134+
#undef CASE
2135+
}
2136+
return None;
2137+
}
2138+
21232139
void ModuleFile::configureStorage(AbstractStorageDecl *decl,
21242140
uint8_t rawOpaqueReadOwnership,
21252141
uint8_t rawReadImplKind,
@@ -4151,6 +4167,38 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
41514167
break;
41524168
}
41534169

4170+
// SWIFT_ENABLE_TENSORFLOW
4171+
case decls_block::Derivative_DECL_ATTR: {
4172+
bool isImplicit;
4173+
uint64_t origNameId;
4174+
DeclID origDeclId;
4175+
uint64_t rawDerivativeKind;
4176+
ArrayRef<uint64_t> parameters;
4177+
4178+
serialization::decls_block::DerivativeDeclAttrLayout::readRecord(
4179+
scratch, isImplicit, origNameId, origDeclId, rawDerivativeKind,
4180+
parameters);
4181+
4182+
DeclNameWithLoc origName{MF.getIdentifier(origNameId), DeclNameLoc()};
4183+
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
4184+
auto derivativeKind =
4185+
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
4186+
if (!derivativeKind)
4187+
MF.fatal();
4188+
llvm::SmallBitVector parametersBitVector(parameters.size());
4189+
for (unsigned i : indices(parameters))
4190+
parametersBitVector[i] = parameters[i];
4191+
auto *indices = IndexSubset::get(ctx, parametersBitVector);
4192+
4193+
auto *derivAttr = DerivativeAttr::create(
4194+
ctx, isImplicit, SourceLoc(), SourceRange(), origName, indices);
4195+
derivAttr->setOriginalFunction(origDecl);
4196+
derivAttr->setDerivativeKind(*derivativeKind);
4197+
Attr = derivAttr;
4198+
break;
4199+
}
4200+
// SWIFT_ENABLE_TENSORFLOW END
4201+
41544202
case decls_block::DynamicReplacement_DECL_ATTR: {
41554203
bool isImplicit;
41564204
uint64_t numArgs;

lib/Serialization/ModuleFormat.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5252
/// describe what change you made. The content of this comment isn't important;
5353
/// it just ensures a conflict if two people change the module format.
5454
/// Don't worry about adhering to the 80-column limit for this line.
55-
const uint16_t SWIFTMODULE_VERSION_MINOR = 528; // derivative function config table
55+
const uint16_t SWIFTMODULE_VERSION_MINOR = 529; // `@derivative` serialization
5656

5757
/// A standard hash seed used for all string hashes in a serialized module.
5858
///
@@ -232,6 +232,16 @@ enum class DifferentiabilityKind : uint8_t {
232232
};
233233
using DifferentiabilityKindField = BCFixed<2>;
234234

235+
// SWIFT_ENABLE_TENSORFLOW
236+
// These IDs must \em not be renumbered or reordered without incrementing the
237+
// module version.
238+
enum class AutoDiffDerivativeFunctionKind : uint8_t {
239+
JVP = 0,
240+
VJP = 1
241+
};
242+
using AutoDiffDerivativeFunctionKindField = BCFixed<1>;
243+
// SWIFT_ENABLE_TENSORFLOW END
244+
235245
enum class ForeignErrorConventionKind : uint8_t {
236246
ZeroResult,
237247
NonZeroResult,
@@ -1772,19 +1782,20 @@ namespace decls_block {
17721782
BCFixed<1>, // Implicit flag.
17731783
IdentifierIDField, // Original name.
17741784
DeclIDField, // Original function declaration.
1785+
AutoDiffDerivativeFunctionKindField, // Derivative function kind.
17751786
BCArray<BCFixed<1>> // Differentiation parameter indices' bitvector.
17761787
>;
17771788

17781789
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
17791790
using DifferentiatingDeclAttrLayout = DerivativeDeclAttrLayout;
1780-
1791+
17811792
// SWIFT_ENABLE_TENSORFLOW
17821793
using TransposeDeclAttrLayout = BCRecordLayout<
17831794
Transpose_DECL_ATTR,
17841795
BCFixed<1>, // Implicit flag.
17851796
IdentifierIDField, // Original name.
17861797
DeclIDField, // Original function declaration.
1787-
BCArray<BCFixed<1>> // Differentiation parameter indices' bitvector.
1798+
BCArray<BCFixed<1>> // Transposed parameter indices' bitvector.
17881799
>;
17891800

17901801
#define SIMPLE_DECL_ATTR(X, CLASS, ...) \

lib/Serialization/Serialization.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,6 +2029,21 @@ static uint8_t getRawStableVarDeclIntroducer(swift::VarDecl::Introducer intr) {
20292029
llvm_unreachable("bad variable decl introducer kind");
20302030
}
20312031

2032+
// SWIFT_ENABLE_TENSORFLOW
2033+
/// Translate from the AST differentiability kind enum to the Serialization enum
2034+
/// values, which are guaranteed to be stable.
2035+
static uint8_t getRawStableAutoDiffDerivativeFunctionKind(
2036+
swift::AutoDiffDerivativeFunctionKind kind) {
2037+
switch (kind) {
2038+
case swift::AutoDiffDerivativeFunctionKind::JVP:
2039+
return uint8_t(serialization::AutoDiffDerivativeFunctionKind::JVP);
2040+
case swift::AutoDiffDerivativeFunctionKind::VJP:
2041+
return uint8_t(serialization::AutoDiffDerivativeFunctionKind::VJP);
2042+
}
2043+
llvm_unreachable("bad derivative function kind");
2044+
}
2045+
// SWIFT_ENABLE_TENSORFLOW END
2046+
20322047
/// Returns true if the declaration of \p decl depends on \p problemContext
20332048
/// based on lexical nesting.
20342049
///
@@ -2129,9 +2144,7 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
21292144
case DAK_PrivateImport:
21302145
llvm_unreachable("cannot serialize attribute");
21312146
// SWIFT_ENABLE_TENSORFLOW
2132-
case DAK_Derivative:
21332147
case DAK_Transpose:
2134-
case DAK_Differentiating:
21352148
llvm_unreachable("cannot serialize attribute");
21362149
// SWIFT_ENABLE_TENSORFLOW END
21372150

@@ -2345,6 +2358,30 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
23452358
return;
23462359
}
23472360

2361+
// SWIFT_ENABLE_TENSORFLOW
2362+
case DAK_Derivative:
2363+
case DAK_Differentiating: {
2364+
auto abbrCode = S.DeclTypeAbbrCodes[DerivativeDeclAttrLayout::Code];
2365+
auto *attr = cast<DerivativeAttr>(DA);
2366+
assert(attr->getOriginalFunction() &&
2367+
"`@derivative` attribute should have original declaration set "
2368+
"during construction or parsing");
2369+
auto origName = attr->getOriginalFunctionName().Name.getBaseName();
2370+
IdentifierID origNameId = S.addDeclBaseNameRef(origName);
2371+
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction());
2372+
auto derivativeKind =
2373+
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
2374+
auto paramIndices = attr->getParameterIndices();
2375+
assert(paramIndices && "Parameter indices must be resolved");
2376+
SmallVector<bool, 4> indices;
2377+
for (unsigned i : range(paramIndices->getCapacity()))
2378+
indices.push_back(paramIndices->contains(i));
2379+
DerivativeDeclAttrLayout::emitRecord(
2380+
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
2381+
origDeclID, derivativeKind, indices);
2382+
return;
2383+
}
2384+
23482385
case DAK_Quoted: {
23492386
auto abbrCode = S.DeclTypeAbbrCodes[QuotedDeclAttrLayout::Code];
23502387
auto attr = cast<QuotedAttr>(DA);
@@ -2354,6 +2391,7 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
23542391
S.addDeclRef(attr->getQuoteDecl()));
23552392
return;
23562393
}
2394+
// SWIFT_ENABLE_TENSORFLOW END
23572395
}
23582396
}
23592397

test/Serialization/derivative_attr.swift

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,24 @@
77

88
// BCANALYZER-NOT: UnknownCode
99

10-
// CHECK: @differentiable(wrt: x, jvp: jvpAddWrtX)
11-
// CHECK-NEXT: @differentiable(wrt: (x, y), vjp: vjpAdd)
1210
func add(x: Float, y: Float) -> Float {
1311
return x + y
1412
}
13+
// CHECK: @derivative(of: add, wrt: x)
1514
@derivative(of: add, wrt: x)
1615
func jvpAddWrtX(x: Float, y: Float) -> (value: Float, differential: (Float) -> (Float)) {
1716
return (x + y, { $0 })
1817
}
18+
// CHECK: @derivative(of: add, wrt: (x, y))
1919
@derivative(of: add)
2020
func vjpAdd(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
2121
return (x + y, { ($0, $0) })
2222
}
2323

24-
// CHECK: @differentiable(wrt: x, vjp: vjpGeneric where T : Differentiable)
2524
func generic<T : Numeric>(x: T) -> T {
2625
return x
2726
}
27+
// CHECK: @derivative(of: generic, wrt: x)
2828
@derivative(of: generic)
2929
func vjpGeneric<T>(x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
3030
where T : Numeric, T : Differentiable
@@ -33,21 +33,47 @@ func vjpGeneric<T>(x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentV
3333
}
3434

3535
protocol InstanceMethod : Differentiable {
36-
// CHECK: @differentiable(wrt: (self, x), vjp: vjpFoo)
3736
func foo(_ x: Self) -> Self
38-
// CHECK: @differentiable(wrt: (self, x), jvp: jvpBarWrt where T == T.TangentVector)
3937
func bar<T : Differentiable>(_ x: T) -> Self
4038
}
4139
extension InstanceMethod {
40+
// CHECK: @derivative(of: foo, wrt: (self, x))
4241
@derivative(of: foo)
43-
func vjpFoo(x: Self) -> (value: Self, pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) {
42+
func vjpFoo(x: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
4443
return (x, { ($0, $0) })
4544
}
4645

46+
// CHECK: @derivative(of: bar, wrt: (self, x))
4747
@derivative(of: bar, wrt: (self, x))
48-
func jvpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, differential: (Self.TangentVector, T) -> Self.TangentVector)
48+
func jvpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T) -> TangentVector)
4949
where T == T.TangentVector
5050
{
5151
return (self, { dself, dx in dself })
5252
}
53+
54+
// CHECK: @derivative(of: bar, wrt: (self, x))
55+
@derivative(of: bar, wrt: (self, x))
56+
func vjpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, pullback: (TangentVector) -> (TangentVector, T))
57+
where T == T.TangentVector
58+
{
59+
return (self, { v in (v, .zero) })
60+
}
61+
}
62+
63+
// Test deprecated `@differentiating` attribute.
64+
// For simplicity, `@differentiating` is serialized/deserialized as
65+
// `@derivative` attribute.
66+
67+
func subtract(x: Float, y: Float) -> Float {
68+
return x - y
69+
}
70+
// CHECK: @derivative(of: subtract, wrt: x)
71+
@differentiating(subtract, wrt: x)
72+
func jvpSubtractWrtX(x: Float, y: Float) -> (value: Float, differential: (Float) -> (Float)) {
73+
return (x - y, { $0 })
74+
}
75+
// CHECK: @derivative(of: subtract, wrt: (x, y))
76+
@differentiating(subtract)
77+
func vjpSubtract(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
78+
return (x - y, { ($0, -$0) })
5379
}

0 commit comments

Comments
 (0)