Skip to content

Commit 7b064a6

Browse files
authored
[AutoDiff upstream] Add SIL differentiability witness serialization. (#29642)
SIL differentiability witnesses are a new top-level SIL construct mapping an "original" SIL function and derivative configuration to derivative SIL functions. This patch adds `SILDifferentiabilityWitness` serialization/deserialization. Resolves TF-1136.
1 parent bd26447 commit 7b064a6

File tree

11 files changed

+331
-13
lines changed

11 files changed

+331
-13
lines changed

include/swift/SIL/SILModule.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,10 @@ class SILModule {
637637
llvm::ArrayRef<SILDifferentiabilityWitness *>
638638
lookUpDifferentiabilityWitnessesForFunction(StringRef name);
639639

640+
/// Attempt to deserialize the SILDifferentiabilityWitness. Returns true if
641+
/// deserialization succeeded, false otherwise.
642+
bool loadDifferentiabilityWitness(SILDifferentiabilityWitness *dw);
643+
640644
// Given a protocol, attempt to create a default witness table declaration
641645
// for it.
642646
SILDefaultWitnessTable *

include/swift/Serialization/SerializedSILLoader.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class SILModule;
3232
class SILVTable;
3333
class SILWitnessTable;
3434
class SILDefaultWitnessTable;
35+
class SILDifferentiabilityWitness;
3536

3637
/// Maintains a list of SILDeserializer, one for each serialized modules
3738
/// in ASTContext. It provides lookupSILFunction that will perform lookup
@@ -64,6 +65,8 @@ class SerializedSILLoader {
6465
SILVTable *lookupVTable(const ClassDecl *C);
6566
SILWitnessTable *lookupWitnessTable(SILWitnessTable *C);
6667
SILDefaultWitnessTable *lookupDefaultWitnessTable(SILDefaultWitnessTable *C);
68+
SILDifferentiabilityWitness *
69+
lookupDifferentiabilityWitness(SILDifferentiabilityWitnessKey key);
6770

6871
/// Invalidate the cached entries for deserialized SILFunctions.
6972
void invalidateCaches();

lib/SIL/SILDifferentiabilityWitness.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212

1313
#define DEBUG_TYPE "sil-differentiability-witness"
1414

15-
// SWIFT_ENABLE_TENSORFLOW
1615
#include "swift/AST/ASTMangler.h"
17-
// SWIFT_ENABLE_TENSORFLOW_END
1816
#include "swift/SIL/SILDifferentiabilityWitness.h"
1917
#include "swift/SIL/SILModule.h"
2018

@@ -51,7 +49,6 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition(
5149
derivativeGenSig, jvp, vjp, /*isDeclaration*/ false, isSerialized,
5250
attribute);
5351
// Register the differentiability witness in the module.
54-
// Register the differentiability witness in the module.
5552
Mangle::ASTMangler mangler;
5653
auto mangledKey =
5754
mangler.mangleSILDifferentiabilityWitnessKey(diffWitness->getKey());

lib/SIL/SILModule.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,9 +580,8 @@ lookUpFunctionInVTable(ClassDecl *Class, SILDeclRef Member) {
580580
SILDifferentiabilityWitness *
581581
SILModule::lookUpDifferentiabilityWitness(StringRef name) {
582582
auto it = DifferentiabilityWitnessMap.find(name);
583-
if (it != DifferentiabilityWitnessMap.end()) {
583+
if (it != DifferentiabilityWitnessMap.end())
584584
return it->second;
585-
}
586585
return nullptr;
587586
}
588587

@@ -599,6 +598,14 @@ SILModule::lookUpDifferentiabilityWitnessesForFunction(StringRef name) {
599598
return DifferentiabilityWitnessesByFunction[name];
600599
}
601600

601+
bool SILModule::loadDifferentiabilityWitness(SILDifferentiabilityWitness *dw) {
602+
auto *newDW = getSILLoader()->lookupDifferentiabilityWitness(dw->getKey());
603+
if (!newDW)
604+
return false;
605+
assert(dw == newDW);
606+
return true;
607+
}
608+
602609
void SILModule::registerDeserializationNotificationHandler(
603610
std::unique_ptr<DeserializationNotificationHandler> &&handler) {
604611
deserializationNotificationHandlers.add(std::move(handler));

lib/Serialization/DeserializeSIL.cpp

Lines changed: 164 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,11 @@ SILDeserializer::SILDeserializer(
165165
kind == sil_index_block::SIL_GLOBALVAR_NAMES ||
166166
kind == sil_index_block::SIL_WITNESS_TABLE_NAMES ||
167167
kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES ||
168-
kind == sil_index_block::SIL_PROPERTY_OFFSETS)) &&
168+
kind == sil_index_block::SIL_PROPERTY_OFFSETS ||
169+
kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES)) &&
169170
"Expect SIL_FUNC_NAMES, SIL_VTABLE_NAMES, SIL_GLOBALVAR_NAMES, \
170-
SIL_WITNESS_TABLE_NAMES, or SIL_DEFAULT_WITNESS_TABLE_NAMES.");
171+
SIL_WITNESS_TABLE_NAMES, SIL_DEFAULT_WITNESS_TABLE_NAMES, \
172+
SIL_PROPERTY_OFFSETS, or SIL_DIFFERENTIABILITY_WITNESS_NAMES.");
171173
(void)prevKind;
172174

173175
if (kind == sil_index_block::SIL_FUNC_NAMES)
@@ -180,6 +182,8 @@ SILDeserializer::SILDeserializer(
180182
WitnessTableList = readFuncTable(scratch, blobData);
181183
else if (kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES)
182184
DefaultWitnessTableList = readFuncTable(scratch, blobData);
185+
else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES)
186+
DifferentiabilityWitnessList = readFuncTable(scratch, blobData);
183187
else if (kind == sil_index_block::SIL_PROPERTY_OFFSETS) {
184188
// No matching 'names' block for property descriptors needed yet.
185189
MF->allocateBuffer(Properties, scratch);
@@ -217,6 +221,12 @@ SILDeserializer::SILDeserializer(
217221
offKind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_OFFSETS) &&
218222
"Expect a SIL_DEFAULT_WITNESS_TABLE_OFFSETS record.");
219223
MF->allocateBuffer(DefaultWitnessTables, scratch);
224+
} else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES) {
225+
assert((next.Kind == llvm::BitstreamEntry::Record &&
226+
offKind ==
227+
sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS) &&
228+
"Expect a SIL_DIFFERENTIABILITY_WITNESS_OFFSETS record.");
229+
MF->allocateBuffer(DifferentiabilityWitnesses, scratch);
220230
}
221231
}
222232
}
@@ -339,6 +349,24 @@ SILType SILDeserializer::getSILType(Type Ty, SILValueCategory Category,
339349
.getCategoryType(Category);
340350
}
341351

352+
/// Helper function to find a SILDifferentiabilityWitness, given its mangled
353+
/// key.
354+
SILDifferentiabilityWitness *
355+
SILDeserializer::getSILDifferentiabilityWitnessForReference(
356+
StringRef mangledKey) {
357+
// Check to see if we have a witness under this key already.
358+
auto *witness = SILMod.lookUpDifferentiabilityWitness(mangledKey);
359+
if (witness)
360+
return witness;
361+
// Otherwise, look for a witness under this key in the module.
362+
if (!DifferentiabilityWitnessList)
363+
return nullptr;
364+
auto iter = DifferentiabilityWitnessList->find(mangledKey);
365+
if (iter == DifferentiabilityWitnessList->end())
366+
return nullptr;
367+
return readDifferentiabilityWitness(*iter);
368+
}
369+
342370
/// Helper function to find a SILFunction, given its name and type.
343371
SILFunction *SILDeserializer::getFuncForReference(StringRef name,
344372
SILType type) {
@@ -760,7 +788,7 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn,
760788
// SIL_VTABLE or SIL_GLOBALVAR or SIL_WITNESS_TABLE record also means the end
761789
// of this SILFunction.
762790
while (kind != SIL_FUNCTION && kind != SIL_VTABLE && kind != SIL_GLOBALVAR &&
763-
kind != SIL_WITNESS_TABLE) {
791+
kind != SIL_WITNESS_TABLE && kind != SIL_DIFFERENTIABILITY_WITNESS) {
764792
if (kind == SIL_BASIC_BLOCK)
765793
// Handle a SILBasicBlock record.
766794
CurrentBB = readSILBasicBlock(fn, CurrentBB, scratch);
@@ -2988,6 +3016,7 @@ void SILDeserializer::readWitnessTableEntries(
29883016
// Another record means the end of this WitnessTable.
29893017
while (kind != SIL_WITNESS_TABLE &&
29903018
kind != SIL_DEFAULT_WITNESS_TABLE &&
3019+
kind != SIL_DIFFERENTIABILITY_WITNESS &&
29913020
kind != SIL_FUNCTION) {
29923021
if (kind == SIL_DEFAULT_WITNESS_TABLE_NO_ENTRY) {
29933022
witnessEntries.push_back(SILDefaultWitnessTable::Entry());
@@ -3343,6 +3372,138 @@ SILDeserializer::lookupDefaultWitnessTable(SILDefaultWitnessTable *existingWt) {
33433372
return Wt;
33443373
}
33453374

3375+
SILDifferentiabilityWitness *
3376+
SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
3377+
if (DId == 0)
3378+
return nullptr;
3379+
assert(DId <= DifferentiabilityWitnesses.size() &&
3380+
"Invalid SILDifferentiabilityWitness ID");
3381+
3382+
auto &diffWitnessOrOffset = DifferentiabilityWitnesses[DId - 1];
3383+
if (diffWitnessOrOffset.isFullyDeserialized())
3384+
return diffWitnessOrOffset.get();
3385+
3386+
BCOffsetRAII restoreOffset(SILCursor);
3387+
if (auto err = SILCursor.JumpToBit(diffWitnessOrOffset.getOffset()))
3388+
MF->fatal(std::move(err));
3389+
llvm::Expected<llvm::BitstreamEntry> maybeEntry =
3390+
SILCursor.advance(AF_DontPopBlockAtEnd);
3391+
if (!maybeEntry)
3392+
MF->fatal(maybeEntry.takeError());
3393+
auto entry = maybeEntry.get();
3394+
if (entry.Kind == llvm::BitstreamEntry::Error) {
3395+
LLVM_DEBUG(llvm::dbgs() << "Cursor advance error in "
3396+
"readDefaultWitnessTable.\n");
3397+
return nullptr;
3398+
}
3399+
3400+
SmallVector<uint64_t, 64> scratch;
3401+
StringRef blobData;
3402+
llvm::Expected<unsigned> maybeKind =
3403+
SILCursor.readRecord(entry.ID, scratch, &blobData);
3404+
if (!maybeKind)
3405+
MF->fatal(maybeKind.takeError());
3406+
unsigned kind = maybeKind.get();
3407+
assert(kind == SIL_DIFFERENTIABILITY_WITNESS &&
3408+
"Expected sil_differentiability_witness");
3409+
(void)kind;
3410+
3411+
DeclID originalNameId, jvpNameId, vjpNameId;
3412+
unsigned rawLinkage, isDeclaration, isSerialized, numParameterIndices,
3413+
numResultIndices;
3414+
GenericSignatureID derivativeGenSigID;
3415+
ArrayRef<uint64_t> rawParameterAndResultIndices;
3416+
3417+
DifferentiabilityWitnessLayout::readRecord(
3418+
scratch, originalNameId, rawLinkage, isDeclaration, isSerialized,
3419+
derivativeGenSigID, jvpNameId, vjpNameId, numParameterIndices,
3420+
numResultIndices, rawParameterAndResultIndices);
3421+
3422+
if (isDeclaration) {
3423+
assert(!isSerialized && "declaration must not be serialized");
3424+
}
3425+
3426+
auto linkageOpt = fromStableSILLinkage(rawLinkage);
3427+
assert(linkageOpt &&
3428+
"Expected value linkage for sil_differentiability_witness");
3429+
auto originalName = MF->getIdentifierText(originalNameId);
3430+
auto jvpName = MF->getIdentifierText(jvpNameId);
3431+
auto vjpName = MF->getIdentifierText(vjpNameId);
3432+
auto *original = getFuncForReference(originalName);
3433+
assert(original && "Original function must be found");
3434+
auto *jvp = getFuncForReference(jvpName);
3435+
if (!jvpName.empty()) {
3436+
assert(!isDeclaration && "JVP must not be defined in declaration");
3437+
assert(jvp && "JVP function must be found if JVP name is not empty");
3438+
}
3439+
auto *vjp = getFuncForReference(vjpName);
3440+
if (!vjpName.empty()) {
3441+
assert(!isDeclaration && "VJP must not be defined in declaration");
3442+
assert(vjp && "VJP function must be found if VJP name is not empty");
3443+
}
3444+
auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID);
3445+
3446+
SmallVector<unsigned, 8> parameterAndResultIndices(
3447+
rawParameterAndResultIndices.begin(), rawParameterAndResultIndices.end());
3448+
assert(parameterAndResultIndices.size() ==
3449+
numParameterIndices + numResultIndices &&
3450+
"Parameter/result indices count mismatch");
3451+
auto *parameterIndices = IndexSubset::get(
3452+
MF->getContext(), original->getLoweredFunctionType()->getNumParameters(),
3453+
ArrayRef<unsigned>(parameterAndResultIndices)
3454+
.take_front(numParameterIndices));
3455+
auto *resultIndices = IndexSubset::get(
3456+
MF->getContext(), original->getLoweredFunctionType()->getNumResults(),
3457+
ArrayRef<unsigned>(parameterAndResultIndices)
3458+
.take_back(numResultIndices));
3459+
3460+
AutoDiffConfig config(parameterIndices, resultIndices, derivativeGenSig);
3461+
auto *diffWitness =
3462+
SILMod.lookUpDifferentiabilityWitness({originalName, config});
3463+
3464+
// Witnesses that we deserialize are always available externally; we never
3465+
// want to emit them ourselves.
3466+
auto linkage = swift::addExternalToLinkage(*linkageOpt);
3467+
3468+
// If there is no existing differentiability witness, create one.
3469+
if (!diffWitness)
3470+
diffWitness = SILDifferentiabilityWitness::createDeclaration(
3471+
SILMod, linkage, original, parameterIndices, resultIndices,
3472+
derivativeGenSig);
3473+
3474+
// If the current differentiability witness is merely a declaration, and the
3475+
// deserialized witness is a definition, upgrade the current differentiability
3476+
// witness to a definition. This can happen in the following situations:
3477+
// 1. The witness was just created above.
3478+
// 2. The witness started out as a declaration (e.g. the differentiation
3479+
// pass emitted a witness for an external function) and now we're loading
3480+
// the definition (e.g. an optimization pass asked for the definition and
3481+
// we found the definition serialized in this module).
3482+
if (diffWitness->isDeclaration() && !isDeclaration)
3483+
diffWitness->convertToDefinition(jvp, vjp, isSerialized);
3484+
3485+
diffWitnessOrOffset.set(diffWitness,
3486+
/*isFullyDeserialized*/ diffWitness->isDefinition());
3487+
return diffWitness;
3488+
}
3489+
3490+
SILDifferentiabilityWitness *SILDeserializer::lookupDifferentiabilityWitness(
3491+
StringRef mangledDiffWitnessKey) {
3492+
if (!DifferentiabilityWitnessList)
3493+
return nullptr;
3494+
auto iter = DifferentiabilityWitnessList->find(mangledDiffWitnessKey);
3495+
if (iter == DifferentiabilityWitnessList->end())
3496+
return nullptr;
3497+
return readDifferentiabilityWitness(*iter);
3498+
}
3499+
3500+
void SILDeserializer::getAllDifferentiabilityWitnesses() {
3501+
if (!DifferentiabilityWitnessList)
3502+
return;
3503+
for (unsigned I = 0, E = DifferentiabilityWitnesses.size(); I < E; ++I)
3504+
readDifferentiabilityWitness(I + 1);
3505+
}
3506+
33463507
SILDeserializer::~SILDeserializer() {
33473508
// Drop our references to anything we've deserialized.
33483509
for (auto &fnEntry : Funcs) {

lib/Serialization/DeserializeSIL.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ namespace swift {
5858
MutableArrayRef<ModuleFile::PartiallySerialized<SILProperty *>>
5959
Properties;
6060

61+
std::unique_ptr<SerializedFuncTable> DifferentiabilityWitnessList;
62+
MutableArrayRef<
63+
ModuleFile::PartiallySerialized<SILDifferentiabilityWitness *>>
64+
DifferentiabilityWitnesses;
65+
6166
/// A declaration will only
6267
llvm::DenseMap<NormalProtocolConformance *, SILWitnessTable *>
6368
ConformanceToWitnessTableMap;
@@ -113,6 +118,9 @@ namespace swift {
113118
SILType getSILType(Type ty, SILValueCategory category,
114119
SILFunction *inContext);
115120

121+
SILDifferentiabilityWitness *
122+
getSILDifferentiabilityWitnessForReference(StringRef mangledKey);
123+
116124
SILFunction *getFuncForReference(StringRef Name, SILType Ty);
117125
SILFunction *getFuncForReference(StringRef Name);
118126
SILVTable *readVTable(serialization::DeclID);
@@ -129,6 +137,8 @@ namespace swift {
129137
SILDefaultWitnessTable *
130138
readDefaultWitnessTable(serialization::DeclID,
131139
SILDefaultWitnessTable *existingWt);
140+
SILDifferentiabilityWitness *
141+
readDifferentiabilityWitness(serialization::DeclID);
132142

133143
Optional<KeyPathPatternComponent>
134144
readKeyPathComponent(ArrayRef<uint64_t> ListOfValues, unsigned &nextValue);
@@ -148,6 +158,8 @@ namespace swift {
148158
SILWitnessTable *lookupWitnessTable(SILWitnessTable *wt);
149159
SILDefaultWitnessTable *
150160
lookupDefaultWitnessTable(SILDefaultWitnessTable *wt);
161+
SILDifferentiabilityWitness *
162+
lookupDifferentiabilityWitness(StringRef mangledDiffWitnessKey);
151163

152164
/// Invalidate all cached SILFunctions.
153165
void invalidateFunctionCache();
@@ -172,6 +184,7 @@ namespace swift {
172184
getAllWitnessTables();
173185
getAllDefaultWitnessTables();
174186
getAllProperties();
187+
getAllDifferentiabilityWitnesses();
175188
}
176189

177190
/// Deserialize all SILFunctions inside the module and add them to SILMod.
@@ -195,6 +208,10 @@ namespace swift {
195208
/// to SILMod.
196209
void getAllProperties();
197210

211+
/// Deserialize all DifferentiabilityWitnesses inside the module and add
212+
/// them to SILMod.
213+
void getAllDifferentiabilityWitnesses();
214+
198215
SILDeserializer(ModuleFile *MF, SILModule &M,
199216
DeserializationNotificationHandlerSet *callback);
200217

lib/Serialization/ModuleFormat.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5555
/// describe what change you made. The content of this comment isn't important;
5656
/// it just ensures a conflict if two people change the module format.
5757
/// Don't worry about adhering to the 80-column limit for this line.
58-
const uint16_t SWIFTMODULE_VERSION_MINOR = 536; // Clang function types
58+
const uint16_t SWIFTMODULE_VERSION_MINOR = 537; // SIL differentiability witnesses
5959

6060
/// A standard hash seed used for all string hashes in a serialized module.
6161
///

lib/Serialization/SILFormat.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ namespace sil_index_block {
9797
SIL_DEFAULT_WITNESS_TABLE_NAMES,
9898
SIL_DEFAULT_WITNESS_TABLE_OFFSETS,
9999
SIL_PROPERTY_OFFSETS,
100+
SIL_DIFFERENTIABILITY_WITNESS_NAMES,
101+
SIL_DIFFERENTIABILITY_WITNESS_OFFSETS,
100102
};
101103

102104
using ListLayout = BCGenericRecordLayout<
@@ -141,6 +143,7 @@ namespace sil_block {
141143
SIL_WITNESS_CONDITIONAL_CONFORMANCE,
142144
SIL_DEFAULT_WITNESS_TABLE,
143145
SIL_DEFAULT_WITNESS_TABLE_NO_ENTRY,
146+
SIL_DIFFERENTIABILITY_WITNESS,
144147
SIL_INST_WITNESS_METHOD,
145148
SIL_SPECIALIZE_ATTR,
146149
SIL_PROPERTY,
@@ -250,6 +253,20 @@ namespace sil_block {
250253
DeclIDField
251254
>;
252255

256+
using DifferentiabilityWitnessLayout = BCRecordLayout<
257+
SIL_DIFFERENTIABILITY_WITNESS,
258+
DeclIDField, // Original function name
259+
SILLinkageField, // Linkage
260+
BCFixed<1>, // Is declaration?
261+
BCFixed<1>, // Is serialized?
262+
GenericSignatureIDField, // Derivative function generic signature
263+
DeclIDField, // JVP function name
264+
DeclIDField, // VJP function name
265+
BCVBR<8>, // Number of parameter indices
266+
BCVBR<8>, // Number of result indices
267+
BCArray<ValueIDField> // Parameter and result indices
268+
>;
269+
253270
using SILFunctionLayout =
254271
BCRecordLayout<SIL_FUNCTION, SILLinkageField,
255272
BCFixed<1>, // transparent

0 commit comments

Comments
 (0)