Skip to content

[AutoDiff] SILGen differentiability witnesses. #27652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 60 additions & 9 deletions lib/SIL/SILPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3085,6 +3085,51 @@ void SILDefaultWitnessTable::dump() const {
print(llvm::errs());
}

// TODO(TF-893): Use this helper to dedupe the same logic in
// `SILFunction::print`.
static void printSILFunctionNameAndType(
llvm::raw_ostream &OS, SILFunction *function) {
function->printName(OS);
OS << " : $";
llvm::DenseMap<CanType, Identifier> Aliases;
llvm::DenseSet<Identifier> UsedNames;
auto sig = function->getLoweredFunctionType()->getGenericSignature();
auto *env = function->getGenericEnvironment();
if (sig && env) {
llvm::SmallString<16> disambiguatedNameBuf;
unsigned disambiguatedNameCounter = 1;
for (auto *paramTy : sig->getGenericParams()) {
auto sugaredTy = env->getSugaredType(paramTy);
Identifier name = sugaredTy->getName();
while (!UsedNames.insert(name).second) {
disambiguatedNameBuf.clear();
{
llvm::raw_svector_ostream names(disambiguatedNameBuf);
names << sugaredTy->getName() << disambiguatedNameCounter++;
}
name = function->getASTContext().getIdentifier(disambiguatedNameBuf);
}
if (name != sugaredTy->getName()) {
Aliases[paramTy->getCanonicalType()] = name;

// Also for the archetype
auto archetypeTy = env->mapTypeIntoContext(paramTy)
->getAs<ArchetypeType>();
if (archetypeTy)
Aliases[archetypeTy->getCanonicalType()] = name;
}
}
}

{
PrintOptions withGenericEnvironment = PrintOptions::printSIL();
withGenericEnvironment.GenericEnv = env;
withGenericEnvironment.AlternativeTypeNames =
Aliases.empty() ? nullptr : &Aliases;
function->getLoweredFunctionType()->print(OS, withGenericEnvironment);
}
}

// SWIFT_ENABLE_TENSORFLOW
void SILDifferentiabilityWitness::print(
llvm::raw_ostream &OS, bool verbose) const {
Expand All @@ -3107,7 +3152,7 @@ void SILDifferentiabilityWitness::print(
interleave(getResultIndices()->getIndices(),
[&](unsigned index) { OS << index; },
[&] { OS << ' '; });
OS << ']';
OS << "] ";
// ([where ...])?
if (auto *derivativeGenSig = getDerivativeGenericSignature()) {
ArrayRef<Requirement> requirements;
Expand All @@ -3123,28 +3168,34 @@ void SILDifferentiabilityWitness::print(
}
}
if (!requirements.empty()) {
OS << " [where ";
OS << "[where ";
auto subPrinter = PrintOptions::printSIL();
subPrinter.GenericEnv = origGenEnv;
interleave(requirements,
[&](Requirement req) {
req.print(OS, subPrinter);
},
[&] { OS << ", "; });
OS << ']';
OS << "] ";
}
}
// @original-function-name : $original-sil-type
OS << " @" << originalFunction->getName() << " : "
<< originalFunction->getLoweredType();
printSILFunctionNameAndType(OS, originalFunction);
// {
// jvp: @jvp-function-name : $jvp-sil-type
// vjp: @vjp-function-name : $vjp-sil-type
// }
OS << " {\n";
if (jvp)
OS << " jvp: @" << jvp->getName() << " : " << jvp->getLoweredType() << '\n';
if (vjp)
OS << " vjp: @" << vjp->getName() << " : " << vjp->getLoweredType() << '\n';
if (jvp) {
OS << " jvp: ";
printSILFunctionNameAndType(OS, jvp);
OS << '\n';
}
if (vjp) {
OS << " vjp: ";
printSILFunctionNameAndType(OS, vjp);
OS << '\n';
}
OS << "}\n\n";
}

Expand Down
198 changes: 123 additions & 75 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,87 +752,135 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
F->print(llvm::dbgs()));

// SWIFT_ENABLE_TENSORFLOW
// Create self-reordering thunks for JVPs/VJPs of `@differentiable` methods.
if (constant.hasDecl() && constant.getAbstractFunctionDecl()) {
// Visit `@differentiable` attributes and generate SIL differentiability
// witnesses.
// TODO(TF-835): Visit `@differentiating` attributes when type-checking no
// longer generates implicit `@differentiable` attributes. See TF-835 for
// replacement code.
// Skip if the SILDeclRef is a:
// - Default argument generator function.
// - Thunk.
if (constant.hasDecl() && constant.getAbstractFunctionDecl() &&
constant.kind != SILDeclRef::Kind::DefaultArgGenerator &&
!constant.isThunk()) {
auto *AFD = constant.getAbstractFunctionDecl();
auto origFnType = AFD->getInterfaceType()->castTo<AnyFunctionType>();
auto origSilFnType = F->getLoweredFunctionType();
// Jointly iterate over AST `@differentiable` attributes and SIL
// `[differentiable]` attributes.
auto diffAttrs = AFD->getAttrs().getAttributes<DifferentiableAttr>();
auto silDiffAttrs = F->getDifferentiableAttrs();
for (auto pair : llvm::zip(diffAttrs, silDiffAttrs)) {
auto *diffAttr = const_cast<DifferentiableAttr *>(std::get<0>(pair));
auto *silDiffAttr = std::get<1>(pair);
// Compute lowered parameter indices.
auto *paramIndices = diffAttr->getParameterIndices();
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
paramIndices, origFnType);
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
assert(silDiffAttr->getIndices() == indices &&
"Expected matching @differentiable and [differentiable] indices");

auto lookUpConformance = LookUpConformanceInModule(M.getSwiftModule());
auto expectedJVPType = origSilFnType->getAutoDiffDerivativeFunctionType(
indices.parameters, indices.source,
AutoDiffDerivativeFunctionKind::JVP, Types, lookUpConformance);
auto expectedVJPType = origSilFnType->getAutoDiffDerivativeFunctionType(
indices.parameters, indices.source,
AutoDiffDerivativeFunctionKind::VJP, Types, lookUpConformance);

// Self reordering is necessary if wrt at least two parameters, including
// self.
auto shouldReorderSelf = [&]() {
if (!F->hasSelfParam())
return false;
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
if (!indices.isWrtParameter(selfParamIndex))
return false;
return indices.parameters->getNumIndices() > 1;
};
bool reorderSelf = shouldReorderSelf();

// Thunk JVP method, if it is defined.
if (auto *jvpDecl = diffAttr->getJVPFunction()) {
SILFunction *jvpThunk;
auto *jvpFn = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
if (reorderSelf || jvpFn->getLoweredFunctionType() != expectedJVPType) {
jvpThunk = getOrCreateAutoDiffDerivativeFunctionThunk(
F, indices, jvpFn, AutoDiffDerivativeFunctionKind::JVP,
reorderSelf);
} else {
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::JVP,
diffAttr->getParameterIndices(), AFD->getASTContext());
jvpThunk = getOrCreateAutoDiffThunk(
constant.asAutoDiffDerivativeFunction(id), jvpFn,
expectedJVPType);
}
silDiffAttr->setJVPName(jvpThunk->getName());
}
// Thunk VJP method, if it is defined.
if (auto *vjpDecl = diffAttr->getVJPFunction()) {
SILFunction *vjpThunk;
auto *vjpFn = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
if (reorderSelf || vjpFn->getLoweredFunctionType() != expectedVJPType) {
vjpThunk = getOrCreateAutoDiffDerivativeFunctionThunk(
F, indices, vjpFn, AutoDiffDerivativeFunctionKind::VJP,
reorderSelf);
} else {
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::VJP,
diffAttr->getParameterIndices(), AFD->getASTContext());
vjpThunk = getOrCreateAutoDiffThunk(
constant.asAutoDiffDerivativeFunction(id), vjpFn,
expectedVJPType);
}
silDiffAttr->setVJPName(vjpThunk->getName());
}
// Visit all `@differentiable` attributes.
for (auto *diffAttr : AFD->getAttrs().getAttributes<DifferentiableAttr>()) {
SILFunction *jvp = nullptr;
SILFunction *vjp = nullptr;
if (auto *jvpDecl = diffAttr->getJVPFunction())
jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
if (auto *vjpDecl = diffAttr->getVJPFunction())
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
AutoDiffConfig config{diffAttr->getParameterIndices(), resultIndices,
diffAttr->getDerivativeGenericSignature()};
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp);
}
}
F->verify();
}

void SILGenModule::emitDifferentiabilityWitness(
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp) {
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
auto origSilFnType = originalFunction->getLoweredFunctionType();
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
config.parameterIndices, origFnType);
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
// parameters corresponding to captured variables. These parameters do not
// appear in the type of `origFnType`.
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
// take `CaptureInfo` into account.
if (origSilFnType->getNumParameters() > loweredParamIndices->getCapacity())
loweredParamIndices = loweredParamIndices->extendingCapacity(
getASTContext(), origSilFnType->getNumParameters());
// TODO(TF-913): Replace usages of `SILAutoDiffIndices` with `AutoDiffConfig`.
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);

// Self reordering thunk is necessary if wrt at least two parameters,
// including self.
auto shouldReorderSelf = [&]() {
if (!originalFunction->hasSelfParam())
return false;
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
if (!indices.isWrtParameter(selfParamIndex))
return false;
return indices.parameters->getNumIndices() > 1;
};
bool reorderSelf = shouldReorderSelf();

CanGenericSignature derivativeCanGenSig;
if (auto *derivativeGenSig = config.derivativeGenericSignature)
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
// TODO(TF-835): Use simpler derivative generic signature logic below when
// type-checking no longer generates implicit `@differentiable` attributes.
// See TF-835 for replacement code.
if (jvp) {
auto jvpCanGenSig = jvp->getLoweredFunctionType()->getGenericSignature();
if (!derivativeCanGenSig && jvpCanGenSig)
derivativeCanGenSig = jvpCanGenSig;
assert(derivativeCanGenSig == jvpCanGenSig);
}
if (vjp) {
auto vjpCanGenSig = vjp->getLoweredFunctionType()->getGenericSignature();
if (!derivativeCanGenSig && vjpCanGenSig)
derivativeCanGenSig = vjpCanGenSig;
assert(derivativeCanGenSig == vjpCanGenSig);
}
// Create new SIL differentiability witness.
// Witness JVP and VJP are set below.
// TODO(TF-919): Explore creating serialized differentiability witnesses.
// Currently, differentiability witnesses are never serialized to avoid
// deserialization issues where JVP/VJP functions cannot be found.
auto *diffWitness = SILDifferentiabilityWitness::create(
M, originalFunction->getLinkage(), originalFunction,
loweredParamIndices, config.resultIndices, derivativeCanGenSig,
/*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false);

// Set derivative function in differentiability witness.
auto setDerivativeInDifferentiabilityWitness =
[&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) {
auto expectedDerivativeType =
origSilFnType->getAutoDiffDerivativeFunctionType(
indices.parameters, indices.source, kind, Types,
LookUpConformanceInModule(M.getSwiftModule()));
// Thunk derivative function.
SILFunction *derivativeThunk;
if (reorderSelf ||
derivative->getLoweredFunctionType() != expectedDerivativeType) {
derivativeThunk = getOrCreateAutoDiffDerivativeFunctionThunk(
originalFunction, indices, derivative, kind, reorderSelf);
} else {
// Note: `AutoDiffDerivativeFunctionIdentifier` must be constructed with
// the AST-level parameter indices, not the SIL-level ones.
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
kind, config.parameterIndices, getASTContext());
derivativeThunk = getOrCreateAutoDiffThunk(
SILDeclRef(originalAFD).asAutoDiffDerivativeFunction(id), derivative,
expectedDerivativeType);
}
// Check for existing same derivative.
// TODO(TF-835): Remove condition below and simplify assertion to
// `!diffWitness->getDerivative(kind)` after `@differentiating` attribute
// type-checking no longer generates implicit `@differentiable` attributes.
auto *existingDerivative = diffWitness->getDerivative(kind);
if (existingDerivative && existingDerivative == derivativeThunk)
return;
assert(!existingDerivative &&
"SIL differentiability witness already has a different existing "
"derivative");
diffWitness->setDerivative(kind, derivativeThunk);
};
if (jvp)
setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::JVP,
jvp);
if (vjp)
setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::VJP,
vjp);
}

void SILGenModule::
emitMarkFunctionEscapeForTopLevelCodeGlobals(SILLocation loc,
const CaptureInfo &captureInfo) {
Expand Down
10 changes: 10 additions & 0 deletions lib/SILGen/SILGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,16 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
/// Emit the self-conformance witness table for a protocol.
void emitSelfConformanceWitnessTable(ProtocolDecl *protocol);

// SWIFT_ENABLE_TENSORFLOW
/// Emit the differentiability witness for the given original function
/// declaration and SIL function, autodiff configuration, and JVP and VJP
/// functions (null if undefined).
void emitDifferentiabilityWitness(AbstractFunctionDecl *originalAFD,
SILFunction *originalFunction,
const AutoDiffConfig &config,
SILFunction *jvp, SILFunction *vjp);
// SWIFT_ENABLE_TENSORFLOW END

/// Emit the lazy initializer function for a global pattern binding
/// declaration.
SILFunction *emitLazyGlobalInitializer(StringRef funcName,
Expand Down
3 changes: 3 additions & 0 deletions lib/Serialization/DeserializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3016,6 +3016,9 @@ void SILDeserializer::readWitnessTableEntries(
// Another record means the end of this WitnessTable.
while (kind != SIL_WITNESS_TABLE &&
kind != SIL_DEFAULT_WITNESS_TABLE &&
// SWIFT_ENABLE_TENSORFLOW
kind != SIL_DIFFERENTIABILITY_WITNESS &&
// SWIFT_ENABLE_TENSORFLOW END
kind != SIL_FUNCTION) {
if (kind == SIL_DEFAULT_WITNESS_TABLE_NO_ENTRY) {
witnessEntries.push_back(SILDefaultWitnessTable::Entry());
Expand Down
2 changes: 2 additions & 0 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,7 @@ void Serializer::writeBlockInfoBlock() {
BLOCK_RECORD(sil_block, SIL_INST_LINEAR_FUNCTION);
BLOCK_RECORD(sil_block, SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT);
BLOCK_RECORD(sil_block, SIL_INST_LINEAR_FUNCTION_EXTRACT);
BLOCK_RECORD(sil_block, SIL_DIFFERENTIABILITY_WITNESS);
// SWIFT_ENABLE_TENSORFLOW END

// These layouts can exist in both decl blocks and sil blocks.
Expand Down Expand Up @@ -829,6 +830,7 @@ void Serializer::writeBlockInfoBlock() {
BLOCK_RECORD(sil_index_block, SIL_DEFAULT_WITNESS_TABLE_OFFSETS);
BLOCK_RECORD(sil_index_block, SIL_PROPERTY_OFFSETS);
// SWIFT_ENABLE_TENSORFLOW
BLOCK_RECORD(sil_index_block, SIL_DIFFERENTIABILITY_WITNESS_NAMES);
BLOCK_RECORD(sil_index_block, SIL_DIFFERENTIABILITY_WITNESS_OFFSETS);
// SWIFT_ENABLE_TENSORFLOW END

Expand Down
17 changes: 6 additions & 11 deletions lib/Serialization/SerializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2322,7 +2322,7 @@ void SILSerializer::writeIndexTables() {
}

// SWIFT_ENABLE_TENSORFLOW
if (!DifferentiabilityWitnessOffset.empty()) {
if (!DifferentiabilityWitnessList.empty()) {
writeIndexTable(S, List,
sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES,
DifferentiabilityWitnessList);
Expand Down Expand Up @@ -2542,17 +2542,12 @@ writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) {
DifferentiabilityWitnessOffset.push_back(Out.GetCurrentBitNo());

auto *original = dw.getOriginalFunction();
addReferencedSILFunction(original, /*DeclOnly*/ true);
IdentifierID jvpID = 0;
IdentifierID vjpID = 0;
if (auto *jvp = dw.getJVP()) {
addReferencedSILFunction(jvp, /*DeclOnly*/ true);
jvpID = S.addUniquedStringRef(jvp->getName());
}
if (auto *vjp = dw.getVJP()) {
addReferencedSILFunction(vjp, /*DeclOnly*/ true);
vjpID = S.addUniquedStringRef(vjp->getName());
}
if (auto *jvp = dw.getJVP())
jvpID = addSILFunctionRef(jvp);
if (auto *vjp = dw.getVJP())
vjpID = addSILFunctionRef(vjp);
SmallVector<unsigned, 8> parameterAndResultIndices(
dw.getParameterIndices()->begin(), dw.getParameterIndices()->end());
parameterAndResultIndices.append(dw.getResultIndices()->begin(),
Expand All @@ -2569,7 +2564,7 @@ writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) {

DifferentiabilityWitnessLayout::emitRecord(
Out, ScratchRecord, SILAbbrCodes[DifferentiabilityWitnessLayout::Code],
S.addUniquedStringRef(original->getName()),
addSILFunctionRef(original),
toStableSILLinkage(dw.getLinkage()),
dw.isSerialized(),
S.addGenericSignatureRef(dw.getDerivativeGenericSignature()),
Expand Down
Loading