Skip to content

Commit 76729c4

Browse files
authored
[AutoDiff] SILGen differentiability witnesses. (#27652)
Generate SIL differentiability witnesses from AST `@differentiable` attributes, using lowered parameter indices, result indices (currently with capacity 1 and set index 0), and derivative generic signature. Resolves TF-869. The TF-866 master issue tracks all retroactive derivative registration tasks.
1 parent 0d17ddf commit 76729c4

File tree

7 files changed

+205
-96
lines changed

7 files changed

+205
-96
lines changed

lib/SIL/SILPrinter.cpp

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3085,6 +3085,51 @@ void SILDefaultWitnessTable::dump() const {
30853085
print(llvm::errs());
30863086
}
30873087

3088+
// TODO(TF-893): Use this helper to dedupe the same logic in
3089+
// `SILFunction::print`.
3090+
static void printSILFunctionNameAndType(
3091+
llvm::raw_ostream &OS, SILFunction *function) {
3092+
function->printName(OS);
3093+
OS << " : $";
3094+
llvm::DenseMap<CanType, Identifier> Aliases;
3095+
llvm::DenseSet<Identifier> UsedNames;
3096+
auto sig = function->getLoweredFunctionType()->getGenericSignature();
3097+
auto *env = function->getGenericEnvironment();
3098+
if (sig && env) {
3099+
llvm::SmallString<16> disambiguatedNameBuf;
3100+
unsigned disambiguatedNameCounter = 1;
3101+
for (auto *paramTy : sig->getGenericParams()) {
3102+
auto sugaredTy = env->getSugaredType(paramTy);
3103+
Identifier name = sugaredTy->getName();
3104+
while (!UsedNames.insert(name).second) {
3105+
disambiguatedNameBuf.clear();
3106+
{
3107+
llvm::raw_svector_ostream names(disambiguatedNameBuf);
3108+
names << sugaredTy->getName() << disambiguatedNameCounter++;
3109+
}
3110+
name = function->getASTContext().getIdentifier(disambiguatedNameBuf);
3111+
}
3112+
if (name != sugaredTy->getName()) {
3113+
Aliases[paramTy->getCanonicalType()] = name;
3114+
3115+
// Also for the archetype
3116+
auto archetypeTy = env->mapTypeIntoContext(paramTy)
3117+
->getAs<ArchetypeType>();
3118+
if (archetypeTy)
3119+
Aliases[archetypeTy->getCanonicalType()] = name;
3120+
}
3121+
}
3122+
}
3123+
3124+
{
3125+
PrintOptions withGenericEnvironment = PrintOptions::printSIL();
3126+
withGenericEnvironment.GenericEnv = env;
3127+
withGenericEnvironment.AlternativeTypeNames =
3128+
Aliases.empty() ? nullptr : &Aliases;
3129+
function->getLoweredFunctionType()->print(OS, withGenericEnvironment);
3130+
}
3131+
}
3132+
30883133
// SWIFT_ENABLE_TENSORFLOW
30893134
void SILDifferentiabilityWitness::print(
30903135
llvm::raw_ostream &OS, bool verbose) const {
@@ -3107,7 +3152,7 @@ void SILDifferentiabilityWitness::print(
31073152
interleave(getResultIndices()->getIndices(),
31083153
[&](unsigned index) { OS << index; },
31093154
[&] { OS << ' '; });
3110-
OS << ']';
3155+
OS << "] ";
31113156
// ([where ...])?
31123157
if (auto *derivativeGenSig = getDerivativeGenericSignature()) {
31133158
ArrayRef<Requirement> requirements;
@@ -3123,28 +3168,34 @@ void SILDifferentiabilityWitness::print(
31233168
}
31243169
}
31253170
if (!requirements.empty()) {
3126-
OS << " [where ";
3171+
OS << "[where ";
31273172
auto subPrinter = PrintOptions::printSIL();
3173+
subPrinter.GenericEnv = origGenEnv;
31283174
interleave(requirements,
31293175
[&](Requirement req) {
31303176
req.print(OS, subPrinter);
31313177
},
31323178
[&] { OS << ", "; });
3133-
OS << ']';
3179+
OS << "] ";
31343180
}
31353181
}
31363182
// @original-function-name : $original-sil-type
3137-
OS << " @" << originalFunction->getName() << " : "
3138-
<< originalFunction->getLoweredType();
3183+
printSILFunctionNameAndType(OS, originalFunction);
31393184
// {
31403185
// jvp: @jvp-function-name : $jvp-sil-type
31413186
// vjp: @vjp-function-name : $vjp-sil-type
31423187
// }
31433188
OS << " {\n";
3144-
if (jvp)
3145-
OS << " jvp: @" << jvp->getName() << " : " << jvp->getLoweredType() << '\n';
3146-
if (vjp)
3147-
OS << " vjp: @" << vjp->getName() << " : " << vjp->getLoweredType() << '\n';
3189+
if (jvp) {
3190+
OS << " jvp: ";
3191+
printSILFunctionNameAndType(OS, jvp);
3192+
OS << '\n';
3193+
}
3194+
if (vjp) {
3195+
OS << " vjp: ";
3196+
printSILFunctionNameAndType(OS, vjp);
3197+
OS << '\n';
3198+
}
31483199
OS << "}\n\n";
31493200
}
31503201

lib/SILGen/SILGen.cpp

Lines changed: 123 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -752,87 +752,135 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
752752
F->print(llvm::dbgs()));
753753

754754
// SWIFT_ENABLE_TENSORFLOW
755-
// Create self-reordering thunks for JVPs/VJPs of `@differentiable` methods.
756-
if (constant.hasDecl() && constant.getAbstractFunctionDecl()) {
755+
// Visit `@differentiable` attributes and generate SIL differentiability
756+
// witnesses.
757+
// TODO(TF-835): Visit `@differentiating` attributes when type-checking no
758+
// longer generates implicit `@differentiable` attributes. See TF-835 for
759+
// replacement code.
760+
// Skip if the SILDeclRef is a:
761+
// - Default argument generator function.
762+
// - Thunk.
763+
if (constant.hasDecl() && constant.getAbstractFunctionDecl() &&
764+
constant.kind != SILDeclRef::Kind::DefaultArgGenerator &&
765+
!constant.isThunk()) {
757766
auto *AFD = constant.getAbstractFunctionDecl();
758-
auto origFnType = AFD->getInterfaceType()->castTo<AnyFunctionType>();
759-
auto origSilFnType = F->getLoweredFunctionType();
760-
// Jointly iterate over AST `@differentiable` attributes and SIL
761-
// `[differentiable]` attributes.
762-
auto diffAttrs = AFD->getAttrs().getAttributes<DifferentiableAttr>();
763-
auto silDiffAttrs = F->getDifferentiableAttrs();
764-
for (auto pair : llvm::zip(diffAttrs, silDiffAttrs)) {
765-
auto *diffAttr = const_cast<DifferentiableAttr *>(std::get<0>(pair));
766-
auto *silDiffAttr = std::get<1>(pair);
767-
// Compute lowered parameter indices.
768-
auto *paramIndices = diffAttr->getParameterIndices();
769-
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
770-
paramIndices, origFnType);
771-
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
772-
assert(silDiffAttr->getIndices() == indices &&
773-
"Expected matching @differentiable and [differentiable] indices");
774-
775-
auto lookUpConformance = LookUpConformanceInModule(M.getSwiftModule());
776-
auto expectedJVPType = origSilFnType->getAutoDiffDerivativeFunctionType(
777-
indices.parameters, indices.source,
778-
AutoDiffDerivativeFunctionKind::JVP, Types, lookUpConformance);
779-
auto expectedVJPType = origSilFnType->getAutoDiffDerivativeFunctionType(
780-
indices.parameters, indices.source,
781-
AutoDiffDerivativeFunctionKind::VJP, Types, lookUpConformance);
782-
783-
// Self reordering is necessary if wrt at least two parameters, including
784-
// self.
785-
auto shouldReorderSelf = [&]() {
786-
if (!F->hasSelfParam())
787-
return false;
788-
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
789-
if (!indices.isWrtParameter(selfParamIndex))
790-
return false;
791-
return indices.parameters->getNumIndices() > 1;
792-
};
793-
bool reorderSelf = shouldReorderSelf();
794-
795-
// Thunk JVP method, if it is defined.
796-
if (auto *jvpDecl = diffAttr->getJVPFunction()) {
797-
SILFunction *jvpThunk;
798-
auto *jvpFn = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
799-
if (reorderSelf || jvpFn->getLoweredFunctionType() != expectedJVPType) {
800-
jvpThunk = getOrCreateAutoDiffDerivativeFunctionThunk(
801-
F, indices, jvpFn, AutoDiffDerivativeFunctionKind::JVP,
802-
reorderSelf);
803-
} else {
804-
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
805-
AutoDiffDerivativeFunctionKind::JVP,
806-
diffAttr->getParameterIndices(), AFD->getASTContext());
807-
jvpThunk = getOrCreateAutoDiffThunk(
808-
constant.asAutoDiffDerivativeFunction(id), jvpFn,
809-
expectedJVPType);
810-
}
811-
silDiffAttr->setJVPName(jvpThunk->getName());
812-
}
813-
// Thunk VJP method, if it is defined.
814-
if (auto *vjpDecl = diffAttr->getVJPFunction()) {
815-
SILFunction *vjpThunk;
816-
auto *vjpFn = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
817-
if (reorderSelf || vjpFn->getLoweredFunctionType() != expectedVJPType) {
818-
vjpThunk = getOrCreateAutoDiffDerivativeFunctionThunk(
819-
F, indices, vjpFn, AutoDiffDerivativeFunctionKind::VJP,
820-
reorderSelf);
821-
} else {
822-
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
823-
AutoDiffDerivativeFunctionKind::VJP,
824-
diffAttr->getParameterIndices(), AFD->getASTContext());
825-
vjpThunk = getOrCreateAutoDiffThunk(
826-
constant.asAutoDiffDerivativeFunction(id), vjpFn,
827-
expectedVJPType);
828-
}
829-
silDiffAttr->setVJPName(vjpThunk->getName());
830-
}
767+
// Visit all `@differentiable` attributes.
768+
for (auto *diffAttr : AFD->getAttrs().getAttributes<DifferentiableAttr>()) {
769+
SILFunction *jvp = nullptr;
770+
SILFunction *vjp = nullptr;
771+
if (auto *jvpDecl = diffAttr->getJVPFunction())
772+
jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
773+
if (auto *vjpDecl = diffAttr->getVJPFunction())
774+
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
775+
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
776+
AutoDiffConfig config{diffAttr->getParameterIndices(), resultIndices,
777+
diffAttr->getDerivativeGenericSignature()};
778+
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp);
831779
}
832780
}
833781
F->verify();
834782
}
835783

784+
void SILGenModule::emitDifferentiabilityWitness(
785+
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
786+
const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp) {
787+
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
788+
auto origSilFnType = originalFunction->getLoweredFunctionType();
789+
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
790+
config.parameterIndices, origFnType);
791+
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
792+
// parameters corresponding to captured variables. These parameters do not
793+
// appear in the type of `origFnType`.
794+
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
795+
// take `CaptureInfo` into account.
796+
if (origSilFnType->getNumParameters() > loweredParamIndices->getCapacity())
797+
loweredParamIndices = loweredParamIndices->extendingCapacity(
798+
getASTContext(), origSilFnType->getNumParameters());
799+
// TODO(TF-913): Replace usages of `SILAutoDiffIndices` with `AutoDiffConfig`.
800+
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
801+
802+
// Self reordering thunk is necessary if wrt at least two parameters,
803+
// including self.
804+
auto shouldReorderSelf = [&]() {
805+
if (!originalFunction->hasSelfParam())
806+
return false;
807+
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
808+
if (!indices.isWrtParameter(selfParamIndex))
809+
return false;
810+
return indices.parameters->getNumIndices() > 1;
811+
};
812+
bool reorderSelf = shouldReorderSelf();
813+
814+
CanGenericSignature derivativeCanGenSig;
815+
if (auto *derivativeGenSig = config.derivativeGenericSignature)
816+
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
817+
// TODO(TF-835): Use simpler derivative generic signature logic below when
818+
// type-checking no longer generates implicit `@differentiable` attributes.
819+
// See TF-835 for replacement code.
820+
if (jvp) {
821+
auto jvpCanGenSig = jvp->getLoweredFunctionType()->getGenericSignature();
822+
if (!derivativeCanGenSig && jvpCanGenSig)
823+
derivativeCanGenSig = jvpCanGenSig;
824+
assert(derivativeCanGenSig == jvpCanGenSig);
825+
}
826+
if (vjp) {
827+
auto vjpCanGenSig = vjp->getLoweredFunctionType()->getGenericSignature();
828+
if (!derivativeCanGenSig && vjpCanGenSig)
829+
derivativeCanGenSig = vjpCanGenSig;
830+
assert(derivativeCanGenSig == vjpCanGenSig);
831+
}
832+
// Create new SIL differentiability witness.
833+
// Witness JVP and VJP are set below.
834+
// TODO(TF-919): Explore creating serialized differentiability witnesses.
835+
// Currently, differentiability witnesses are never serialized to avoid
836+
// deserialization issues where JVP/VJP functions cannot be found.
837+
auto *diffWitness = SILDifferentiabilityWitness::create(
838+
M, originalFunction->getLinkage(), originalFunction,
839+
loweredParamIndices, config.resultIndices, derivativeCanGenSig,
840+
/*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false);
841+
842+
// Set derivative function in differentiability witness.
843+
auto setDerivativeInDifferentiabilityWitness =
844+
[&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) {
845+
auto expectedDerivativeType =
846+
origSilFnType->getAutoDiffDerivativeFunctionType(
847+
indices.parameters, indices.source, kind, Types,
848+
LookUpConformanceInModule(M.getSwiftModule()));
849+
// Thunk derivative function.
850+
SILFunction *derivativeThunk;
851+
if (reorderSelf ||
852+
derivative->getLoweredFunctionType() != expectedDerivativeType) {
853+
derivativeThunk = getOrCreateAutoDiffDerivativeFunctionThunk(
854+
originalFunction, indices, derivative, kind, reorderSelf);
855+
} else {
856+
// Note: `AutoDiffDerivativeFunctionIdentifier` must be constructed with
857+
// the AST-level parameter indices, not the SIL-level ones.
858+
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
859+
kind, config.parameterIndices, getASTContext());
860+
derivativeThunk = getOrCreateAutoDiffThunk(
861+
SILDeclRef(originalAFD).asAutoDiffDerivativeFunction(id), derivative,
862+
expectedDerivativeType);
863+
}
864+
// Check for existing same derivative.
865+
// TODO(TF-835): Remove condition below and simplify assertion to
866+
// `!diffWitness->getDerivative(kind)` after `@differentiating` attribute
867+
// type-checking no longer generates implicit `@differentiable` attributes.
868+
auto *existingDerivative = diffWitness->getDerivative(kind);
869+
if (existingDerivative && existingDerivative == derivativeThunk)
870+
return;
871+
assert(!existingDerivative &&
872+
"SIL differentiability witness already has a different existing "
873+
"derivative");
874+
diffWitness->setDerivative(kind, derivativeThunk);
875+
};
876+
if (jvp)
877+
setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::JVP,
878+
jvp);
879+
if (vjp)
880+
setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::VJP,
881+
vjp);
882+
}
883+
836884
void SILGenModule::
837885
emitMarkFunctionEscapeForTopLevelCodeGlobals(SILLocation loc,
838886
const CaptureInfo &captureInfo) {

lib/SILGen/SILGen.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,16 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
318318
/// Emit the self-conformance witness table for a protocol.
319319
void emitSelfConformanceWitnessTable(ProtocolDecl *protocol);
320320

321+
// SWIFT_ENABLE_TENSORFLOW
322+
/// Emit the differentiability witness for the given original function
323+
/// declaration and SIL function, autodiff configuration, and JVP and VJP
324+
/// functions (null if undefined).
325+
void emitDifferentiabilityWitness(AbstractFunctionDecl *originalAFD,
326+
SILFunction *originalFunction,
327+
const AutoDiffConfig &config,
328+
SILFunction *jvp, SILFunction *vjp);
329+
// SWIFT_ENABLE_TENSORFLOW END
330+
321331
/// Emit the lazy initializer function for a global pattern binding
322332
/// declaration.
323333
SILFunction *emitLazyGlobalInitializer(StringRef funcName,

lib/Serialization/DeserializeSIL.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3016,6 +3016,9 @@ void SILDeserializer::readWitnessTableEntries(
30163016
// Another record means the end of this WitnessTable.
30173017
while (kind != SIL_WITNESS_TABLE &&
30183018
kind != SIL_DEFAULT_WITNESS_TABLE &&
3019+
// SWIFT_ENABLE_TENSORFLOW
3020+
kind != SIL_DIFFERENTIABILITY_WITNESS &&
3021+
// SWIFT_ENABLE_TENSORFLOW END
30193022
kind != SIL_FUNCTION) {
30203023
if (kind == SIL_DEFAULT_WITNESS_TABLE_NO_ENTRY) {
30213024
witnessEntries.push_back(SILDefaultWitnessTable::Entry());

lib/Serialization/Serialization.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ void Serializer::writeBlockInfoBlock() {
789789
BLOCK_RECORD(sil_block, SIL_INST_LINEAR_FUNCTION);
790790
BLOCK_RECORD(sil_block, SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT);
791791
BLOCK_RECORD(sil_block, SIL_INST_LINEAR_FUNCTION_EXTRACT);
792+
BLOCK_RECORD(sil_block, SIL_DIFFERENTIABILITY_WITNESS);
792793
// SWIFT_ENABLE_TENSORFLOW END
793794

794795
// These layouts can exist in both decl blocks and sil blocks.
@@ -829,6 +830,7 @@ void Serializer::writeBlockInfoBlock() {
829830
BLOCK_RECORD(sil_index_block, SIL_DEFAULT_WITNESS_TABLE_OFFSETS);
830831
BLOCK_RECORD(sil_index_block, SIL_PROPERTY_OFFSETS);
831832
// SWIFT_ENABLE_TENSORFLOW
833+
BLOCK_RECORD(sil_index_block, SIL_DIFFERENTIABILITY_WITNESS_NAMES);
832834
BLOCK_RECORD(sil_index_block, SIL_DIFFERENTIABILITY_WITNESS_OFFSETS);
833835
// SWIFT_ENABLE_TENSORFLOW END
834836

lib/Serialization/SerializeSIL.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2322,7 +2322,7 @@ void SILSerializer::writeIndexTables() {
23222322
}
23232323

23242324
// SWIFT_ENABLE_TENSORFLOW
2325-
if (!DifferentiabilityWitnessOffset.empty()) {
2325+
if (!DifferentiabilityWitnessList.empty()) {
23262326
writeIndexTable(S, List,
23272327
sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES,
23282328
DifferentiabilityWitnessList);
@@ -2542,17 +2542,12 @@ writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) {
25422542
DifferentiabilityWitnessOffset.push_back(Out.GetCurrentBitNo());
25432543

25442544
auto *original = dw.getOriginalFunction();
2545-
addReferencedSILFunction(original, /*DeclOnly*/ true);
25462545
IdentifierID jvpID = 0;
25472546
IdentifierID vjpID = 0;
2548-
if (auto *jvp = dw.getJVP()) {
2549-
addReferencedSILFunction(jvp, /*DeclOnly*/ true);
2550-
jvpID = S.addUniquedStringRef(jvp->getName());
2551-
}
2552-
if (auto *vjp = dw.getVJP()) {
2553-
addReferencedSILFunction(vjp, /*DeclOnly*/ true);
2554-
vjpID = S.addUniquedStringRef(vjp->getName());
2555-
}
2547+
if (auto *jvp = dw.getJVP())
2548+
jvpID = addSILFunctionRef(jvp);
2549+
if (auto *vjp = dw.getVJP())
2550+
vjpID = addSILFunctionRef(vjp);
25562551
SmallVector<unsigned, 8> parameterAndResultIndices(
25572552
dw.getParameterIndices()->begin(), dw.getParameterIndices()->end());
25582553
parameterAndResultIndices.append(dw.getResultIndices()->begin(),
@@ -2569,7 +2564,7 @@ writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) {
25692564

25702565
DifferentiabilityWitnessLayout::emitRecord(
25712566
Out, ScratchRecord, SILAbbrCodes[DifferentiabilityWitnessLayout::Code],
2572-
S.addUniquedStringRef(original->getName()),
2567+
addSILFunctionRef(original),
25732568
toStableSILLinkage(dw.getLinkage()),
25742569
dw.isSerialized(),
25752570
S.addGenericSignatureRef(dw.getDerivativeGenericSignature()),

0 commit comments

Comments
 (0)