Skip to content

Commit b25dd43

Browse files
author
Marc Rasi
committed
IRGen differentiability witness tables
1 parent 4d1c1df commit b25dd43

File tree

12 files changed

+304
-43
lines changed

12 files changed

+304
-43
lines changed

include/swift/AST/PrettyStackTrace.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,28 @@ class PrettyStackTraceSelector : public llvm::PrettyStackTraceEntry {
187187
void print(llvm::raw_ostream &OS) const override;
188188
};
189189

190+
// SWIFT_ENABLE_TENSORFLOW
191+
/// PrettyStackTraceDifferentiabilityWitness - Observe that we are processing a
192+
/// specific differentiability witness.
193+
class PrettyStackTraceDifferentiabilityWitness
194+
: public llvm::PrettyStackTraceEntry {
195+
ASTContext &Context;
196+
const SILDifferentiabilityWitnessKey Key;
197+
const char *Action;
198+
199+
public:
200+
PrettyStackTraceDifferentiabilityWitness(
201+
ASTContext &C, const char *action,
202+
const SILDifferentiabilityWitnessKey key)
203+
: Context(C), Key(key), Action(action) {}
204+
virtual void print(llvm::raw_ostream &OS) const;
205+
};
206+
207+
void printDifferentiabilityWitnessDescription(
208+
llvm::raw_ostream &out, const SILDifferentiabilityWitnessKey key,
209+
ASTContext &Context, bool addNewline = true);
210+
// SWIFT_ENABLE_TENSORFLOW END
211+
190212
} // end namespace swift
191213

192214
#endif

include/swift/IRGen/Linking.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,11 @@ class LinkEntity {
346346
/// ProtocolConformance*.
347347
ProtocolWitnessTableLazyCacheVariable,
348348

349+
// SWIFT_ENABLE_TENSORFLOW
350+
/// A SIL differentiability witness.
351+
DifferentiabilityWitness,
352+
// SWIFT_ENABLE_TENSORFLOW_END
353+
349354
// Everything following this is a type kind.
350355

351356
/// A value witness for a type.
@@ -468,6 +473,15 @@ class LinkEntity {
468473
associatedProtocol));
469474
}
470475

476+
// SWIFT_ENABLE_TENSORFLOW
477+
void
478+
setForDifferentiabilityWitness(Kind kind,
479+
const SILDifferentiabilityWitness *witness) {
480+
Pointer = const_cast<void *>(static_cast<const void *>(witness));
481+
Data = LINKENTITY_SET_FIELD(Kind, unsigned(kind));
482+
}
483+
// SWIFT_ENABLE_TENSORFLOW_END
484+
471485
// We store associated types using their index in their parent protocol
472486
// in order to avoid bloating LinkEntity out to three key pointers.
473487
static unsigned getAssociatedTypeIndex(const ProtocolConformance *conformance,
@@ -848,6 +862,16 @@ class LinkEntity {
848862
return entity;
849863
}
850864

865+
// SWIFT_ENABLE_TENSORFLOW
866+
static LinkEntity
867+
forDifferentiabilityWitness(const SILDifferentiabilityWitness *witness) {
868+
LinkEntity entity;
869+
entity.setForDifferentiabilityWitness(Kind::DifferentiabilityWitness,
870+
witness);
871+
return entity;
872+
}
873+
// SWIFT_ENABLE_TENSORFLOW_END
874+
851875
static LinkEntity
852876
forGenericProtocolWitnessTableInstantiationFunction(
853877
const ProtocolConformance *C) {
@@ -1043,6 +1067,11 @@ class LinkEntity {
10431067
return reinterpret_cast<SILGlobalVariable*>(Pointer);
10441068
}
10451069

1070+
SILDifferentiabilityWitness *getSILDifferentiabilityWitness() const {
1071+
assert(getKind() == Kind::DifferentiabilityWitness);
1072+
return reinterpret_cast<SILDifferentiabilityWitness *>(Pointer);
1073+
}
1074+
10461075
const RootProtocolConformance *getRootProtocolConformance() const {
10471076
assert(isRootProtocolConformanceKind(getKind()));
10481077
return cast<RootProtocolConformance>(getProtocolConformance());

lib/AST/PrettyStackTrace.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,18 @@ void PrettyStackTraceGenericSignature::print(llvm::raw_ostream &out) const {
263263
void PrettyStackTraceSelector::print(llvm::raw_ostream &out) const {
264264
out << "While " << Action << " '" << Selector << "'";
265265
}
266+
267+
void PrettyStackTraceDifferentiabilityWitness::print(
268+
llvm::raw_ostream &out) const {
269+
out << "While " << Action << ' ';
270+
printDifferentiabilityWitnessDescription(out, Key, Context);
271+
}
272+
273+
void swift::printDifferentiabilityWitnessDescription(
274+
llvm::raw_ostream &out, const SILDifferentiabilityWitnessKey key,
275+
ASTContext &Context, bool addNewline) {
276+
out << key.first << " ";
277+
key.second.print(out);
278+
if (addNewline)
279+
out << '\n';
280+
}

lib/IRGen/GenDecl.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1064,7 +1064,21 @@ void IRGenerator::emitGlobalTopLevel() {
10641064
CurrentIGMPtr IGM = getGenModule(prop.getDecl()->getInnermostDeclContext());
10651065
IGM->emitSILProperty(&prop);
10661066
}
1067-
1067+
1068+
// SWIFT_ENABLE_TENSORFLOW
1069+
// Emit differentiability witnesses.
1070+
for (auto &dw :
1071+
PrimaryIGM->getSILModule().getDifferentiabilityWitnessList()) {
1072+
if (dw.isDeclaration())
1073+
continue;
1074+
1075+
// Emit into same IRGenModule as the VJP.
1076+
CurrentIGMPtr IGM = getGenModule(dw.getVJP());
1077+
1078+
IGM->emitSILDifferentiabilityWitness(&dw);
1079+
}
1080+
// SWIFT_ENABLE_TENSORFLOW_END
1081+
10681082
// Emit code coverage mapping data.
10691083
PrimaryIGM->emitCoverageMapping();
10701084

@@ -4392,6 +4406,15 @@ IRGenModule::getAddrOfWitnessTablePattern(const NormalProtocolConformance *conf,
43924406
return getAddrOfLLVMVariable(entity, definition, DebugTypeInfo());
43934407
}
43944408

4409+
// SWIFT_ENABLE_TENSORFLOW
4410+
/// Look up the address of a witness table.
4411+
llvm::Constant *IRGenModule::getAddrOfDifferentiabilityWitness(
4412+
const SILDifferentiabilityWitness *witness, ConstantInit definition) {
4413+
auto entity = LinkEntity::forDifferentiabilityWitness(witness);
4414+
return getAddrOfLLVMVariable(entity, definition, DebugTypeInfo());
4415+
}
4416+
// SWIFT_ENABLE_TENSORFLOW
4417+
43954418
llvm::Function *
43964419
IRGenModule::getAddrOfAssociatedTypeWitnessTableAccessFunction(
43974420
const NormalProtocolConformance *conformance,

lib/IRGen/GenProto.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,6 +2166,40 @@ void IRGenModule::emitSILWitnessTable(SILWitnessTable *wt) {
21662166
RequireMetadata);
21672167
}
21682168

2169+
// SWIFT_ENABLE_TENSORFLOW
2170+
void IRGenModule::emitSILDifferentiabilityWitness(
2171+
SILDifferentiabilityWitness *dw) {
2172+
PrettyStackTraceDifferentiabilityWitness _st(
2173+
Context, "emitting differentiability witness for", dw->getKey());
2174+
2175+
// Don't emit declarations.
2176+
if (dw->isDeclaration())
2177+
return;
2178+
2179+
ConstantInitBuilder builder(*this);
2180+
auto diffWitnessContents = builder.beginStruct();
2181+
2182+
// TODO(marcrasi): When the differentiation pass generates JVP/VJP for
2183+
// witnesses, remove the nullptr case and add assertions that the JVP/VJP
2184+
// exist.
2185+
if (dw->getJVP()) {
2186+
diffWitnessContents.addBitCast(
2187+
getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy);
2188+
} else {
2189+
diffWitnessContents.addNullPointer(Int8PtrTy);
2190+
}
2191+
if (dw->getVJP()) {
2192+
diffWitnessContents.addBitCast(
2193+
getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy);
2194+
} else {
2195+
diffWitnessContents.addNullPointer(Int8PtrTy);
2196+
}
2197+
2198+
getAddrOfDifferentiabilityWitness(
2199+
dw, diffWitnessContents.finishAndCreateFuture());
2200+
}
2201+
// SWIFT_ENABLE_TENSORFLOW_END
2202+
21692203
/// True if a function's signature in LLVM carries polymorphic parameters.
21702204
/// Generic functions and protocol witnesses carry polymorphic parameters.
21712205
bool irgen::hasPolymorphicParameters(CanSILFunctionType ty) {

lib/IRGen/IRGenModule.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,11 @@ IRGenModule::IRGenModule(IRGenerator &irgen,
524524

525525
DynamicReplacementKeyTy = createStructType(*this, "swift.dyn_repl_key",
526526
{RelativeAddressTy, Int32Ty});
527+
528+
// SWIFT_ENABLE_TENSORFLOW
529+
DifferentiabilityWitnessTy = createStructType(
530+
*this, "swift.differentiability_witness", {Int8PtrTy, Int8PtrTy});
531+
// SWIFT_ENABLE_TENSORFLOW_END
527532
}
528533

529534
IRGenModule::~IRGenModule() {

lib/IRGen/IRGenModule.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,10 @@ class IRGenModule {
639639
*DynamicReplacementLinkEntryPtrTy; // %link_entry*
640640
llvm::StructType *DynamicReplacementKeyTy; // { i32, i32}
641641

642+
// SWIFT_ENABLE_TENSORFLOW
643+
llvm::StructType *DifferentiabilityWitnessTy; // { i8*, i8* }
644+
// SWIFT_ENABLE_TENSORFLOW_END
645+
642646
llvm::GlobalVariable *TheTrivialPropertyDescriptor = nullptr;
643647

644648
/// Used to create unique names for class layout types with tail allocated
@@ -1233,6 +1237,9 @@ private: \
12331237
void emitSILFunction(SILFunction *f);
12341238
void emitSILWitnessTable(SILWitnessTable *wt);
12351239
void emitSILProperty(SILProperty *prop);
1240+
// SWIFT_ENABLE_TENSORFLOW
1241+
void emitSILDifferentiabilityWitness(SILDifferentiabilityWitness *dw);
1242+
// SWIFT_ENABLE_TENSORFLOW END
12361243
void emitSILStaticInitializers();
12371244
llvm::Constant *emitFixedTypeLayout(CanType t, const FixedTypeInfo &ti);
12381245
void emitProtocolConformance(const ConformanceDescription &record);
@@ -1411,6 +1418,12 @@ private: \
14111418
const NormalProtocolConformance *C,
14121419
ConstantInit definition = ConstantInit());
14131420

1421+
// SWIFT_ENABLE_TENSORFLOW
1422+
llvm::Constant *
1423+
getAddrOfDifferentiabilityWitness(const SILDifferentiabilityWitness *witness,
1424+
ConstantInit definition = ConstantInit());
1425+
// SWIFT_ENABLE_TENSORFLOW_END
1426+
14141427
llvm::Function *
14151428
getAddrOfGenericWitnessTableInstantiationFunction(
14161429
const NormalProtocolConformance *C);

lib/IRGen/Linking.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,12 @@ std::string LinkEntity::mangleAsString() const {
414414
case Kind::ReflectionAssociatedTypeDescriptor:
415415
return mangler.mangleReflectionAssociatedTypeDescriptor(
416416
getProtocolConformance());
417+
// SWIFT_ENABLE_TENSORFLOW
418+
case Kind::DifferentiabilityWitness:
419+
return mangler.mangleSILDifferentiabilityWitnessKey(
420+
{getSILDifferentiabilityWitness()->getOriginalFunction()->getName(),
421+
getSILDifferentiabilityWitness()->getConfig()});
422+
// SWIFT_ENABLE_TENSORFLOW_END
417423
}
418424
llvm_unreachable("bad entity kind!");
419425
}
@@ -659,6 +665,10 @@ SILLinkage LinkEntity::getLinkage(ForDefinition_t forDefinition) const {
659665
case Kind::ExtensionDescriptor:
660666
case Kind::AnonymousDescriptor:
661667
return SILLinkage::Shared;
668+
// SWIFT_ENABLE_TENSORFLOW
669+
case Kind::DifferentiabilityWitness:
670+
return getSILDifferentiabilityWitness()->getLinkage();
671+
// SWIFT_ENABLE_TENSORFLOW_END
662672
}
663673
llvm_unreachable("bad link entity kind");
664674
}
@@ -803,6 +813,10 @@ bool LinkEntity::isAvailableExternally(IRGenModule &IGM) const {
803813
case Kind::DynamicallyReplaceableFunctionImpl:
804814
case Kind::DynamicallyReplaceableFunctionKeyAST:
805815
llvm_unreachable("Relative reference to unsupported link entity");
816+
// SWIFT_ENABLE_TENSORFLOW
817+
case Kind::DifferentiabilityWitness:
818+
return true;
819+
// SWIFT_ENABLE_TENSORFLOW_END
806820
}
807821
llvm_unreachable("bad link entity kind");
808822
}
@@ -904,6 +918,10 @@ llvm::Type *LinkEntity::getDefaultDeclarationType(IRGenModule &IGM) const {
904918
return IGM.ObjCResilientClassStubTy;
905919
}
906920
llvm_unreachable("invalid metadata address");
921+
// SWIFT_ENABLE_TENSORFLOW
922+
case Kind::DifferentiabilityWitness:
923+
return IGM.DifferentiabilityWitnessTy;
924+
// SWIFT_ENABLE_TENSORFLOW_END
907925
default:
908926
llvm_unreachable("declaration LLVM type not specified");
909927
}
@@ -956,6 +974,10 @@ Alignment LinkEntity::getAlignment(IRGenModule &IGM) const {
956974
return Alignment(8);
957975
case Kind::SILFunction:
958976
return Alignment(1);
977+
// SWIFT_ENABLE_TENSORFLOW
978+
case Kind::DifferentiabilityWitness:
979+
return IGM.getPointerAlignment();
980+
// SWIFT_ENABLE_TENSORFLOW_END
959981
default:
960982
llvm_unreachable("alignment not specified");
961983
}
@@ -1053,6 +1075,11 @@ bool LinkEntity::isWeakImported(ModuleDecl *module) const {
10531075
case Kind::ReflectionFieldDescriptor:
10541076
case Kind::CoroutineContinuationPrototype:
10551077
return false;
1078+
1079+
// SWIFT_ENABLE_TENSORFLOW
1080+
case Kind::DifferentiabilityWitness:
1081+
return false;
1082+
// SWIFT_ENABLE_TENSORFLOW_END
10561083
}
10571084

10581085
llvm_unreachable("Bad link entity kind");
@@ -1182,6 +1209,11 @@ const SourceFile *LinkEntity::getSourceFileForEmission() const {
11821209
case Kind::ValueWitness:
11831210
case Kind::ValueWitnessTable:
11841211
return nullptr;
1212+
1213+
// SWIFT_ENABLE_TENSORFLOW
1214+
case Kind::DifferentiabilityWitness:
1215+
return nullptr;
1216+
// SWIFT_ENABLE_TENSORFLOW_END
11851217
}
11861218

11871219
return sf;

lib/SILOptimizer/IPO/DeadFunctionElimination.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,14 @@ class DeadFunctionElimination : FunctionLivenessComputation {
606606
}
607607
}
608608

609+
// SWIFT_ENABLE_TENSORFLOW
610+
// Check differentiable function witness entries.
611+
for (auto &dw : Module->getDifferentiabilityWitnessList()) {
612+
if (dw.getJVP())
613+
ensureAlive(dw.getJVP());
614+
if (dw.getVJP())
615+
ensureAlive(dw.getVJP());
616+
}
609617
}
610618

611619
/// Removes all dead methods from vtables and witness tables.

0 commit comments

Comments
 (0)