Skip to content

[AutoDiff upstream] Add SIL differentiability witness serialization. #29642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/swift/SIL/SILModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,10 @@ class SILModule {
llvm::ArrayRef<SILDifferentiabilityWitness *>
lookUpDifferentiabilityWitnessesForFunction(StringRef name);

/// Attempt to deserialize the SILDifferentiabilityWitness. Returns true if
/// deserialization succeeded, false otherwise.
bool loadDifferentiabilityWitness(SILDifferentiabilityWitness *dw);

// Given a protocol, attempt to create a default witness table declaration
// for it.
SILDefaultWitnessTable *
Expand Down
3 changes: 3 additions & 0 deletions include/swift/Serialization/SerializedSILLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class SILModule;
class SILVTable;
class SILWitnessTable;
class SILDefaultWitnessTable;
class SILDifferentiabilityWitness;

/// Maintains a list of SILDeserializer, one for each serialized modules
/// in ASTContext. It provides lookupSILFunction that will perform lookup
Expand Down Expand Up @@ -64,6 +65,8 @@ class SerializedSILLoader {
SILVTable *lookupVTable(const ClassDecl *C);
SILWitnessTable *lookupWitnessTable(SILWitnessTable *C);
SILDefaultWitnessTable *lookupDefaultWitnessTable(SILDefaultWitnessTable *C);
SILDifferentiabilityWitness *
lookupDifferentiabilityWitness(SILDifferentiabilityWitnessKey key);

/// Invalidate the cached entries for deserialized SILFunctions.
void invalidateCaches();
Expand Down
3 changes: 0 additions & 3 deletions lib/SIL/SILDifferentiabilityWitness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

#define DEBUG_TYPE "sil-differentiability-witness"

// SWIFT_ENABLE_TENSORFLOW
#include "swift/AST/ASTMangler.h"
// SWIFT_ENABLE_TENSORFLOW_END
#include "swift/SIL/SILDifferentiabilityWitness.h"
#include "swift/SIL/SILModule.h"

Expand Down Expand Up @@ -51,7 +49,6 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition(
derivativeGenSig, jvp, vjp, /*isDeclaration*/ false, isSerialized,
attribute);
// Register the differentiability witness in the module.
// Register the differentiability witness in the module.
Mangle::ASTMangler mangler;
auto mangledKey =
mangler.mangleSILDifferentiabilityWitnessKey(diffWitness->getKey());
Expand Down
11 changes: 9 additions & 2 deletions lib/SIL/SILModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,9 +580,8 @@ lookUpFunctionInVTable(ClassDecl *Class, SILDeclRef Member) {
SILDifferentiabilityWitness *
SILModule::lookUpDifferentiabilityWitness(StringRef name) {
auto it = DifferentiabilityWitnessMap.find(name);
if (it != DifferentiabilityWitnessMap.end()) {
if (it != DifferentiabilityWitnessMap.end())
return it->second;
}
return nullptr;
}

Expand All @@ -599,6 +598,14 @@ SILModule::lookUpDifferentiabilityWitnessesForFunction(StringRef name) {
return DifferentiabilityWitnessesByFunction[name];
}

bool SILModule::loadDifferentiabilityWitness(SILDifferentiabilityWitness *dw) {
auto *newDW = getSILLoader()->lookupDifferentiabilityWitness(dw->getKey());
if (!newDW)
return false;
assert(dw == newDW);
return true;
}

void SILModule::registerDeserializationNotificationHandler(
std::unique_ptr<DeserializationNotificationHandler> &&handler) {
deserializationNotificationHandlers.add(std::move(handler));
Expand Down
167 changes: 164 additions & 3 deletions lib/Serialization/DeserializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ SILDeserializer::SILDeserializer(
kind == sil_index_block::SIL_GLOBALVAR_NAMES ||
kind == sil_index_block::SIL_WITNESS_TABLE_NAMES ||
kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES ||
kind == sil_index_block::SIL_PROPERTY_OFFSETS)) &&
kind == sil_index_block::SIL_PROPERTY_OFFSETS ||
kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES)) &&
"Expect SIL_FUNC_NAMES, SIL_VTABLE_NAMES, SIL_GLOBALVAR_NAMES, \
SIL_WITNESS_TABLE_NAMES, or SIL_DEFAULT_WITNESS_TABLE_NAMES.");
SIL_WITNESS_TABLE_NAMES, SIL_DEFAULT_WITNESS_TABLE_NAMES, \
SIL_PROPERTY_OFFSETS, or SIL_DIFFERENTIABILITY_WITNESS_NAMES.");
(void)prevKind;

if (kind == sil_index_block::SIL_FUNC_NAMES)
Expand All @@ -180,6 +182,8 @@ SILDeserializer::SILDeserializer(
WitnessTableList = readFuncTable(scratch, blobData);
else if (kind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_NAMES)
DefaultWitnessTableList = readFuncTable(scratch, blobData);
else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES)
DifferentiabilityWitnessList = readFuncTable(scratch, blobData);
else if (kind == sil_index_block::SIL_PROPERTY_OFFSETS) {
// No matching 'names' block for property descriptors needed yet.
MF->allocateBuffer(Properties, scratch);
Expand Down Expand Up @@ -217,6 +221,12 @@ SILDeserializer::SILDeserializer(
offKind == sil_index_block::SIL_DEFAULT_WITNESS_TABLE_OFFSETS) &&
"Expect a SIL_DEFAULT_WITNESS_TABLE_OFFSETS record.");
MF->allocateBuffer(DefaultWitnessTables, scratch);
} else if (kind == sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES) {
assert((next.Kind == llvm::BitstreamEntry::Record &&
offKind ==
sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_OFFSETS) &&
"Expect a SIL_DIFFERENTIABILITY_WITNESS_OFFSETS record.");
MF->allocateBuffer(DifferentiabilityWitnesses, scratch);
}
}
}
Expand Down Expand Up @@ -339,6 +349,24 @@ SILType SILDeserializer::getSILType(Type Ty, SILValueCategory Category,
.getCategoryType(Category);
}

/// Helper function to find a SILDifferentiabilityWitness, given its mangled
/// key.
SILDifferentiabilityWitness *
SILDeserializer::getSILDifferentiabilityWitnessForReference(
StringRef mangledKey) {
// Check to see if we have a witness under this key already.
auto *witness = SILMod.lookUpDifferentiabilityWitness(mangledKey);
if (witness)
return witness;
// Otherwise, look for a witness under this key in the module.
if (!DifferentiabilityWitnessList)
return nullptr;
auto iter = DifferentiabilityWitnessList->find(mangledKey);
if (iter == DifferentiabilityWitnessList->end())
return nullptr;
return readDifferentiabilityWitness(*iter);
}

/// Helper function to find a SILFunction, given its name and type.
SILFunction *SILDeserializer::getFuncForReference(StringRef name,
SILType type) {
Expand Down Expand Up @@ -760,7 +788,7 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn,
// SIL_VTABLE or SIL_GLOBALVAR or SIL_WITNESS_TABLE record also means the end
// of this SILFunction.
while (kind != SIL_FUNCTION && kind != SIL_VTABLE && kind != SIL_GLOBALVAR &&
kind != SIL_WITNESS_TABLE) {
kind != SIL_WITNESS_TABLE && kind != SIL_DIFFERENTIABILITY_WITNESS) {
if (kind == SIL_BASIC_BLOCK)
// Handle a SILBasicBlock record.
CurrentBB = readSILBasicBlock(fn, CurrentBB, scratch);
Expand Down Expand Up @@ -2988,6 +3016,7 @@ void SILDeserializer::readWitnessTableEntries(
// Another record means the end of this WitnessTable.
while (kind != SIL_WITNESS_TABLE &&
kind != SIL_DEFAULT_WITNESS_TABLE &&
kind != SIL_DIFFERENTIABILITY_WITNESS &&
kind != SIL_FUNCTION) {
if (kind == SIL_DEFAULT_WITNESS_TABLE_NO_ENTRY) {
witnessEntries.push_back(SILDefaultWitnessTable::Entry());
Expand Down Expand Up @@ -3343,6 +3372,138 @@ SILDeserializer::lookupDefaultWitnessTable(SILDefaultWitnessTable *existingWt) {
return Wt;
}

SILDifferentiabilityWitness *
SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
if (DId == 0)
return nullptr;
assert(DId <= DifferentiabilityWitnesses.size() &&
"Invalid SILDifferentiabilityWitness ID");

auto &diffWitnessOrOffset = DifferentiabilityWitnesses[DId - 1];
if (diffWitnessOrOffset.isFullyDeserialized())
return diffWitnessOrOffset.get();

BCOffsetRAII restoreOffset(SILCursor);
if (auto err = SILCursor.JumpToBit(diffWitnessOrOffset.getOffset()))
MF->fatal(std::move(err));
llvm::Expected<llvm::BitstreamEntry> maybeEntry =
SILCursor.advance(AF_DontPopBlockAtEnd);
if (!maybeEntry)
MF->fatal(maybeEntry.takeError());
auto entry = maybeEntry.get();
if (entry.Kind == llvm::BitstreamEntry::Error) {
LLVM_DEBUG(llvm::dbgs() << "Cursor advance error in "
"readDefaultWitnessTable.\n");
return nullptr;
}

SmallVector<uint64_t, 64> scratch;
StringRef blobData;
llvm::Expected<unsigned> maybeKind =
SILCursor.readRecord(entry.ID, scratch, &blobData);
if (!maybeKind)
MF->fatal(maybeKind.takeError());
unsigned kind = maybeKind.get();
assert(kind == SIL_DIFFERENTIABILITY_WITNESS &&
"Expected sil_differentiability_witness");
(void)kind;

DeclID originalNameId, jvpNameId, vjpNameId;
unsigned rawLinkage, isDeclaration, isSerialized, numParameterIndices,
numResultIndices;
GenericSignatureID derivativeGenSigID;
ArrayRef<uint64_t> rawParameterAndResultIndices;

DifferentiabilityWitnessLayout::readRecord(
scratch, originalNameId, rawLinkage, isDeclaration, isSerialized,
derivativeGenSigID, jvpNameId, vjpNameId, numParameterIndices,
numResultIndices, rawParameterAndResultIndices);

if (isDeclaration) {
assert(!isSerialized && "declaration must not be serialized");
}

auto linkageOpt = fromStableSILLinkage(rawLinkage);
assert(linkageOpt &&
"Expected value linkage for sil_differentiability_witness");
auto originalName = MF->getIdentifierText(originalNameId);
auto jvpName = MF->getIdentifierText(jvpNameId);
auto vjpName = MF->getIdentifierText(vjpNameId);
auto *original = getFuncForReference(originalName);
assert(original && "Original function must be found");
auto *jvp = getFuncForReference(jvpName);
if (!jvpName.empty()) {
assert(!isDeclaration && "JVP must not be defined in declaration");
assert(jvp && "JVP function must be found if JVP name is not empty");
}
auto *vjp = getFuncForReference(vjpName);
if (!vjpName.empty()) {
assert(!isDeclaration && "VJP must not be defined in declaration");
assert(vjp && "VJP function must be found if VJP name is not empty");
}
auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID);

SmallVector<unsigned, 8> parameterAndResultIndices(
rawParameterAndResultIndices.begin(), rawParameterAndResultIndices.end());
assert(parameterAndResultIndices.size() ==
numParameterIndices + numResultIndices &&
"Parameter/result indices count mismatch");
auto *parameterIndices = IndexSubset::get(
MF->getContext(), original->getLoweredFunctionType()->getNumParameters(),
ArrayRef<unsigned>(parameterAndResultIndices)
.take_front(numParameterIndices));
auto *resultIndices = IndexSubset::get(
MF->getContext(), original->getLoweredFunctionType()->getNumResults(),
ArrayRef<unsigned>(parameterAndResultIndices)
.take_back(numResultIndices));

AutoDiffConfig config(parameterIndices, resultIndices, derivativeGenSig);
auto *diffWitness =
SILMod.lookUpDifferentiabilityWitness({originalName, config});

// Witnesses that we deserialize are always available externally; we never
// want to emit them ourselves.
auto linkage = swift::addExternalToLinkage(*linkageOpt);

// If there is no existing differentiability witness, create one.
if (!diffWitness)
diffWitness = SILDifferentiabilityWitness::createDeclaration(
SILMod, linkage, original, parameterIndices, resultIndices,
derivativeGenSig);

// If the current differentiability witness is merely a declaration, and the
// deserialized witness is a definition, upgrade the current differentiability
// witness to a definition. This can happen in the following situations:
// 1. The witness was just created above.
// 2. The witness started out as a declaration (e.g. the differentiation
// pass emitted a witness for an external function) and now we're loading
// the definition (e.g. an optimization pass asked for the definition and
// we found the definition serialized in this module).
if (diffWitness->isDeclaration() && !isDeclaration)
diffWitness->convertToDefinition(jvp, vjp, isSerialized);

diffWitnessOrOffset.set(diffWitness,
/*isFullyDeserialized*/ diffWitness->isDefinition());
return diffWitness;
}

SILDifferentiabilityWitness *SILDeserializer::lookupDifferentiabilityWitness(
StringRef mangledDiffWitnessKey) {
if (!DifferentiabilityWitnessList)
return nullptr;
auto iter = DifferentiabilityWitnessList->find(mangledDiffWitnessKey);
if (iter == DifferentiabilityWitnessList->end())
return nullptr;
return readDifferentiabilityWitness(*iter);
}

void SILDeserializer::getAllDifferentiabilityWitnesses() {
if (!DifferentiabilityWitnessList)
return;
for (unsigned I = 0, E = DifferentiabilityWitnesses.size(); I < E; ++I)
readDifferentiabilityWitness(I + 1);
}

SILDeserializer::~SILDeserializer() {
// Drop our references to anything we've deserialized.
for (auto &fnEntry : Funcs) {
Expand Down
17 changes: 17 additions & 0 deletions lib/Serialization/DeserializeSIL.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ namespace swift {
MutableArrayRef<ModuleFile::PartiallySerialized<SILProperty *>>
Properties;

std::unique_ptr<SerializedFuncTable> DifferentiabilityWitnessList;
MutableArrayRef<
ModuleFile::PartiallySerialized<SILDifferentiabilityWitness *>>
DifferentiabilityWitnesses;

/// A declaration will only
llvm::DenseMap<NormalProtocolConformance *, SILWitnessTable *>
ConformanceToWitnessTableMap;
Expand Down Expand Up @@ -113,6 +118,9 @@ namespace swift {
SILType getSILType(Type ty, SILValueCategory category,
SILFunction *inContext);

SILDifferentiabilityWitness *
getSILDifferentiabilityWitnessForReference(StringRef mangledKey);

SILFunction *getFuncForReference(StringRef Name, SILType Ty);
SILFunction *getFuncForReference(StringRef Name);
SILVTable *readVTable(serialization::DeclID);
Expand All @@ -129,6 +137,8 @@ namespace swift {
SILDefaultWitnessTable *
readDefaultWitnessTable(serialization::DeclID,
SILDefaultWitnessTable *existingWt);
SILDifferentiabilityWitness *
readDifferentiabilityWitness(serialization::DeclID);

Optional<KeyPathPatternComponent>
readKeyPathComponent(ArrayRef<uint64_t> ListOfValues, unsigned &nextValue);
Expand All @@ -148,6 +158,8 @@ namespace swift {
SILWitnessTable *lookupWitnessTable(SILWitnessTable *wt);
SILDefaultWitnessTable *
lookupDefaultWitnessTable(SILDefaultWitnessTable *wt);
SILDifferentiabilityWitness *
lookupDifferentiabilityWitness(StringRef mangledDiffWitnessKey);

/// Invalidate all cached SILFunctions.
void invalidateFunctionCache();
Expand All @@ -172,6 +184,7 @@ namespace swift {
getAllWitnessTables();
getAllDefaultWitnessTables();
getAllProperties();
getAllDifferentiabilityWitnesses();
}

/// Deserialize all SILFunctions inside the module and add them to SILMod.
Expand All @@ -195,6 +208,10 @@ namespace swift {
/// to SILMod.
void getAllProperties();

/// Deserialize all DifferentiabilityWitnesses inside the module and add
/// them to SILMod.
void getAllDifferentiabilityWitnesses();

SILDeserializer(ModuleFile *MF, SILModule &M,
DeserializationNotificationHandlerSet *callback);

Expand Down
2 changes: 1 addition & 1 deletion lib/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
/// Don't worry about adhering to the 80-column limit for this line.
const uint16_t SWIFTMODULE_VERSION_MINOR = 536; // Clang function types
const uint16_t SWIFTMODULE_VERSION_MINOR = 537; // SIL differentiability witnesses

/// A standard hash seed used for all string hashes in a serialized module.
///
Expand Down
17 changes: 17 additions & 0 deletions lib/Serialization/SILFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ namespace sil_index_block {
SIL_DEFAULT_WITNESS_TABLE_NAMES,
SIL_DEFAULT_WITNESS_TABLE_OFFSETS,
SIL_PROPERTY_OFFSETS,
SIL_DIFFERENTIABILITY_WITNESS_NAMES,
SIL_DIFFERENTIABILITY_WITNESS_OFFSETS,
};

using ListLayout = BCGenericRecordLayout<
Expand Down Expand Up @@ -141,6 +143,7 @@ namespace sil_block {
SIL_WITNESS_CONDITIONAL_CONFORMANCE,
SIL_DEFAULT_WITNESS_TABLE,
SIL_DEFAULT_WITNESS_TABLE_NO_ENTRY,
SIL_DIFFERENTIABILITY_WITNESS,
SIL_INST_WITNESS_METHOD,
SIL_SPECIALIZE_ATTR,
SIL_PROPERTY,
Expand Down Expand Up @@ -250,6 +253,20 @@ namespace sil_block {
DeclIDField
>;

using DifferentiabilityWitnessLayout = BCRecordLayout<
SIL_DIFFERENTIABILITY_WITNESS,
DeclIDField, // Original function name
SILLinkageField, // Linkage
BCFixed<1>, // Is declaration?
BCFixed<1>, // Is serialized?
GenericSignatureIDField, // Derivative function generic signature
DeclIDField, // JVP function name
DeclIDField, // VJP function name
BCVBR<8>, // Number of parameter indices
BCVBR<8>, // Number of result indices
BCArray<ValueIDField> // Parameter and result indices
>;

using SILFunctionLayout =
BCRecordLayout<SIL_FUNCTION, SILLinkageField,
BCFixed<1>, // transparent
Expand Down
Loading