Skip to content

Commit 46ff5f2

Browse files
committed
Serialization of @differentiable attribute.
1 parent 374e64d commit 46ff5f2

File tree

3 files changed

+51
-9
lines changed

3 files changed

+51
-9
lines changed

lib/Serialization/ModuleFormat.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1734,7 +1734,19 @@ namespace decls_block {
17341734
GenericSignatureIDField // specialized signature
17351735
>;
17361736

1737-
#define SIMPLE_DECL_ATTR(X, CLASS, ...) \
1737+
using DifferentiableDeclAttrLayout = BCRecordLayout<
1738+
Differentiable_DECL_ATTR,
1739+
BCFixed<1>, // Implicit flag.
1740+
BCFixed<1>, // Linear flag.
1741+
IdentifierIDField, // JVP name.
1742+
DeclIDField, // JVP function declaration.
1743+
IdentifierIDField, // VJP name.
1744+
DeclIDField, // VJP function declaration.
1745+
GenericSignatureIDField, // Derivative generic signature.
1746+
BCArray<BCFixed<1>> // Differentiation parameter indices' bitvector.
1747+
>;
1748+
1749+
#define SIMPLE_DECL_ATTR(X, CLASS, ...) \
17381750
using CLASS##DeclAttrLayout = BCRecordLayout< \
17391751
CLASS##_DECL_ATTR, \
17401752
BCFixed<1> /* implicit flag */ \

lib/Serialization/Serialization.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,6 +2270,37 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
22702270
break;
22712271
}
22722272

2273+
case DAK_Differentiable: {
2274+
auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
2275+
auto *attr = cast<DifferentiableAttr>(DA);
2276+
2277+
IdentifierID jvpName = 0;
2278+
DeclID jvpRef = 0;
2279+
if (auto jvp = attr->getJVP())
2280+
jvpName = S.addDeclBaseNameRef(jvp->Name.getBaseName());
2281+
if (auto jvpFunction = attr->getJVPFunction())
2282+
jvpRef = S.addDeclRef(jvpFunction);
2283+
2284+
IdentifierID vjpName = 0;
2285+
DeclID vjpRef = 0;
2286+
if (auto vjp = attr->getVJP())
2287+
vjpName = S.addDeclBaseNameRef(vjp->Name.getBaseName());
2288+
if (auto vjpFunction = attr->getVJPFunction())
2289+
vjpRef = S.addDeclRef(vjpFunction);
2290+
2291+
auto paramIndices = attr->getParameterIndices();
2292+
assert(paramIndices && "Checked parameter indices must be resolved");
2293+
SmallVector<bool, 4> indices;
2294+
for (unsigned i : swift::indices(paramIndices->parameters))
2295+
indices.push_back(paramIndices->parameters[i]);
2296+
2297+
DifferentiableDeclAttrLayout::emitRecord(
2298+
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(),
2299+
attr->isLinear(), jvpName, jvpRef, vjpName, vjpRef,
2300+
S.addGenericSignatureRef(attr->getDerivativeGenericSignature()),
2301+
indices);
2302+
return;
2303+
}
22732304
}
22742305
}
22752306

utils/gyb_syntax_support/NodeSerializationCodes.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,13 @@
237237
'LayoutRequirement': 233,
238238
'LayoutConstraint': 234,
239239
'OpaqueReturnTypeOfAttributeArguments': 235,
240-
'LayoutConstraint': 236,
241-
'DifferentiableAttributeArguments': 237,
242-
'DifferentiationParamsClause': 238,
243-
'DifferentiationParams': 239,
244-
'DifferentiationParamList': 240,
245-
'DifferentiationParam': 241,
246-
'DifferentiableAttributeFuncSpecifier': 242,
247-
'FunctionDeclName': 243,
240+
'DifferentiableAttributeArguments': 236,
241+
'DifferentiationParamsClause': 237,
242+
'DifferentiationParams': 238,
243+
'DifferentiationParamList': 239,
244+
'DifferentiationParam': 240,
245+
'DifferentiableAttributeFuncSpecifier': 241,
246+
'FunctionDeclName': 242,
248247
}
249248

250249

0 commit comments

Comments
 (0)