Skip to content

Commit 41015c3

Browse files
author
marcrasi
authored
[AutoDiff] DifferentiabilityWitnessFunctionInst stores pointer to witness (#27919)
We want the `DifferentiabilityWitnessFunctionInst` to store a pointer to the witness because this makes the instruction data structure smaller and simpler, and makes misuse (e.g. referencing a witness that's not declared in the SIL module) harder. This PR mostly preserves the existing SIL syntax & rules. The changes are: * `sil_differentiability_witness` declarations are now required to come before `differentiability_witness_function` instructions that reference them, so that we don't have to add "forward declaration" logic for them. * `sil_differentiability_witness` constraints now look like `<τ_0_0 where τ_0_0 : _Differentiable>` instead of `[where T : _Differentiable]`. This allows me to share parsing logic between the `sil_differentiability_witness` constraints and the `differentiability_witness_function`, which simplifies the code and also avoids a problem [1]. This PR changes the `DifferentiabilityWitnessFunctionInst` serialization format to store the witness mangled name rather than the original function name plus indices and generic constraints. Detailed outline of changes: * Datastructure & utility changes: * `DifferentiabilityWitnessFunctionInst` now contains a `SILDifferentiabilityWitness*` instead of a pointer to original function, indices, and generic signature. (credit to @dan-zheng, I just copied these changes from his commit) * Added functions in `SILModule` for looking up `SILDifferentiabilityWitness*` because parsing and deserialization need to find the right `SILDifferentiabilityWitness*` to put in the inst. * SIL printing changes: * Print `sil_differentiability_witness` declarations before function declarations * Print `<T where T ...>` instead of `[where T ...]` in `sil_differentiability_witness` declarations. * SIL parsing changes: * Factor out a common `parseSILDifferentiabilityWitnessConfigAndFunction`, and use this for parsing both `sil_differentiability_witness` and `differentiability_witness_function`. This simplifies the code and avoids the problem [1]. * When parsing a `differentiability_witness_function`, look for the `SILDifferentiabilityWitness*` in the module. * Bugfix in ParseStmt so that it can parse `sil_differentiability_witness` declarations that come at the beginning of the file. * Serialization changes: * Deleted `SILInstDifferentiabilityWitnessFunctionLayout` from the serialization format, because the `DifferentiabilityWitnessFunctionInst` now fits into the `SILOneOperandLayout`. * Changed serialization to serialize `DifferentiabilityWitnessFunctionInst` a `SILOneOperandLayout` with the mangled witness name. * Changed deserialization to deserialize `DifferentiabilityWitnessFunctionInst` by reading the name and deserializing the corresponding witness. * Changed deserialization to treat declaration-only witnesses as "fully deserialized". Otherwise it tries to deserialize them twice and the module doesn't like having duplicate witnesses. [1] The avoided problem: The previous code for parsing `sil_differentiability_witness` constraints of the form `[where T : _Differentiable]` fails to handle the case where the original function is declaration-only, because the code assumes that there is a generic environment, and declaration-only functions do not have a generic environment. Rather than fixing this code, it seemed better to delete it and use the existing code that parses `differentiability_witness_function` constraints, which does not suffer from this problem.
1 parent 54bdad6 commit 41015c3

19 files changed

+294
-345
lines changed

include/swift/AST/DiagnosticsParse.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,8 @@ ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
691691
"expected '%0' in differentiability witness", (StringRef))
692692
ERROR(sil_diff_witness_serialized_declaration,none,
693693
"differentiability witness declaration should not be serialized", ())
694+
ERROR(sil_diff_witness_undefined,PointsToFirstBadToken,
695+
"reference to undefined differentiability witness", ())
694696

695697
// SIL Coverage Map
696698
ERROR(sil_coverage_func_not_found, none,

include/swift/SIL/SILBuilder.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -551,13 +551,10 @@ class SILBuilder {
551551

552552
DifferentiabilityWitnessFunctionInst *
553553
createDifferentiabilityWitnessFunction(
554-
SILLocation Loc, SILFunction *OriginalFunction,
555-
DifferentiabilityWitnessFunctionKind WitnessKind,
556-
IndexSubset *ParameterIndices, IndexSubset *ResultIndices,
557-
GenericSignature WitnessGenericSignature) {
554+
SILLocation Loc, DifferentiabilityWitnessFunctionKind WitnessKind,
555+
SILDifferentiabilityWitness *Witness) {
558556
return insert(new (getModule()) DifferentiabilityWitnessFunctionInst(
559-
getModule(), getSILDebugLocation(Loc), OriginalFunction, WitnessKind,
560-
ParameterIndices, ResultIndices, WitnessGenericSignature.getPointer()));
557+
getModule(), getSILDebugLocation(Loc), WitnessKind, Witness));
561558
}
562559
// SWIFT_ENABLE_TENSORFLOW END
563560

include/swift/SIL/SILCloner.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,9 +1018,8 @@ void SILCloner<ImplClass>::visitDifferentiabilityWitnessFunctionInst(
10181018
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
10191019
recordClonedInstruction(
10201020
Inst, getBuilder().createDifferentiabilityWitnessFunction(
1021-
getOpLocation(Inst->getLoc()), Inst->getOriginalFunction(),
1022-
Inst->getWitnessKind(), Inst->getParameterIndices(),
1023-
Inst->getResultIndices(), Inst->getWitnessGenericSignature()));
1021+
getOpLocation(Inst->getLoc()), Inst->getWitnessKind(),
1022+
Inst->getWitness()));
10241023
}
10251024
// SWIFT_ENABLE_TENSORFLOW END
10261025

include/swift/SIL/SILInstruction.h

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class SILBasicBlock;
6161
class SILBuilder;
6262
class SILDebugLocation;
6363
class SILDebugScope;
64+
// SWIFT_ENABLE_TENSORFLOW
65+
class SILDifferentiabilityWitness;
66+
// SWIFT_ENABLE_TENSORFLOW_END
6467
class SILFunction;
6568
class SILGlobalVariable;
6669
class SILInstructionResultArray;
@@ -8038,43 +8041,31 @@ class DifferentiabilityWitnessFunctionInst
80388041
SingleValueInstruction> {
80398042
private:
80408043
friend SILBuilder;
8041-
/// The original function.
8042-
SILFunction *originalFunction;
80438044
/// The differentiability witness function kind.
80448045
DifferentiabilityWitnessFunctionKind witnessKind;
8045-
/// The autodiff config: parameter indices, result indices, and witness
8046-
/// derivative signature.
8047-
AutoDiffConfig config;
8046+
/// The referenced SIL differentiability witness.
8047+
SILDifferentiabilityWitness *witness;
80488048

80498049
static SILType getDifferentiabilityWitnessType(
8050-
SILModule &module, SILFunction *originalFunction,
8050+
SILModule &module,
80518051
DifferentiabilityWitnessFunctionKind witnessKind,
8052-
IndexSubset *parameterIndices, IndexSubset *resultIndices,
8053-
GenericSignature witnessGenericSignature);
8052+
SILDifferentiabilityWitness *witness);
80548053

80558054
public:
80568055
DifferentiabilityWitnessFunctionInst(
8057-
SILModule &module, SILDebugLocation loc, SILFunction *originalFunction,
8056+
SILModule &module, SILDebugLocation loc,
80588057
DifferentiabilityWitnessFunctionKind witnessKind,
8059-
IndexSubset *parameterIndices, IndexSubset *resultIndices,
8060-
GenericSignature witnessGenericSignature);
8058+
SILDifferentiabilityWitness *witness);
80618059

80628060
static DifferentiabilityWitnessFunctionInst *create(
8063-
SILModule &module, SILDebugLocation loc, SILFunction *originalFunction,
8061+
SILModule &module, SILDebugLocation loc,
80648062
DifferentiabilityWitnessFunctionKind witnessKind,
8065-
IndexSubset *parameterIndices, IndexSubset *resultIndices,
8066-
GenericSignature witnessGenericSignature);
8063+
SILDifferentiabilityWitness *witness);
80678064

80688065
DifferentiabilityWitnessFunctionKind getWitnessKind() const {
80698066
return witnessKind;
80708067
}
8071-
SILFunction *getOriginalFunction() const { return originalFunction; }
8072-
AutoDiffConfig const &getConfig() const { return config; }
8073-
IndexSubset *getParameterIndices() const { return config.parameterIndices; }
8074-
IndexSubset *getResultIndices() const { return config.resultIndices; }
8075-
GenericSignature getWitnessGenericSignature() const {
8076-
return config.derivativeGenericSignature;
8077-
}
8068+
SILDifferentiabilityWitness *getWitness() const { return witness; }
80788069

80798070
ArrayRef<Operand> getAllOperands() const { return {}; }
80808071
MutableArrayRef<Operand> getAllOperands() { return {}; }

include/swift/SIL/SILModule.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,8 @@ class SILModule {
204204
DefaultWitnessTableListType defaultWitnessTables;
205205

206206
// SWIFT_ENABLE_TENSORFLOW
207-
/// Lookup table for SIL differentiability witnesses from original functions.
208-
/// Indexed by key type: original function, parameter indices, result indices,
209-
/// and derivative generic signature.
210-
llvm::DenseMap<SILDifferentiabilityWitnessKey, SILDifferentiabilityWitness *>
211-
DifferentiabilityWitnessMap;
207+
/// Lookup table for SIL differentiability witnesses, keyed by mangled name.
208+
llvm::StringMap<SILDifferentiabilityWitness *> DifferentiabilityWitnessMap;
212209

213210
/// The list of SILDifferentiabilityWitnesses in the module.
214211
DifferentiabilityWitnessListType differentiabilityWitnesses;
@@ -609,6 +606,15 @@ class SILModule {
609606
/// hierarchy of \p Class.
610607
SILFunction *lookUpFunctionInVTable(ClassDecl *Class, SILDeclRef Member);
611608

609+
// SWIFT_ENABLE_TENSORFLOW
610+
/// Look up the differentiability witness with the given name.
611+
SILDifferentiabilityWitness *lookUpDifferentiabilityWitness(StringRef name);
612+
613+
/// Look up the differentiability witness corresponding to the given key.
614+
SILDifferentiabilityWitness *
615+
lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key);
616+
// SWIFT_ENABLE_TENSORFLOW_END
617+
612618
// Given a protocol, attempt to create a default witness table declaration
613619
// for it.
614620
SILDefaultWitnessTable *

lib/Parse/ParseStmt.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,9 @@ ParserStatus Parser::parseBraceItems(SmallVectorImpl<ASTNode> &Entries,
326326
Tok.isNot(tok::kw_sil_global) &&
327327
Tok.isNot(tok::kw_sil_witness_table) &&
328328
Tok.isNot(tok::kw_sil_default_witness_table) &&
329+
// SWIFT_ENABLE_TENSORFLOW
330+
Tok.isNot(tok::kw_sil_differentiability_witness) &&
331+
// SWIFT_ENABLE_TENSORFLOW_END
329332
Tok.isNot(tok::kw_sil_property) &&
330333
(isConditionalBlock ||
331334
!isTerminatorForBraceItemListKind(Kind, Entries))) {

0 commit comments

Comments
 (0)