Skip to content

[AutoDiff] IRGen differentiability witness tables #28067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/swift/AST/PrettyStackTrace.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,26 @@ class PrettyStackTraceSelector : public llvm::PrettyStackTraceEntry {
void print(llvm::raw_ostream &OS) const override;
};

// SWIFT_ENABLE_TENSORFLOW
/// PrettyStackTraceDifferentiabilityWitness - Observe that we are processing a
/// specific differentiability witness.
class PrettyStackTraceDifferentiabilityWitness
: public llvm::PrettyStackTraceEntry {
const SILDifferentiabilityWitnessKey Key;
const char *Action;

public:
PrettyStackTraceDifferentiabilityWitness(
const char *action, const SILDifferentiabilityWitnessKey key)
: Key(key), Action(action) {}
virtual void print(llvm::raw_ostream &OS) const;
};

void printDifferentiabilityWitnessDescription(
llvm::raw_ostream &out, const SILDifferentiabilityWitnessKey key,
bool addNewline = true);
// SWIFT_ENABLE_TENSORFLOW END

} // end namespace swift

#endif
29 changes: 29 additions & 0 deletions include/swift/IRGen/Linking.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,11 @@ class LinkEntity {
/// ProtocolConformance*.
ProtocolWitnessTableLazyCacheVariable,

// SWIFT_ENABLE_TENSORFLOW
/// A SIL differentiability witness.
DifferentiabilityWitness,
// SWIFT_ENABLE_TENSORFLOW_END

// Everything following this is a type kind.

/// A value witness for a type.
Expand Down Expand Up @@ -468,6 +473,15 @@ class LinkEntity {
associatedProtocol));
}

// SWIFT_ENABLE_TENSORFLOW
void
setForDifferentiabilityWitness(Kind kind,
const SILDifferentiabilityWitness *witness) {
Pointer = const_cast<void *>(static_cast<const void *>(witness));
Data = LINKENTITY_SET_FIELD(Kind, unsigned(kind));
}
// SWIFT_ENABLE_TENSORFLOW_END

// We store associated types using their index in their parent protocol
// in order to avoid bloating LinkEntity out to three key pointers.
static unsigned getAssociatedTypeIndex(const ProtocolConformance *conformance,
Expand Down Expand Up @@ -848,6 +862,16 @@ class LinkEntity {
return entity;
}

// SWIFT_ENABLE_TENSORFLOW
static LinkEntity
forDifferentiabilityWitness(const SILDifferentiabilityWitness *witness) {
LinkEntity entity;
entity.setForDifferentiabilityWitness(Kind::DifferentiabilityWitness,
witness);
return entity;
}
// SWIFT_ENABLE_TENSORFLOW_END

static LinkEntity
forGenericProtocolWitnessTableInstantiationFunction(
const ProtocolConformance *C) {
Expand Down Expand Up @@ -1043,6 +1067,11 @@ class LinkEntity {
return reinterpret_cast<SILGlobalVariable*>(Pointer);
}

SILDifferentiabilityWitness *getSILDifferentiabilityWitness() const {
assert(getKind() == Kind::DifferentiabilityWitness);
return reinterpret_cast<SILDifferentiabilityWitness *>(Pointer);
}

const RootProtocolConformance *getRootProtocolConformance() const {
assert(isRootProtocolConformanceKind(getKind()));
return cast<RootProtocolConformance>(getProtocolConformance());
Expand Down
15 changes: 15 additions & 0 deletions lib/AST/PrettyStackTrace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,18 @@ void PrettyStackTraceGenericSignature::print(llvm::raw_ostream &out) const {
void PrettyStackTraceSelector::print(llvm::raw_ostream &out) const {
out << "While " << Action << " '" << Selector << "'";
}

void PrettyStackTraceDifferentiabilityWitness::print(
llvm::raw_ostream &out) const {
out << "While " << Action << ' ';
printDifferentiabilityWitnessDescription(out, Key);
}

void swift::printDifferentiabilityWitnessDescription(
llvm::raw_ostream &out, const SILDifferentiabilityWitnessKey key,
bool addNewline) {
out << key.first << " ";
key.second.print(out);
if (addNewline)
out << '\n';
}
25 changes: 24 additions & 1 deletion lib/IRGen/GenDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,21 @@ void IRGenerator::emitGlobalTopLevel() {
CurrentIGMPtr IGM = getGenModule(prop.getDecl()->getInnermostDeclContext());
IGM->emitSILProperty(&prop);
}


// SWIFT_ENABLE_TENSORFLOW
// Emit differentiability witnesses.
for (auto &dw :
PrimaryIGM->getSILModule().getDifferentiabilityWitnessList()) {
if (dw.isDeclaration())
continue;

// Emit into same IRGenModule as the VJP.
CurrentIGMPtr IGM = getGenModule(dw.getVJP());

IGM->emitSILDifferentiabilityWitness(&dw);
}
// SWIFT_ENABLE_TENSORFLOW_END

// Emit code coverage mapping data.
PrimaryIGM->emitCoverageMapping();

Expand Down Expand Up @@ -4392,6 +4406,15 @@ IRGenModule::getAddrOfWitnessTablePattern(const NormalProtocolConformance *conf,
return getAddrOfLLVMVariable(entity, definition, DebugTypeInfo());
}

// SWIFT_ENABLE_TENSORFLOW
/// Look up the address of a witness table.
llvm::Constant *IRGenModule::getAddrOfDifferentiabilityWitness(
const SILDifferentiabilityWitness *witness, ConstantInit definition) {
auto entity = LinkEntity::forDifferentiabilityWitness(witness);
return getAddrOfLLVMVariable(entity, definition, DebugTypeInfo());
}
// SWIFT_ENABLE_TENSORFLOW

llvm::Function *
IRGenModule::getAddrOfAssociatedTypeWitnessTableAccessFunction(
const NormalProtocolConformance *conformance,
Expand Down
34 changes: 34 additions & 0 deletions lib/IRGen/GenProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2166,6 +2166,40 @@ void IRGenModule::emitSILWitnessTable(SILWitnessTable *wt) {
RequireMetadata);
}

// SWIFT_ENABLE_TENSORFLOW
void IRGenModule::emitSILDifferentiabilityWitness(
SILDifferentiabilityWitness *dw) {
PrettyStackTraceDifferentiabilityWitness _st(
"emitting differentiability witness for", dw->getKey());

// Don't emit declarations.
if (dw->isDeclaration())
return;

ConstantInitBuilder builder(*this);
auto diffWitnessContents = builder.beginStruct();

// TODO(marcrasi): When the differentiation pass generates JVP/VJP for
// witnesses, remove the nullptr case and add assertions that the JVP/VJP
// exist.
if (dw->getJVP()) {
diffWitnessContents.addBitCast(
getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy);
} else {
diffWitnessContents.addNullPointer(Int8PtrTy);
}
if (dw->getVJP()) {
diffWitnessContents.addBitCast(
getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy);
} else {
diffWitnessContents.addNullPointer(Int8PtrTy);
}

getAddrOfDifferentiabilityWitness(
dw, diffWitnessContents.finishAndCreateFuture());
}
// SWIFT_ENABLE_TENSORFLOW_END

/// True if a function's signature in LLVM carries polymorphic parameters.
/// Generic functions and protocol witnesses carry polymorphic parameters.
bool irgen::hasPolymorphicParameters(CanSILFunctionType ty) {
Expand Down
5 changes: 5 additions & 0 deletions lib/IRGen/IRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ IRGenModule::IRGenModule(IRGenerator &irgen,

DynamicReplacementKeyTy = createStructType(*this, "swift.dyn_repl_key",
{RelativeAddressTy, Int32Ty});

// SWIFT_ENABLE_TENSORFLOW
DifferentiabilityWitnessTy = createStructType(
*this, "swift.differentiability_witness", {Int8PtrTy, Int8PtrTy});
// SWIFT_ENABLE_TENSORFLOW_END
}

IRGenModule::~IRGenModule() {
Expand Down
13 changes: 13 additions & 0 deletions lib/IRGen/IRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,10 @@ class IRGenModule {
*DynamicReplacementLinkEntryPtrTy; // %link_entry*
llvm::StructType *DynamicReplacementKeyTy; // { i32, i32}

// SWIFT_ENABLE_TENSORFLOW
llvm::StructType *DifferentiabilityWitnessTy; // { i8*, i8* }
// SWIFT_ENABLE_TENSORFLOW_END

llvm::GlobalVariable *TheTrivialPropertyDescriptor = nullptr;

/// Used to create unique names for class layout types with tail allocated
Expand Down Expand Up @@ -1233,6 +1237,9 @@ private: \
void emitSILFunction(SILFunction *f);
void emitSILWitnessTable(SILWitnessTable *wt);
void emitSILProperty(SILProperty *prop);
// SWIFT_ENABLE_TENSORFLOW
void emitSILDifferentiabilityWitness(SILDifferentiabilityWitness *dw);
// SWIFT_ENABLE_TENSORFLOW END
void emitSILStaticInitializers();
llvm::Constant *emitFixedTypeLayout(CanType t, const FixedTypeInfo &ti);
void emitProtocolConformance(const ConformanceDescription &record);
Expand Down Expand Up @@ -1411,6 +1418,12 @@ private: \
const NormalProtocolConformance *C,
ConstantInit definition = ConstantInit());

// SWIFT_ENABLE_TENSORFLOW
llvm::Constant *
getAddrOfDifferentiabilityWitness(const SILDifferentiabilityWitness *witness,
ConstantInit definition = ConstantInit());
// SWIFT_ENABLE_TENSORFLOW_END

llvm::Function *
getAddrOfGenericWitnessTableInstantiationFunction(
const NormalProtocolConformance *C);
Expand Down
32 changes: 32 additions & 0 deletions lib/IRGen/Linking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,12 @@ std::string LinkEntity::mangleAsString() const {
case Kind::ReflectionAssociatedTypeDescriptor:
return mangler.mangleReflectionAssociatedTypeDescriptor(
getProtocolConformance());
// SWIFT_ENABLE_TENSORFLOW
case Kind::DifferentiabilityWitness:
return mangler.mangleSILDifferentiabilityWitnessKey(
{getSILDifferentiabilityWitness()->getOriginalFunction()->getName(),
getSILDifferentiabilityWitness()->getConfig()});
// SWIFT_ENABLE_TENSORFLOW_END
}
llvm_unreachable("bad entity kind!");
}
Expand Down Expand Up @@ -659,6 +665,10 @@ SILLinkage LinkEntity::getLinkage(ForDefinition_t forDefinition) const {
case Kind::ExtensionDescriptor:
case Kind::AnonymousDescriptor:
return SILLinkage::Shared;
// SWIFT_ENABLE_TENSORFLOW
case Kind::DifferentiabilityWitness:
return getSILDifferentiabilityWitness()->getLinkage();
// SWIFT_ENABLE_TENSORFLOW_END
}
llvm_unreachable("bad link entity kind");
}
Expand Down Expand Up @@ -803,6 +813,10 @@ bool LinkEntity::isAvailableExternally(IRGenModule &IGM) const {
case Kind::DynamicallyReplaceableFunctionImpl:
case Kind::DynamicallyReplaceableFunctionKeyAST:
llvm_unreachable("Relative reference to unsupported link entity");
// SWIFT_ENABLE_TENSORFLOW
case Kind::DifferentiabilityWitness:
return true;
// SWIFT_ENABLE_TENSORFLOW_END
}
llvm_unreachable("bad link entity kind");
}
Expand Down Expand Up @@ -904,6 +918,10 @@ llvm::Type *LinkEntity::getDefaultDeclarationType(IRGenModule &IGM) const {
return IGM.ObjCResilientClassStubTy;
}
llvm_unreachable("invalid metadata address");
// SWIFT_ENABLE_TENSORFLOW
case Kind::DifferentiabilityWitness:
return IGM.DifferentiabilityWitnessTy;
// SWIFT_ENABLE_TENSORFLOW_END
default:
llvm_unreachable("declaration LLVM type not specified");
}
Expand Down Expand Up @@ -956,6 +974,10 @@ Alignment LinkEntity::getAlignment(IRGenModule &IGM) const {
return Alignment(8);
case Kind::SILFunction:
return Alignment(1);
// SWIFT_ENABLE_TENSORFLOW
case Kind::DifferentiabilityWitness:
return IGM.getPointerAlignment();
// SWIFT_ENABLE_TENSORFLOW_END
default:
llvm_unreachable("alignment not specified");
}
Expand Down Expand Up @@ -1053,6 +1075,11 @@ bool LinkEntity::isWeakImported(ModuleDecl *module) const {
case Kind::ReflectionFieldDescriptor:
case Kind::CoroutineContinuationPrototype:
return false;

// SWIFT_ENABLE_TENSORFLOW
case Kind::DifferentiabilityWitness:
return false;
// SWIFT_ENABLE_TENSORFLOW_END
}

llvm_unreachable("Bad link entity kind");
Expand Down Expand Up @@ -1182,6 +1209,11 @@ const SourceFile *LinkEntity::getSourceFileForEmission() const {
case Kind::ValueWitness:
case Kind::ValueWitnessTable:
return nullptr;

// SWIFT_ENABLE_TENSORFLOW
case Kind::DifferentiabilityWitness:
return nullptr;
// SWIFT_ENABLE_TENSORFLOW_END
}

return sf;
Expand Down
8 changes: 8 additions & 0 deletions lib/SILOptimizer/IPO/DeadFunctionElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,14 @@ class DeadFunctionElimination : FunctionLivenessComputation {
}
}

// SWIFT_ENABLE_TENSORFLOW
// Check differentiable function witness entries.
for (auto &dw : Module->getDifferentiabilityWitnessList()) {
if (dw.getJVP())
ensureAlive(dw.getJVP());
if (dw.getVJP())
ensureAlive(dw.getVJP());
}
}

/// Removes all dead methods from vtables and witness tables.
Expand Down
Loading