Skip to content

[AutoDiff] declaration-only SILDifferentiabilityWitness #27854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,8 @@ ERROR(sil_witness_protocol_conformance_not_found,none,
// SIL differentiability witnesses
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
"expected '%0' in differentiability witness", (StringRef))
ERROR(sil_diff_witness_serialized_declaration,none,
"differentiability witness declaration should not be serialized", ())

// SIL Coverage Map
ERROR(sil_coverage_func_not_found, none,
Expand Down
71 changes: 41 additions & 30 deletions include/swift/SIL/SILDifferentiabilityWitness.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,78 +43,89 @@ class SILDifferentiabilityWitness
{
private:
/// The module which contains the differentiability witness.
SILModule &module;
SILModule &Module;
/// The linkage of the differentiability witness.
SILLinkage linkage;
SILLinkage Linkage;
/// The original function.
SILFunction *originalFunction;
SILFunction *OriginalFunction;
/// The autodiff configuration: parameter indices, result indices, derivative
/// generic signature (optional).
AutoDiffConfig config;
AutoDiffConfig Config;
/// The JVP (Jacobian-vector products) derivative function.
SILFunction *jvp;
SILFunction *JVP;
/// The VJP (vector-Jacobian products) derivative function.
SILFunction *vjp;
SILFunction *VJP;
/// Whether or not this differentiability witness is a declaration.
bool IsDeclaration;
/// Whether or not this differentiability witness is serialized, which allows
/// devirtualization from another module.
bool serialized;
bool IsSerialized;
/// The AST `@differentiable` or `@differentiating` attribute from which the
/// differentiability witness is generated. Used for diagnostics.
/// Null if the differentiability witness is parsed from SIL or if it is
/// deserialized.
DeclAttribute *attribute = nullptr;
DeclAttribute *Attribute = nullptr;

SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage,
SILFunction *originalFunction,
IndexSubset *parameterIndices,
IndexSubset *resultIndices,
GenericSignature derivativeGenSig,
SILFunction *jvp, SILFunction *vjp,
bool isSerialized, DeclAttribute *attribute)
: module(module), linkage(linkage), originalFunction(originalFunction),
config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
jvp(jvp), vjp(vjp), serialized(isSerialized), attribute(attribute) {}
bool isDeclaration, bool isSerialized,
DeclAttribute *attribute)
: Module(module), Linkage(linkage), OriginalFunction(originalFunction),
Config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
JVP(jvp), VJP(vjp), IsDeclaration(isDeclaration),
IsSerialized(isSerialized), Attribute(attribute) {}

public:
static SILDifferentiabilityWitness *create(
static SILDifferentiabilityWitness *createDeclaration(
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature derivativeGenSig, DeclAttribute *attribute = nullptr);

static SILDifferentiabilityWitness *createDefinition(
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
bool isSerialized, DeclAttribute *attribute = nullptr);

SILDifferentiabilityWitnessKey getKey() const;
SILModule &getModule() const { return module; }
SILLinkage getLinkage() const { return linkage; }
SILFunction *getOriginalFunction() const { return originalFunction; }
const AutoDiffConfig &getConfig() const { return config; }
SILModule &getModule() const { return Module; }
SILLinkage getLinkage() const { return Linkage; }
SILFunction *getOriginalFunction() const { return OriginalFunction; }
const AutoDiffConfig &getConfig() const { return Config; }
IndexSubset *getParameterIndices() const {
return config.parameterIndices;
return Config.parameterIndices;
}
IndexSubset *getResultIndices() const {
return config.resultIndices;
return Config.resultIndices;
}
GenericSignature getDerivativeGenericSignature() const {
return config.derivativeGenericSignature;
return Config.derivativeGenericSignature;
}
SILFunction *getJVP() const { return jvp; }
SILFunction *getVJP() const { return vjp; }
SILFunction *getJVP() const { return JVP; }
SILFunction *getVJP() const { return VJP; }
SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const {
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP: return jvp;
case AutoDiffDerivativeFunctionKind::VJP: return vjp;
case AutoDiffDerivativeFunctionKind::JVP: return JVP;
case AutoDiffDerivativeFunctionKind::VJP: return VJP;
}
}
void setJVP(SILFunction *jvp) { this->jvp = jvp; }
void setVJP(SILFunction *vjp) { this->vjp = vjp; }
void setJVP(SILFunction *jvp) { JVP = jvp; }
void setVJP(SILFunction *vjp) { VJP = vjp; }
void setDerivative(AutoDiffDerivativeFunctionKind kind,
SILFunction *derivative) {
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP: jvp = derivative; break;
case AutoDiffDerivativeFunctionKind::VJP: vjp = derivative; break;
case AutoDiffDerivativeFunctionKind::JVP: JVP = derivative; break;
case AutoDiffDerivativeFunctionKind::VJP: VJP = derivative; break;
}
}
bool isSerialized() const { return serialized; }
DeclAttribute *getAttribute() const { return attribute; }
bool isDeclaration() const { return IsDeclaration; }
bool isDefinition() const { return !IsDeclaration; }
bool isSerialized() const { return IsSerialized; }
DeclAttribute *getAttribute() const { return Attribute; }

/// Verify that the differentiability witness is well-formed.
void verify(const SILModule &module) const;
Expand Down
43 changes: 28 additions & 15 deletions lib/ParseSIL/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6933,7 +6933,9 @@ bool SILParserTUState::parseSILDefaultWitnessTable(Parser &P) {
/// '[' 'parameters' index-subset ']'
/// '[' 'results' index-subset ']'
/// ('[' 'where' derivatve-generic-signature-requirements ']')?
/// sil-function-name ':' sil-type
/// decl-sil-differentiability-witness-body?
///
/// decl-sil-differentiability-witness-body ::=
/// '{'
/// ('jvp' sil-function-name ':' sil-type)?
/// ('vjp' sil-function-name ':' sil-type)?
Expand All @@ -6949,9 +6951,6 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
Optional<SILLinkage> linkage;
if (parseSILLinkage(linkage, P))
return true;
// Default to public linkage.
if (!linkage)
linkage = SILLinkage::Public;

// Parse '[serialized]' flag (optional).
bool isSerialized = false;
Expand Down Expand Up @@ -6986,8 +6985,7 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
P.diagnose(fnNameLoc, diag::expected_sil_function_type);
return true;
}
fn = State.getGlobalNameForReference(name, fnType, fnNameLoc, true);
State.TUState.PotentialZombieFns.insert(fn);
fn = State.getGlobalNameForReference(name, fnType, fnNameLoc);
return false;
};

Expand Down Expand Up @@ -7063,7 +7061,26 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
nullptr);
}

// Parse differentiability witness body.
auto origFnType = originalFn->getLoweredFunctionType();
auto *parameterIndexSet = IndexSubset::get(
P.Context, origFnType->getNumParameters(), parameterIndices);
auto *resultIndexSet = IndexSubset::get(
P.Context, origFnType->getNumResults(), resultIndices);

// If this is just a declaration, create the declaration now and return.
if (!P.Tok.is(tok::l_brace)) {
if (isSerialized) {
P.diagnose(lastLoc, diag::sil_diff_witness_serialized_declaration);
return true;
}

SILDifferentiabilityWitness::createDeclaration(
M, linkage ? *linkage : SILLinkage::DefaultForDeclaration, originalFn,
parameterIndexSet, resultIndexSet, derivativeGenSig);
return false;
}

// This is a definition, so parse differentiability witness body.
SILFunction *jvp = nullptr;
SILFunction *vjp = nullptr;
if (P.Tok.is(tok::l_brace)) {
Expand Down Expand Up @@ -7094,14 +7111,10 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
return true;
}

auto origFnType = originalFn->getLoweredFunctionType();
auto *parameterIndexSet = IndexSubset::get(
P.Context, origFnType->getNumParameters(), parameterIndices);
auto *resultIndexSet = IndexSubset::get(
P.Context, origFnType->getNumResults(), resultIndices);
SILDifferentiabilityWitness::create(
M, *linkage, originalFn, parameterIndexSet, resultIndexSet,
derivativeGenSig, jvp, vjp, isSerialized);
SILDifferentiabilityWitness::createDefinition(
M, linkage ? *linkage : SILLinkage::DefaultForDefinition, originalFn,
parameterIndexSet, resultIndexSet, derivativeGenSig, jvp, vjp,
isSerialized);
return false;
}
// SWIFT_ENABLE_TENSORFLOW END
Expand Down
24 changes: 21 additions & 3 deletions lib/SIL/SILDifferentiabilityWitness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,31 @@

using namespace swift;

SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDeclaration(
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature derivativeGenSig, DeclAttribute *attribute) {
auto *diffWitness = new (module) SILDifferentiabilityWitness(
module, linkage, originalFunction, parameterIndices, resultIndices,
derivativeGenSig, /*jvp*/ nullptr, /*vjp*/ nullptr,
/*isDeclaration*/ true, /*isSerialized*/ false, attribute);
// Register the differentiability witness in the module.
assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) &&
"Cannot create duplicate differentiability witness in a module");
module.DifferentiabilityWitnessMap[diffWitness->getKey()] = diffWitness;
module.getDifferentiabilityWitnessList().push_back(diffWitness);
return diffWitness;
}

SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition(
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
bool isSerialized, DeclAttribute *attribute) {
auto *diffWitness = new (module) SILDifferentiabilityWitness(
module, linkage, originalFunction, parameterIndices, resultIndices,
derivativeGenSig, jvp, vjp, isSerialized, attribute);
derivativeGenSig, jvp, vjp, /*isDeclaration*/ false, isSerialized,
attribute);
// Register the differentiability witness in the module.
assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) &&
"Cannot create duplicate differentiability witness in a module");
Expand All @@ -33,6 +50,7 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
return diffWitness;
}


SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
return std::make_pair(originalFunction->getName(), getConfig());
return std::make_pair(getOriginalFunction()->getName(), getConfig());
}
16 changes: 10 additions & 6 deletions lib/SIL/SILPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3164,11 +3164,11 @@ void SILDefaultWitnessTable::dump() const {
void SILDifferentiabilityWitness::print(
llvm::raw_ostream &OS, bool verbose) const {
OS << "// differentiability witness for "
<< demangleSymbol(originalFunction->getName()) << '\n';
<< demangleSymbol(getOriginalFunction()->getName()) << '\n';
PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType();
// sil_differentiability_witness (linkage)?
OS << "sil_differentiability_witness ";
printLinkage(OS, linkage, ForDefinition);
printLinkage(OS, getLinkage(), /*isDefinition*/ isDefinition());
// ([serialized])?
if (isSerialized())
OS << "[serialized] ";
Expand All @@ -3187,7 +3187,7 @@ void SILDifferentiabilityWitness::print(
if (auto derivativeGenSig = getDerivativeGenericSignature()) {
ArrayRef<Requirement> requirements;
SmallVector<Requirement, 4> requirementsScratch;
auto *origGenEnv = originalFunction->getGenericEnvironment();
auto *origGenEnv = getOriginalFunction()->getGenericEnvironment();
if (derivativeGenSig) {
if (origGenEnv) {
requirementsScratch = derivativeGenSig->requirementsNotSatisfiedBy(
Expand All @@ -3210,18 +3210,22 @@ void SILDifferentiabilityWitness::print(
}
}
// @original-function-name : $original-sil-type
printSILFunctionNameAndType(OS, originalFunction);
printSILFunctionNameAndType(OS, getOriginalFunction());

if (isDeclaration())
return;

// {
// jvp: @jvp-function-name : $jvp-sil-type
// vjp: @vjp-function-name : $vjp-sil-type
// }
OS << " {\n";
if (jvp) {
if (auto *jvp = getJVP()) {
OS << " jvp: ";
printSILFunctionNameAndType(OS, jvp);
OS << '\n';
}
if (vjp) {
if (auto *vjp = getVJP()) {
OS << " vjp: ";
printSILFunctionNameAndType(OS, vjp);
OS << '\n';
Expand Down
6 changes: 3 additions & 3 deletions lib/SIL/SILVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5387,7 +5387,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
if (!M.getOptions().VerifyAll)
return;
#endif
auto origFnType = originalFunction->getLoweredFunctionType();
auto origFnType = getOriginalFunction()->getLoweredFunctionType();
CanGenericSignature derivativeCanGenSig;
if (auto derivativeGenSig = getDerivativeGenericSignature())
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
Expand All @@ -5407,7 +5407,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
else
exit(1);
};
if (jvp) {
if (auto *jvp = getJVP()) {
// TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
// to accept result indices.
auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType(
Expand All @@ -5417,7 +5417,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
requireSameType(jvp->getLoweredFunctionType(), expectedJVPType,
"JVP type does not match expected JVP type");
}
if (vjp) {
if (auto *vjp = getVJP()) {
// TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
// to result indices.
auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType(
Expand Down
2 changes: 1 addition & 1 deletion lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ void SILGenModule::emitDifferentiabilityWitness(
// TODO(TF-919): Explore creating serialized differentiability witnesses.
// Currently, differentiability witnesses are never serialized to avoid
// deserialization issues where JVP/VJP functions cannot be found.
auto *diffWitness = SILDifferentiabilityWitness::create(
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
M, originalFunction->getLinkage(), originalFunction,
loweredParamIndices, config.resultIndices, derivativeCanGenSig,
/*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false);
Expand Down
31 changes: 24 additions & 7 deletions lib/Serialization/DeserializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3407,14 +3407,19 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
(void)kind;

DeclID originalNameId, jvpNameId, vjpNameId;
unsigned rawLinkage, isSerialized, numParameterIndices, numResultIndices;
unsigned rawLinkage, isDeclaration, isSerialized, numParameterIndices,
numResultIndices;
GenericSignatureID derivativeGenSigID;
ArrayRef<uint64_t> rawParameterAndResultIndices;

DifferentiabilityWitnessLayout::readRecord(
scratch, originalNameId, rawLinkage, isSerialized, derivativeGenSigID,
jvpNameId, vjpNameId, numParameterIndices, numResultIndices,
rawParameterAndResultIndices);
scratch, originalNameId, rawLinkage, isDeclaration, isSerialized,
derivativeGenSigID, jvpNameId, vjpNameId, numParameterIndices,
numResultIndices, rawParameterAndResultIndices);

if (isDeclaration) {
assert(!isSerialized && "declaration must not be serialized");
}

auto linkage = fromStableSILLinkage(rawLinkage);
assert(linkage && "Expected value linkage for sil_differentiability_witness");
Expand All @@ -3424,11 +3429,15 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
auto *original = getFuncForReference(originalName);
assert(original && "Original function must be found");
auto *jvp = getFuncForReference(jvpName);
if (!jvpName.empty())
if (!jvpName.empty()) {
assert(!isDeclaration && "JVP must not be defined in declaration");
assert(jvp && "JVP function must be found if JVP name is not empty");
}
auto *vjp = getFuncForReference(vjpName);
if (!vjpName.empty())
if (!vjpName.empty()) {
assert(!isDeclaration && "VJP must not be defined in declaration");
assert(vjp && "VJP function must be found if VJP name is not empty");
}
auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID);

SmallVector<unsigned, 8> parameterAndResultIndices(
Expand All @@ -3446,7 +3455,15 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
ArrayRef<unsigned>(parameterAndResultIndices)
.take_back(numResultIndices));

auto *diffWitness = SILDifferentiabilityWitness::create(
if (isDeclaration) {
auto *diffWitness = SILDifferentiabilityWitness::createDeclaration(
SILMod, *linkage, original, parameterIndices, resultIndices,
derivativeGenSig);
diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ false);
return diffWitness;
}

auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
SILMod, *linkage, original, parameterIndices, resultIndices,
derivativeGenSig, jvp, vjp, isSerialized);
diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true);
Expand Down
Loading