Skip to content

[AutoDiff upstream] Add SIL differentiability witness IRGen. #29704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/swift/AST/PrettyStackTrace.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,24 @@ class PrettyStackTraceSelector : public llvm::PrettyStackTraceEntry {
void print(llvm::raw_ostream &OS) const override;
};

/// PrettyStackTraceDifferentiabilityWitness - Observe that we are processing a
/// specific differentiability witness.
class PrettyStackTraceDifferentiabilityWitness
: public llvm::PrettyStackTraceEntry {
const SILDifferentiabilityWitnessKey Key;
const char *Action;

public:
PrettyStackTraceDifferentiabilityWitness(
const char *action, const SILDifferentiabilityWitnessKey key)
: Key(key), Action(action) {}
virtual void print(llvm::raw_ostream &OS) const;
};

void printDifferentiabilityWitnessDescription(
llvm::raw_ostream &out, const SILDifferentiabilityWitnessKey key,
bool addNewline = true);

} // end namespace swift

#endif
24 changes: 24 additions & 0 deletions include/swift/IRGen/Linking.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ class LinkEntity {
/// ProtocolConformance*.
ProtocolWitnessTableLazyCacheVariable,

/// A SIL differentiability witness.
DifferentiabilityWitness,

// Everything following this is a type kind.

/// A value witness for a type.
Expand Down Expand Up @@ -535,6 +538,14 @@ class LinkEntity {
return getAssociatedConformanceByIndex(conformance->getProtocol(), index);
}

void
setForDifferentiabilityWitness(Kind kind,
const SILDifferentiabilityWitness *witness) {
Pointer = const_cast<void *>(static_cast<const void *>(witness));
SecondaryPointer = nullptr;
Data = LINKENTITY_SET_FIELD(Kind, unsigned(kind));
}

void setForType(Kind kind, CanType type) {
assert(isTypeKind(kind));
Pointer = type.getPointer();
Expand Down Expand Up @@ -835,6 +846,14 @@ class LinkEntity {
return entity;
}

static LinkEntity
forDifferentiabilityWitness(const SILDifferentiabilityWitness *witness) {
LinkEntity entity;
entity.setForDifferentiabilityWitness(Kind::DifferentiabilityWitness,
witness);
return entity;
}

static LinkEntity forProtocolWitnessTable(const RootProtocolConformance *C) {
LinkEntity entity;
entity.setForProtocolConformance(Kind::ProtocolWitnessTable, C);
Expand Down Expand Up @@ -1043,6 +1062,11 @@ class LinkEntity {
return reinterpret_cast<SILGlobalVariable*>(Pointer);
}

SILDifferentiabilityWitness *getSILDifferentiabilityWitness() const {
assert(getKind() == Kind::DifferentiabilityWitness);
return reinterpret_cast<SILDifferentiabilityWitness *>(Pointer);
}

const RootProtocolConformance *getRootProtocolConformance() const {
assert(isRootProtocolConformanceKind(getKind()));
return cast<RootProtocolConformance>(getProtocolConformance());
Expand Down
12 changes: 12 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@

using namespace swift;

void AutoDiffConfig::print(llvm::raw_ostream &s) const {
s << "(parameters=";
parameterIndices->print(s);
s << " results=";
resultIndices->print(s);
if (derivativeGenericSignature) {
s << " where=";
derivativeGenericSignature->print(s);
}
s << ')';
}

// TODO(TF-874): This helper is inefficient and should be removed. Unwrapping at
// most once (for curried method types) is sufficient.
static void unwrapCurryLevels(AnyFunctionType *fnTy,
Expand Down
15 changes: 15 additions & 0 deletions lib/AST/PrettyStackTrace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,18 @@ void PrettyStackTraceGenericSignature::print(llvm::raw_ostream &out) const {
void PrettyStackTraceSelector::print(llvm::raw_ostream &out) const {
out << "While " << Action << " '" << Selector << "'";
}

void PrettyStackTraceDifferentiabilityWitness::print(
llvm::raw_ostream &out) const {
out << "While " << Action << ' ';
printDifferentiabilityWitnessDescription(out, Key);
}

void swift::printDifferentiabilityWitnessDescription(
llvm::raw_ostream &out, const SILDifferentiabilityWitnessKey key,
bool addNewline) {
out << key.first << " ";
key.second.print(out);
if (addNewline)
out << '\n';
}
1 change: 1 addition & 0 deletions lib/IRGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_swift_host_library(swiftIRGen STATIC
GenControl.cpp
GenCoverage.cpp
GenDecl.cpp
GenDiffWitness.cpp
GenEnum.cpp
GenExistential.cpp
GenFunc.cpp
Expand Down
23 changes: 22 additions & 1 deletion lib/IRGen/GenDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,21 @@ void IRGenerator::emitGlobalTopLevel(llvm::StringSet<> *linkerDirectives) {
CurrentIGMPtr IGM = getGenModule(prop.getDecl()->getInnermostDeclContext());
IGM->emitSILProperty(&prop);
}


// Emit differentiability witnesses.
for (auto &dw :
PrimaryIGM->getSILModule().getDifferentiabilityWitnessList()) {
// Emit into same IRGenModule as the original function.
// NOTE(TF-894): Investigate whether `getGenModule(dw.getVJP())` is
// significant/desirable; `getGenModule` seems relevant for multi-threaded
// compilation. When the differentiation transform canonicalizes all
// differentiability witnesses to have JVP/VJP functions, we can assert
// that JVP/VJP functions exist and use `getGenModule(dw.getVJP())`.
CurrentIGMPtr IGM = getGenModule(dw.getOriginalFunction());

IGM->emitSILDifferentiabilityWitness(&dw);
}

// Emit code coverage mapping data.
PrimaryIGM->emitCoverageMapping();

Expand Down Expand Up @@ -4495,6 +4509,13 @@ IRGenModule::getAddrOfWitnessTablePattern(const NormalProtocolConformance *conf,
return getAddrOfLLVMVariable(entity, definition, DebugTypeInfo());
}

/// Look up the address of a differentiability witness.
llvm::Constant *IRGenModule::getAddrOfDifferentiabilityWitness(
const SILDifferentiabilityWitness *witness, ConstantInit definition) {
auto entity = LinkEntity::forDifferentiabilityWitness(witness);
return getAddrOfLLVMVariable(entity, definition, DebugTypeInfo());
}

llvm::Function *
IRGenModule::getAddrOfAssociatedTypeWitnessTableAccessFunction(
const NormalProtocolConformance *conformance,
Expand Down
54 changes: 54 additions & 0 deletions lib/IRGen/GenDiffWitness.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//===--- GenDiffWitness.cpp - IRGen for differentiability witnesses -------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements IR generation for SIL differentiability witnesses.
//
//===----------------------------------------------------------------------===//

#include "swift/AST/PrettyStackTrace.h"
#include "swift/SIL/SILDifferentiabilityWitness.h"

#include "ConstantBuilder.h"
#include "IRGenModule.h"

using namespace swift;
using namespace irgen;

void IRGenModule::emitSILDifferentiabilityWitness(
SILDifferentiabilityWitness *dw) {
PrettyStackTraceDifferentiabilityWitness _st(
"emitting differentiability witness for", dw->getKey());

// Don't emit declarations.
if (dw->isDeclaration())
return;

// Don't emit `public_external` witnesses.
if (dw->getLinkage() == SILLinkage::PublicExternal)
return;

ConstantInitBuilder builder(*this);
auto diffWitnessContents = builder.beginStruct();

assert(dw->getJVP() &&
"Differentiability witness definition should have JVP");
assert(dw->getVJP() &&
"Differentiability witness definition should have VJP");

diffWitnessContents.addBitCast(
getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy);
diffWitnessContents.addBitCast(
getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy);

getAddrOfDifferentiabilityWitness(
dw, diffWitnessContents.finishAndCreateFuture());
}
3 changes: 3 additions & 0 deletions lib/IRGen/IRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,9 @@ IRGenModule::IRGenModule(IRGenerator &irgen,

DynamicReplacementKeyTy = createStructType(*this, "swift.dyn_repl_key",
{RelativeAddressTy, Int32Ty});

DifferentiabilityWitnessTy = createStructType(
*this, "swift.differentiability_witness", {Int8PtrTy, Int8PtrTy});
}

IRGenModule::~IRGenModule() {
Expand Down
8 changes: 8 additions & 0 deletions lib/IRGen/IRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ namespace swift {
class RootProtocolConformance;
struct SILDeclRef;
class SILDefaultWitnessTable;
class SILDifferentiabilityWitness;
class SILGlobalVariable;
class SILModule;
class SILProperty;
Expand Down Expand Up @@ -672,6 +673,8 @@ class IRGenModule {
*DynamicReplacementLinkEntryPtrTy; // %link_entry*
llvm::StructType *DynamicReplacementKeyTy; // { i32, i32}

llvm::StructType *DifferentiabilityWitnessTy; // { i8*, i8* }

llvm::GlobalVariable *TheTrivialPropertyDescriptor = nullptr;

/// Used to create unique names for class layout types with tail allocated
Expand Down Expand Up @@ -1272,6 +1275,7 @@ private: \
void emitSILFunction(SILFunction *f);
void emitSILWitnessTable(SILWitnessTable *wt);
void emitSILProperty(SILProperty *prop);
void emitSILDifferentiabilityWitness(SILDifferentiabilityWitness *dw);
void emitSILStaticInitializers();
llvm::Constant *emitFixedTypeLayout(CanType t, const FixedTypeInfo &ti);
void emitProtocolConformance(const ConformanceDescription &record);
Expand Down Expand Up @@ -1463,6 +1467,10 @@ private: \
llvm::Function *getAddrOfDefaultAssociatedConformanceAccessor(
AssociatedConformance requirement);

llvm::Constant *
getAddrOfDifferentiabilityWitness(const SILDifferentiabilityWitness *witness,
ConstantInit definition = ConstantInit());

Address getAddrOfObjCISAMask();

/// Retrieve the generic signature for the current generic context, or null if no
Expand Down
14 changes: 14 additions & 0 deletions lib/IRGen/Linking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ std::string LinkEntity::mangleAsString() const {
case Kind::ReflectionAssociatedTypeDescriptor:
return mangler.mangleReflectionAssociatedTypeDescriptor(
getProtocolConformance());
case Kind::DifferentiabilityWitness:
return mangler.mangleSILDifferentiabilityWitnessKey(
{getSILDifferentiabilityWitness()->getOriginalFunction()->getName(),
getSILDifferentiabilityWitness()->getConfig()});
}
llvm_unreachable("bad entity kind!");
}
Expand Down Expand Up @@ -659,6 +663,8 @@ SILLinkage LinkEntity::getLinkage(ForDefinition_t forDefinition) const {
case Kind::ExtensionDescriptor:
case Kind::AnonymousDescriptor:
return SILLinkage::Shared;
case Kind::DifferentiabilityWitness:
return getSILDifferentiabilityWitness()->getLinkage();
}
llvm_unreachable("bad link entity kind");
}
Expand Down Expand Up @@ -783,6 +789,9 @@ bool LinkEntity::isAvailableExternally(IRGenModule &IGM) const {
->getDeclContext()
->getInnermostTypeContext());
}

case Kind::DifferentiabilityWitness:
return true;

case Kind::ObjCMetadataUpdateFunction:
case Kind::ObjCResilientClassStub:
Expand Down Expand Up @@ -904,6 +913,8 @@ llvm::Type *LinkEntity::getDefaultDeclarationType(IRGenModule &IGM) const {
return IGM.ObjCResilientClassStubTy;
}
llvm_unreachable("invalid metadata address");
case Kind::DifferentiabilityWitness:
return IGM.DifferentiabilityWitnessTy;
default:
llvm_unreachable("declaration LLVM type not specified");
}
Expand Down Expand Up @@ -951,6 +962,7 @@ Alignment LinkEntity::getAlignment(IRGenModule &IGM) const {
case Kind::OpaqueTypeDescriptorAccessorKey:
case Kind::OpaqueTypeDescriptorAccessorVar:
case Kind::ObjCResilientClassStub:
case Kind::DifferentiabilityWitness:
return IGM.getPointerAlignment();
case Kind::TypeMetadataDemanglingCacheVariable:
return Alignment(8);
Expand Down Expand Up @@ -1052,6 +1064,7 @@ bool LinkEntity::isWeakImported(ModuleDecl *module) const {
case Kind::ReflectionBuiltinDescriptor:
case Kind::ReflectionFieldDescriptor:
case Kind::CoroutineContinuationPrototype:
case Kind::DifferentiabilityWitness:
return false;
}

Expand Down Expand Up @@ -1181,6 +1194,7 @@ const SourceFile *LinkEntity::getSourceFileForEmission() const {
case Kind::ReflectionBuiltinDescriptor:
case Kind::ValueWitness:
case Kind::ValueWitnessTable:
case Kind::DifferentiabilityWitness:
return nullptr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
// RUN: %empty-directory(%t)
// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name main
// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.sil -module-name main

// NOTE(SR-12090): Workaround because import declarations are not preserved in .sib files.
// RUN: sed -e 's/import Swift$/import Swift; import _Differentiation/' %t/tmp.sil > %t/tmp_fixed.sil
// RUN: %target-sil-opt %t/tmp_fixed.sil -module-name main -emit-sorted-sil | %FileCheck --check-prefix=ROUNDTRIP %s

// IRGen test.

// RUN: %target-swift-frontend -emit-ir %s | %FileCheck --check-prefix=IRGEN %s

// REQUIRES: differentiable_programming
// NOTE(SR-12090): `shell` is required only to run `sed` as a SR-12090 workaround.
// REQUIRES: shell
Expand Down Expand Up @@ -49,6 +52,11 @@ sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@conven
// ROUNDTRIP: vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// ROUNDTRIP: }

// IRGEN-LABEL: @AD__externalFn1_PSRS ={{( protected)?}} global { i8*, i8* } {
// IRGEN-SAME: @AD__externalFn1__jvp_src_0_wrt_0
// IRGEN-SAME: @AD__externalFn1__vjp_src_0_wrt_0
// IRGEN-SAME: }

// Test SIL differentiability witness for bodiless original function, with bodiless jvp/vjp.

sil @externalFn2 : $@convention(thin) (Float) -> Float
Expand All @@ -68,6 +76,11 @@ sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@conven
// ROUNDTRIP: vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// ROUNDTRIP: }

// IRGEN-LABEL: @AD__externalFn2_PSRS ={{( protected)?}} global { i8*, i8* } {
// IRGEN-SAME: @AD__externalFn2__jvp_src_0_wrt_0
// IRGEN-SAME: @AD__externalFn2__vjp_src_0_wrt_0
// IRGEN-SAME: }

// Test SIL differentiability witness declaration.

sil @externalFn3 : $@convention(thin) (Float) -> Float
Expand All @@ -77,6 +90,8 @@ sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@conven
// ROUNDTRIP-LABEL: // differentiability witness for externalFn3
// ROUNDTRIP: sil_differentiability_witness{{( public_external)?}} [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float{{[^{]*$}}

// IRGEN-NOT: @AD__externalFn3{{.*}}={{.*}}{ i8*, i8* }

// Test public non-generic function.
// SIL differentiability witness:
// - Has public linkage (implicit).
Expand Down Expand Up @@ -108,6 +123,11 @@ sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thi
// ROUNDTRIP: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// ROUNDTRIP: }

// IRGEN-LABEL: @AD__foo_PSRS ={{( protected)?}} global { i8*, i8* } {
// IRGEN-SAME: @AD__foo__jvp_src_0_wrt_0
// IRGEN-SAME: @AD__foo__vjp_src_0_wrt_0
// IRGEN-SAME: }

// Test internal generic function.
// SIL differentiability witness:
// - Has hidden linkage.
Expand Down Expand Up @@ -140,3 +160,8 @@ sil_differentiability_witness hidden [parameters 0 1] [results 0] <τ_0_0 where
// ROUNDTRIP: jvp: @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
// ROUNDTRIP: vjp: @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
// ROUNDTRIP: }

// IRGEN-LABEL: @AD__generic_PSSRS16_Differentiation14DifferentiableRzl = hidden global { i8*, i8* } {
// IRGEN-SAME: @AD__generic__jvp_src_0_wrt_0_1
// IRGEN-SAME: @AD__generic__vjp_src_0_wrt_0_1
// IRGEN-SAME: }