Skip to content

Commit 051e6e0

Browse files
author
marcrasi
authored
[AutoDiff] declaration-only SILDifferentiabilityWitness (#27854)
We need declaration-only `SILDifferentiabilityWitness` so that we can refer to differentiability witnesses defined in other modules. Therefore, this PR: * Adds distinct declaration/definition `create` methods for `SILDifferentiabilityWitness`. (credit to @dan-zheng's original work on this) * Handles the distinction in SIL printing & parsing: The definitions have a body in braces, and the declarations don't. * Handles the distinction in serialization: There is a new bit describing whether it's a declaration or definition. * Fixes the "deserialization currently fails if public function bodies are removed so that they are only declarations" problem noted in `test/AutoDiff/sil_differentiability_witness.sil`. * The overall purpose of this PR is to allow us to have decl-only differentiability witnesses referencing decl-only functions, so this fix is important to the overall purpose of this PR. * The change at `fn = State.getGlobalNameForReference(name, fnType, fnNameLoc);` in `ParseSIL.cpp` is what fixes this. * Adds test for a decl differentiability witness.
1 parent 93d8d6c commit 051e6e0

File tree

11 files changed

+180
-67
lines changed

11 files changed

+180
-67
lines changed

include/swift/AST/DiagnosticsParse.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,8 @@ ERROR(sil_witness_protocol_conformance_not_found,none,
689689
// SIL differentiability witnesses
690690
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
691691
"expected '%0' in differentiability witness", (StringRef))
692+
ERROR(sil_diff_witness_serialized_declaration,none,
693+
"differentiability witness declaration should not be serialized", ())
692694

693695
// SIL Coverage Map
694696
ERROR(sil_coverage_func_not_found, none,

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -43,78 +43,89 @@ class SILDifferentiabilityWitness
4343
{
4444
private:
4545
/// The module which contains the differentiability witness.
46-
SILModule &module;
46+
SILModule &Module;
4747
/// The linkage of the differentiability witness.
48-
SILLinkage linkage;
48+
SILLinkage Linkage;
4949
/// The original function.
50-
SILFunction *originalFunction;
50+
SILFunction *OriginalFunction;
5151
/// The autodiff configuration: parameter indices, result indices, derivative
5252
/// generic signature (optional).
53-
AutoDiffConfig config;
53+
AutoDiffConfig Config;
5454
/// The JVP (Jacobian-vector products) derivative function.
55-
SILFunction *jvp;
55+
SILFunction *JVP;
5656
/// The VJP (vector-Jacobian products) derivative function.
57-
SILFunction *vjp;
57+
SILFunction *VJP;
58+
/// Whether or not this differentiability witness is a declaration.
59+
bool IsDeclaration;
5860
/// Whether or not this differentiability witness is serialized, which allows
5961
/// devirtualization from another module.
60-
bool serialized;
62+
bool IsSerialized;
6163
/// The AST `@differentiable` or `@differentiating` attribute from which the
6264
/// differentiability witness is generated. Used for diagnostics.
6365
/// Null if the differentiability witness is parsed from SIL or if it is
6466
/// deserialized.
65-
DeclAttribute *attribute = nullptr;
67+
DeclAttribute *Attribute = nullptr;
6668

6769
SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage,
6870
SILFunction *originalFunction,
6971
IndexSubset *parameterIndices,
7072
IndexSubset *resultIndices,
7173
GenericSignature derivativeGenSig,
7274
SILFunction *jvp, SILFunction *vjp,
73-
bool isSerialized, DeclAttribute *attribute)
74-
: module(module), linkage(linkage), originalFunction(originalFunction),
75-
config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
76-
jvp(jvp), vjp(vjp), serialized(isSerialized), attribute(attribute) {}
75+
bool isDeclaration, bool isSerialized,
76+
DeclAttribute *attribute)
77+
: Module(module), Linkage(linkage), OriginalFunction(originalFunction),
78+
Config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
79+
JVP(jvp), VJP(vjp), IsDeclaration(isDeclaration),
80+
IsSerialized(isSerialized), Attribute(attribute) {}
7781

7882
public:
79-
static SILDifferentiabilityWitness *create(
83+
static SILDifferentiabilityWitness *createDeclaration(
84+
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
85+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
86+
GenericSignature derivativeGenSig, DeclAttribute *attribute = nullptr);
87+
88+
static SILDifferentiabilityWitness *createDefinition(
8089
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
8190
IndexSubset *parameterIndices, IndexSubset *resultIndices,
8291
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
8392
bool isSerialized, DeclAttribute *attribute = nullptr);
8493

8594
SILDifferentiabilityWitnessKey getKey() const;
86-
SILModule &getModule() const { return module; }
87-
SILLinkage getLinkage() const { return linkage; }
88-
SILFunction *getOriginalFunction() const { return originalFunction; }
89-
const AutoDiffConfig &getConfig() const { return config; }
95+
SILModule &getModule() const { return Module; }
96+
SILLinkage getLinkage() const { return Linkage; }
97+
SILFunction *getOriginalFunction() const { return OriginalFunction; }
98+
const AutoDiffConfig &getConfig() const { return Config; }
9099
IndexSubset *getParameterIndices() const {
91-
return config.parameterIndices;
100+
return Config.parameterIndices;
92101
}
93102
IndexSubset *getResultIndices() const {
94-
return config.resultIndices;
103+
return Config.resultIndices;
95104
}
96105
GenericSignature getDerivativeGenericSignature() const {
97-
return config.derivativeGenericSignature;
106+
return Config.derivativeGenericSignature;
98107
}
99-
SILFunction *getJVP() const { return jvp; }
100-
SILFunction *getVJP() const { return vjp; }
108+
SILFunction *getJVP() const { return JVP; }
109+
SILFunction *getVJP() const { return VJP; }
101110
SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const {
102111
switch (kind) {
103-
case AutoDiffDerivativeFunctionKind::JVP: return jvp;
104-
case AutoDiffDerivativeFunctionKind::VJP: return vjp;
112+
case AutoDiffDerivativeFunctionKind::JVP: return JVP;
113+
case AutoDiffDerivativeFunctionKind::VJP: return VJP;
105114
}
106115
}
107-
void setJVP(SILFunction *jvp) { this->jvp = jvp; }
108-
void setVJP(SILFunction *vjp) { this->vjp = vjp; }
116+
void setJVP(SILFunction *jvp) { JVP = jvp; }
117+
void setVJP(SILFunction *vjp) { VJP = vjp; }
109118
void setDerivative(AutoDiffDerivativeFunctionKind kind,
110119
SILFunction *derivative) {
111120
switch (kind) {
112-
case AutoDiffDerivativeFunctionKind::JVP: jvp = derivative; break;
113-
case AutoDiffDerivativeFunctionKind::VJP: vjp = derivative; break;
121+
case AutoDiffDerivativeFunctionKind::JVP: JVP = derivative; break;
122+
case AutoDiffDerivativeFunctionKind::VJP: VJP = derivative; break;
114123
}
115124
}
116-
bool isSerialized() const { return serialized; }
117-
DeclAttribute *getAttribute() const { return attribute; }
125+
bool isDeclaration() const { return IsDeclaration; }
126+
bool isDefinition() const { return !IsDeclaration; }
127+
bool isSerialized() const { return IsSerialized; }
128+
DeclAttribute *getAttribute() const { return Attribute; }
118129

119130
/// Verify that the differentiability witness is well-formed.
120131
void verify(const SILModule &module) const;

lib/ParseSIL/ParseSIL.cpp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6933,7 +6933,9 @@ bool SILParserTUState::parseSILDefaultWitnessTable(Parser &P) {
69336933
/// '[' 'parameters' index-subset ']'
69346934
/// '[' 'results' index-subset ']'
69356935
/// ('[' 'where' derivatve-generic-signature-requirements ']')?
6936-
/// sil-function-name ':' sil-type
6936+
/// decl-sil-differentiability-witness-body?
6937+
///
6938+
/// decl-sil-differentiability-witness-body ::=
69376939
/// '{'
69386940
/// ('jvp' sil-function-name ':' sil-type)?
69396941
/// ('vjp' sil-function-name ':' sil-type)?
@@ -6949,9 +6951,6 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
69496951
Optional<SILLinkage> linkage;
69506952
if (parseSILLinkage(linkage, P))
69516953
return true;
6952-
// Default to public linkage.
6953-
if (!linkage)
6954-
linkage = SILLinkage::Public;
69556954

69566955
// Parse '[serialized]' flag (optional).
69576956
bool isSerialized = false;
@@ -6986,8 +6985,7 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
69866985
P.diagnose(fnNameLoc, diag::expected_sil_function_type);
69876986
return true;
69886987
}
6989-
fn = State.getGlobalNameForReference(name, fnType, fnNameLoc, true);
6990-
State.TUState.PotentialZombieFns.insert(fn);
6988+
fn = State.getGlobalNameForReference(name, fnType, fnNameLoc);
69916989
return false;
69926990
};
69936991

@@ -7063,7 +7061,26 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
70637061
nullptr);
70647062
}
70657063

7066-
// Parse differentiability witness body.
7064+
auto origFnType = originalFn->getLoweredFunctionType();
7065+
auto *parameterIndexSet = IndexSubset::get(
7066+
P.Context, origFnType->getNumParameters(), parameterIndices);
7067+
auto *resultIndexSet = IndexSubset::get(
7068+
P.Context, origFnType->getNumResults(), resultIndices);
7069+
7070+
// If this is just a declaration, create the declaration now and return.
7071+
if (!P.Tok.is(tok::l_brace)) {
7072+
if (isSerialized) {
7073+
P.diagnose(lastLoc, diag::sil_diff_witness_serialized_declaration);
7074+
return true;
7075+
}
7076+
7077+
SILDifferentiabilityWitness::createDeclaration(
7078+
M, linkage ? *linkage : SILLinkage::DefaultForDeclaration, originalFn,
7079+
parameterIndexSet, resultIndexSet, derivativeGenSig);
7080+
return false;
7081+
}
7082+
7083+
// This is a definition, so parse differentiability witness body.
70677084
SILFunction *jvp = nullptr;
70687085
SILFunction *vjp = nullptr;
70697086
if (P.Tok.is(tok::l_brace)) {
@@ -7094,14 +7111,10 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
70947111
return true;
70957112
}
70967113

7097-
auto origFnType = originalFn->getLoweredFunctionType();
7098-
auto *parameterIndexSet = IndexSubset::get(
7099-
P.Context, origFnType->getNumParameters(), parameterIndices);
7100-
auto *resultIndexSet = IndexSubset::get(
7101-
P.Context, origFnType->getNumResults(), resultIndices);
7102-
SILDifferentiabilityWitness::create(
7103-
M, *linkage, originalFn, parameterIndexSet, resultIndexSet,
7104-
derivativeGenSig, jvp, vjp, isSerialized);
7114+
SILDifferentiabilityWitness::createDefinition(
7115+
M, linkage ? *linkage : SILLinkage::DefaultForDefinition, originalFn,
7116+
parameterIndexSet, resultIndexSet, derivativeGenSig, jvp, vjp,
7117+
isSerialized);
71057118
return false;
71067119
}
71077120
// SWIFT_ENABLE_TENSORFLOW END

lib/SIL/SILDifferentiabilityWitness.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,31 @@
1717

1818
using namespace swift;
1919

20-
SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
20+
SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDeclaration(
21+
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
22+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
23+
GenericSignature derivativeGenSig, DeclAttribute *attribute) {
24+
auto *diffWitness = new (module) SILDifferentiabilityWitness(
25+
module, linkage, originalFunction, parameterIndices, resultIndices,
26+
derivativeGenSig, /*jvp*/ nullptr, /*vjp*/ nullptr,
27+
/*isDeclaration*/ true, /*isSerialized*/ false, attribute);
28+
// Register the differentiability witness in the module.
29+
assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) &&
30+
"Cannot create duplicate differentiability witness in a module");
31+
module.DifferentiabilityWitnessMap[diffWitness->getKey()] = diffWitness;
32+
module.getDifferentiabilityWitnessList().push_back(diffWitness);
33+
return diffWitness;
34+
}
35+
36+
SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition(
2137
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
2238
IndexSubset *parameterIndices, IndexSubset *resultIndices,
2339
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
2440
bool isSerialized, DeclAttribute *attribute) {
2541
auto *diffWitness = new (module) SILDifferentiabilityWitness(
2642
module, linkage, originalFunction, parameterIndices, resultIndices,
27-
derivativeGenSig, jvp, vjp, isSerialized, attribute);
43+
derivativeGenSig, jvp, vjp, /*isDeclaration*/ false, isSerialized,
44+
attribute);
2845
// Register the differentiability witness in the module.
2946
assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) &&
3047
"Cannot create duplicate differentiability witness in a module");
@@ -33,6 +50,7 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
3350
return diffWitness;
3451
}
3552

53+
3654
SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
37-
return std::make_pair(originalFunction->getName(), getConfig());
55+
return std::make_pair(getOriginalFunction()->getName(), getConfig());
3856
}

lib/SIL/SILPrinter.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3164,11 +3164,11 @@ void SILDefaultWitnessTable::dump() const {
31643164
void SILDifferentiabilityWitness::print(
31653165
llvm::raw_ostream &OS, bool verbose) const {
31663166
OS << "// differentiability witness for "
3167-
<< demangleSymbol(originalFunction->getName()) << '\n';
3167+
<< demangleSymbol(getOriginalFunction()->getName()) << '\n';
31683168
PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType();
31693169
// sil_differentiability_witness (linkage)?
31703170
OS << "sil_differentiability_witness ";
3171-
printLinkage(OS, linkage, ForDefinition);
3171+
printLinkage(OS, getLinkage(), /*isDefinition*/ isDefinition());
31723172
// ([serialized])?
31733173
if (isSerialized())
31743174
OS << "[serialized] ";
@@ -3187,7 +3187,7 @@ void SILDifferentiabilityWitness::print(
31873187
if (auto derivativeGenSig = getDerivativeGenericSignature()) {
31883188
ArrayRef<Requirement> requirements;
31893189
SmallVector<Requirement, 4> requirementsScratch;
3190-
auto *origGenEnv = originalFunction->getGenericEnvironment();
3190+
auto *origGenEnv = getOriginalFunction()->getGenericEnvironment();
31913191
if (derivativeGenSig) {
31923192
if (origGenEnv) {
31933193
requirementsScratch = derivativeGenSig->requirementsNotSatisfiedBy(
@@ -3210,18 +3210,22 @@ void SILDifferentiabilityWitness::print(
32103210
}
32113211
}
32123212
// @original-function-name : $original-sil-type
3213-
printSILFunctionNameAndType(OS, originalFunction);
3213+
printSILFunctionNameAndType(OS, getOriginalFunction());
3214+
3215+
if (isDeclaration())
3216+
return;
3217+
32143218
// {
32153219
// jvp: @jvp-function-name : $jvp-sil-type
32163220
// vjp: @vjp-function-name : $vjp-sil-type
32173221
// }
32183222
OS << " {\n";
3219-
if (jvp) {
3223+
if (auto *jvp = getJVP()) {
32203224
OS << " jvp: ";
32213225
printSILFunctionNameAndType(OS, jvp);
32223226
OS << '\n';
32233227
}
3224-
if (vjp) {
3228+
if (auto *vjp = getVJP()) {
32253229
OS << " vjp: ";
32263230
printSILFunctionNameAndType(OS, vjp);
32273231
OS << '\n';

lib/SIL/SILVerifier.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5387,7 +5387,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
53875387
if (!M.getOptions().VerifyAll)
53885388
return;
53895389
#endif
5390-
auto origFnType = originalFunction->getLoweredFunctionType();
5390+
auto origFnType = getOriginalFunction()->getLoweredFunctionType();
53915391
CanGenericSignature derivativeCanGenSig;
53925392
if (auto derivativeGenSig = getDerivativeGenericSignature())
53935393
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
@@ -5407,7 +5407,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
54075407
else
54085408
exit(1);
54095409
};
5410-
if (jvp) {
5410+
if (auto *jvp = getJVP()) {
54115411
// TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
54125412
// to accept result indices.
54135413
auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType(
@@ -5417,7 +5417,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
54175417
requireSameType(jvp->getLoweredFunctionType(), expectedJVPType,
54185418
"JVP type does not match expected JVP type");
54195419
}
5420-
if (vjp) {
5420+
if (auto *vjp = getVJP()) {
54215421
// TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
54225422
// to result indices.
54235423
auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType(

lib/SILGen/SILGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ void SILGenModule::emitDifferentiabilityWitness(
832832
// TODO(TF-919): Explore creating serialized differentiability witnesses.
833833
// Currently, differentiability witnesses are never serialized to avoid
834834
// deserialization issues where JVP/VJP functions cannot be found.
835-
auto *diffWitness = SILDifferentiabilityWitness::create(
835+
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
836836
M, originalFunction->getLinkage(), originalFunction,
837837
loweredParamIndices, config.resultIndices, derivativeCanGenSig,
838838
/*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false);

lib/Serialization/DeserializeSIL.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3407,14 +3407,19 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
34073407
(void)kind;
34083408

34093409
DeclID originalNameId, jvpNameId, vjpNameId;
3410-
unsigned rawLinkage, isSerialized, numParameterIndices, numResultIndices;
3410+
unsigned rawLinkage, isDeclaration, isSerialized, numParameterIndices,
3411+
numResultIndices;
34113412
GenericSignatureID derivativeGenSigID;
34123413
ArrayRef<uint64_t> rawParameterAndResultIndices;
34133414

34143415
DifferentiabilityWitnessLayout::readRecord(
3415-
scratch, originalNameId, rawLinkage, isSerialized, derivativeGenSigID,
3416-
jvpNameId, vjpNameId, numParameterIndices, numResultIndices,
3417-
rawParameterAndResultIndices);
3416+
scratch, originalNameId, rawLinkage, isDeclaration, isSerialized,
3417+
derivativeGenSigID, jvpNameId, vjpNameId, numParameterIndices,
3418+
numResultIndices, rawParameterAndResultIndices);
3419+
3420+
if (isDeclaration) {
3421+
assert(!isSerialized && "declaration must not be serialized");
3422+
}
34183423

34193424
auto linkage = fromStableSILLinkage(rawLinkage);
34203425
assert(linkage && "Expected value linkage for sil_differentiability_witness");
@@ -3424,11 +3429,15 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
34243429
auto *original = getFuncForReference(originalName);
34253430
assert(original && "Original function must be found");
34263431
auto *jvp = getFuncForReference(jvpName);
3427-
if (!jvpName.empty())
3432+
if (!jvpName.empty()) {
3433+
assert(!isDeclaration && "JVP must not be defined in declaration");
34283434
assert(jvp && "JVP function must be found if JVP name is not empty");
3435+
}
34293436
auto *vjp = getFuncForReference(vjpName);
3430-
if (!vjpName.empty())
3437+
if (!vjpName.empty()) {
3438+
assert(!isDeclaration && "VJP must not be defined in declaration");
34313439
assert(vjp && "VJP function must be found if VJP name is not empty");
3440+
}
34323441
auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID);
34333442

34343443
SmallVector<unsigned, 8> parameterAndResultIndices(
@@ -3446,7 +3455,15 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
34463455
ArrayRef<unsigned>(parameterAndResultIndices)
34473456
.take_back(numResultIndices));
34483457

3449-
auto *diffWitness = SILDifferentiabilityWitness::create(
3458+
if (isDeclaration) {
3459+
auto *diffWitness = SILDifferentiabilityWitness::createDeclaration(
3460+
SILMod, *linkage, original, parameterIndices, resultIndices,
3461+
derivativeGenSig);
3462+
diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ false);
3463+
return diffWitness;
3464+
}
3465+
3466+
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
34503467
SILMod, *linkage, original, parameterIndices, resultIndices,
34513468
derivativeGenSig, jvp, vjp, isSerialized);
34523469
diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true);

0 commit comments

Comments
 (0)