Skip to content

Commit a174243

Browse files
authored
[AutoDiff upstream] Add SIL differentiability witness IRGen. (#29704)
SIL differentiability witnesses are a new top-level SIL construct mapping an "original" SIL function and derivative configuration to derivative SIL functions. This patch adds `SILDifferentiabilityWitness` IRGen. `SILDifferentiabilityWitness` has a fixed `{ i8*, i8* }` layout: JVP and VJP derivative function pointers. Resolves TF-1146.
1 parent b4386f4 commit a174243

File tree

11 files changed

+197
-2
lines changed

11 files changed

+197
-2
lines changed

include/swift/AST/PrettyStackTrace.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,24 @@ class PrettyStackTraceSelector : public llvm::PrettyStackTraceEntry {
202202
void print(llvm::raw_ostream &OS) const override;
203203
};
204204

205+
/// PrettyStackTraceDifferentiabilityWitness - Observe that we are processing a
206+
/// specific differentiability witness.
207+
class PrettyStackTraceDifferentiabilityWitness
208+
: public llvm::PrettyStackTraceEntry {
209+
const SILDifferentiabilityWitnessKey Key;
210+
const char *Action;
211+
212+
public:
213+
PrettyStackTraceDifferentiabilityWitness(
214+
const char *action, const SILDifferentiabilityWitnessKey key)
215+
: Key(key), Action(action) {}
216+
virtual void print(llvm::raw_ostream &OS) const;
217+
};
218+
219+
void printDifferentiabilityWitnessDescription(
220+
llvm::raw_ostream &out, const SILDifferentiabilityWitnessKey key,
221+
bool addNewline = true);
222+
205223
} // end namespace swift
206224

207225
#endif

include/swift/IRGen/Linking.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ class LinkEntity {
346346
/// ProtocolConformance*.
347347
ProtocolWitnessTableLazyCacheVariable,
348348

349+
/// A SIL differentiability witness.
350+
DifferentiabilityWitness,
351+
349352
// Everything following this is a type kind.
350353

351354
/// A value witness for a type.
@@ -535,6 +538,14 @@ class LinkEntity {
535538
return getAssociatedConformanceByIndex(conformance->getProtocol(), index);
536539
}
537540

541+
void
542+
setForDifferentiabilityWitness(Kind kind,
543+
const SILDifferentiabilityWitness *witness) {
544+
Pointer = const_cast<void *>(static_cast<const void *>(witness));
545+
SecondaryPointer = nullptr;
546+
Data = LINKENTITY_SET_FIELD(Kind, unsigned(kind));
547+
}
548+
538549
void setForType(Kind kind, CanType type) {
539550
assert(isTypeKind(kind));
540551
Pointer = type.getPointer();
@@ -835,6 +846,14 @@ class LinkEntity {
835846
return entity;
836847
}
837848

849+
static LinkEntity
850+
forDifferentiabilityWitness(const SILDifferentiabilityWitness *witness) {
851+
LinkEntity entity;
852+
entity.setForDifferentiabilityWitness(Kind::DifferentiabilityWitness,
853+
witness);
854+
return entity;
855+
}
856+
838857
static LinkEntity forProtocolWitnessTable(const RootProtocolConformance *C) {
839858
LinkEntity entity;
840859
entity.setForProtocolConformance(Kind::ProtocolWitnessTable, C);
@@ -1043,6 +1062,11 @@ class LinkEntity {
10431062
return reinterpret_cast<SILGlobalVariable*>(Pointer);
10441063
}
10451064

1065+
SILDifferentiabilityWitness *getSILDifferentiabilityWitness() const {
1066+
assert(getKind() == Kind::DifferentiabilityWitness);
1067+
return reinterpret_cast<SILDifferentiabilityWitness *>(Pointer);
1068+
}
1069+
10461070
const RootProtocolConformance *getRootProtocolConformance() const {
10471071
assert(isRootProtocolConformanceKind(getKind()));
10481072
return cast<RootProtocolConformance>(getProtocolConformance());

lib/AST/AutoDiff.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@
1717

1818
using namespace swift;
1919

20+
void AutoDiffConfig::print(llvm::raw_ostream &s) const {
21+
s << "(parameters=";
22+
parameterIndices->print(s);
23+
s << " results=";
24+
resultIndices->print(s);
25+
if (derivativeGenericSignature) {
26+
s << " where=";
27+
derivativeGenericSignature->print(s);
28+
}
29+
s << ')';
30+
}
31+
2032
// TODO(TF-874): This helper is inefficient and should be removed. Unwrapping at
2133
// most once (for curried method types) is sufficient.
2234
static void unwrapCurryLevels(AnyFunctionType *fnTy,

lib/AST/PrettyStackTrace.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,18 @@ void PrettyStackTraceGenericSignature::print(llvm::raw_ostream &out) const {
273273
void PrettyStackTraceSelector::print(llvm::raw_ostream &out) const {
274274
out << "While " << Action << " '" << Selector << "'";
275275
}
276+
277+
void PrettyStackTraceDifferentiabilityWitness::print(
278+
llvm::raw_ostream &out) const {
279+
out << "While " << Action << ' ';
280+
printDifferentiabilityWitnessDescription(out, Key);
281+
}
282+
283+
void swift::printDifferentiabilityWitnessDescription(
284+
llvm::raw_ostream &out, const SILDifferentiabilityWitnessKey key,
285+
bool addNewline) {
286+
out << key.first << " ";
287+
key.second.print(out);
288+
if (addNewline)
289+
out << '\n';
290+
}

lib/IRGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_swift_host_library(swiftIRGen STATIC
1616
GenControl.cpp
1717
GenCoverage.cpp
1818
GenDecl.cpp
19+
GenDiffWitness.cpp
1920
GenEnum.cpp
2021
GenExistential.cpp
2122
GenFunc.cpp

lib/IRGen/GenDecl.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,21 @@ void IRGenerator::emitGlobalTopLevel(llvm::StringSet<> *linkerDirectives) {
10731073
CurrentIGMPtr IGM = getGenModule(prop.getDecl()->getInnermostDeclContext());
10741074
IGM->emitSILProperty(&prop);
10751075
}
1076-
1076+
1077+
// Emit differentiability witnesses.
1078+
for (auto &dw :
1079+
PrimaryIGM->getSILModule().getDifferentiabilityWitnessList()) {
1080+
// Emit into same IRGenModule as the original function.
1081+
// NOTE(TF-894): Investigate whether `getGenModule(dw.getVJP())` is
1082+
// significant/desirable; `getGenModule` seems relevant for multi-threaded
1083+
// compilation. When the differentiation transform canonicalizes all
1084+
// differentiability witnesses to have JVP/VJP functions, we can assert
1085+
// that JVP/VJP functions exist and use `getGenModule(dw.getVJP())`.
1086+
CurrentIGMPtr IGM = getGenModule(dw.getOriginalFunction());
1087+
1088+
IGM->emitSILDifferentiabilityWitness(&dw);
1089+
}
1090+
10771091
// Emit code coverage mapping data.
10781092
PrimaryIGM->emitCoverageMapping();
10791093

@@ -4495,6 +4509,13 @@ IRGenModule::getAddrOfWitnessTablePattern(const NormalProtocolConformance *conf,
44954509
return getAddrOfLLVMVariable(entity, definition, DebugTypeInfo());
44964510
}
44974511

4512+
/// Look up the address of a differentiability witness.
4513+
llvm::Constant *IRGenModule::getAddrOfDifferentiabilityWitness(
4514+
const SILDifferentiabilityWitness *witness, ConstantInit definition) {
4515+
auto entity = LinkEntity::forDifferentiabilityWitness(witness);
4516+
return getAddrOfLLVMVariable(entity, definition, DebugTypeInfo());
4517+
}
4518+
44984519
llvm::Function *
44994520
IRGenModule::getAddrOfAssociatedTypeWitnessTableAccessFunction(
45004521
const NormalProtocolConformance *conformance,

lib/IRGen/GenDiffWitness.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===--- GenDiffWitness.cpp - IRGen for differentiability witnesses -------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file implements IR generation for SIL differentiability witnesses.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#include "swift/AST/PrettyStackTrace.h"
18+
#include "swift/SIL/SILDifferentiabilityWitness.h"
19+
20+
#include "ConstantBuilder.h"
21+
#include "IRGenModule.h"
22+
23+
using namespace swift;
24+
using namespace irgen;
25+
26+
void IRGenModule::emitSILDifferentiabilityWitness(
27+
SILDifferentiabilityWitness *dw) {
28+
PrettyStackTraceDifferentiabilityWitness _st(
29+
"emitting differentiability witness for", dw->getKey());
30+
31+
// Don't emit declarations.
32+
if (dw->isDeclaration())
33+
return;
34+
35+
// Don't emit `public_external` witnesses.
36+
if (dw->getLinkage() == SILLinkage::PublicExternal)
37+
return;
38+
39+
ConstantInitBuilder builder(*this);
40+
auto diffWitnessContents = builder.beginStruct();
41+
42+
assert(dw->getJVP() &&
43+
"Differentiability witness definition should have JVP");
44+
assert(dw->getVJP() &&
45+
"Differentiability witness definition should have VJP");
46+
47+
diffWitnessContents.addBitCast(
48+
getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy);
49+
diffWitnessContents.addBitCast(
50+
getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy);
51+
52+
getAddrOfDifferentiabilityWitness(
53+
dw, diffWitnessContents.finishAndCreateFuture());
54+
}

lib/IRGen/IRGenModule.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,9 @@ IRGenModule::IRGenModule(IRGenerator &irgen,
524524

525525
DynamicReplacementKeyTy = createStructType(*this, "swift.dyn_repl_key",
526526
{RelativeAddressTy, Int32Ty});
527+
528+
DifferentiabilityWitnessTy = createStructType(
529+
*this, "swift.differentiability_witness", {Int8PtrTy, Int8PtrTy});
527530
}
528531

529532
IRGenModule::~IRGenModule() {

lib/IRGen/IRGenModule.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ namespace swift {
9595
class RootProtocolConformance;
9696
struct SILDeclRef;
9797
class SILDefaultWitnessTable;
98+
class SILDifferentiabilityWitness;
9899
class SILGlobalVariable;
99100
class SILModule;
100101
class SILProperty;
@@ -672,6 +673,8 @@ class IRGenModule {
672673
*DynamicReplacementLinkEntryPtrTy; // %link_entry*
673674
llvm::StructType *DynamicReplacementKeyTy; // { i32, i32}
674675

676+
llvm::StructType *DifferentiabilityWitnessTy; // { i8*, i8* }
677+
675678
llvm::GlobalVariable *TheTrivialPropertyDescriptor = nullptr;
676679

677680
/// Used to create unique names for class layout types with tail allocated
@@ -1272,6 +1275,7 @@ private: \
12721275
void emitSILFunction(SILFunction *f);
12731276
void emitSILWitnessTable(SILWitnessTable *wt);
12741277
void emitSILProperty(SILProperty *prop);
1278+
void emitSILDifferentiabilityWitness(SILDifferentiabilityWitness *dw);
12751279
void emitSILStaticInitializers();
12761280
llvm::Constant *emitFixedTypeLayout(CanType t, const FixedTypeInfo &ti);
12771281
void emitProtocolConformance(const ConformanceDescription &record);
@@ -1463,6 +1467,10 @@ private: \
14631467
llvm::Function *getAddrOfDefaultAssociatedConformanceAccessor(
14641468
AssociatedConformance requirement);
14651469

1470+
llvm::Constant *
1471+
getAddrOfDifferentiabilityWitness(const SILDifferentiabilityWitness *witness,
1472+
ConstantInit definition = ConstantInit());
1473+
14661474
Address getAddrOfObjCISAMask();
14671475

14681476
/// Retrieve the generic signature for the current generic context, or null if no

lib/IRGen/Linking.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,10 @@ std::string LinkEntity::mangleAsString() const {
414414
case Kind::ReflectionAssociatedTypeDescriptor:
415415
return mangler.mangleReflectionAssociatedTypeDescriptor(
416416
getProtocolConformance());
417+
case Kind::DifferentiabilityWitness:
418+
return mangler.mangleSILDifferentiabilityWitnessKey(
419+
{getSILDifferentiabilityWitness()->getOriginalFunction()->getName(),
420+
getSILDifferentiabilityWitness()->getConfig()});
417421
}
418422
llvm_unreachable("bad entity kind!");
419423
}
@@ -659,6 +663,8 @@ SILLinkage LinkEntity::getLinkage(ForDefinition_t forDefinition) const {
659663
case Kind::ExtensionDescriptor:
660664
case Kind::AnonymousDescriptor:
661665
return SILLinkage::Shared;
666+
case Kind::DifferentiabilityWitness:
667+
return getSILDifferentiabilityWitness()->getLinkage();
662668
}
663669
llvm_unreachable("bad link entity kind");
664670
}
@@ -783,6 +789,9 @@ bool LinkEntity::isAvailableExternally(IRGenModule &IGM) const {
783789
->getDeclContext()
784790
->getInnermostTypeContext());
785791
}
792+
793+
case Kind::DifferentiabilityWitness:
794+
return true;
786795

787796
case Kind::ObjCMetadataUpdateFunction:
788797
case Kind::ObjCResilientClassStub:
@@ -904,6 +913,8 @@ llvm::Type *LinkEntity::getDefaultDeclarationType(IRGenModule &IGM) const {
904913
return IGM.ObjCResilientClassStubTy;
905914
}
906915
llvm_unreachable("invalid metadata address");
916+
case Kind::DifferentiabilityWitness:
917+
return IGM.DifferentiabilityWitnessTy;
907918
default:
908919
llvm_unreachable("declaration LLVM type not specified");
909920
}
@@ -951,6 +962,7 @@ Alignment LinkEntity::getAlignment(IRGenModule &IGM) const {
951962
case Kind::OpaqueTypeDescriptorAccessorKey:
952963
case Kind::OpaqueTypeDescriptorAccessorVar:
953964
case Kind::ObjCResilientClassStub:
965+
case Kind::DifferentiabilityWitness:
954966
return IGM.getPointerAlignment();
955967
case Kind::TypeMetadataDemanglingCacheVariable:
956968
return Alignment(8);
@@ -1052,6 +1064,7 @@ bool LinkEntity::isWeakImported(ModuleDecl *module) const {
10521064
case Kind::ReflectionBuiltinDescriptor:
10531065
case Kind::ReflectionFieldDescriptor:
10541066
case Kind::CoroutineContinuationPrototype:
1067+
case Kind::DifferentiabilityWitness:
10551068
return false;
10561069
}
10571070

@@ -1181,6 +1194,7 @@ const SourceFile *LinkEntity::getSourceFileForEmission() const {
11811194
case Kind::ReflectionBuiltinDescriptor:
11821195
case Kind::ValueWitness:
11831196
case Kind::ValueWitnessTable:
1197+
case Kind::DifferentiabilityWitness:
11841198
return nullptr;
11851199
}
11861200

test/AutoDiff/SIL/Serialization/sil_differentiability_witness.sil

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
// RUN: %empty-directory(%t)
88
// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name main
99
// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.sil -module-name main
10-
1110
// NOTE(SR-12090): Workaround because import declarations are not preserved in .sib files.
1211
// RUN: sed -e 's/import Swift$/import Swift; import _Differentiation/' %t/tmp.sil > %t/tmp_fixed.sil
1312
// RUN: %target-sil-opt %t/tmp_fixed.sil -module-name main -emit-sorted-sil | %FileCheck --check-prefix=ROUNDTRIP %s
1413

14+
// IRGen test.
15+
16+
// RUN: %target-swift-frontend -emit-ir %s | %FileCheck --check-prefix=IRGEN %s
17+
1518
// REQUIRES: differentiable_programming
1619
// NOTE(SR-12090): `shell` is required only to run `sed` as a SR-12090 workaround.
1720
// REQUIRES: shell
@@ -49,6 +52,11 @@ sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@conven
4952
// ROUNDTRIP: vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
5053
// ROUNDTRIP: }
5154

55+
// IRGEN-LABEL: @AD__externalFn1_PSRS ={{( protected)?}} global { i8*, i8* } {
56+
// IRGEN-SAME: @AD__externalFn1__jvp_src_0_wrt_0
57+
// IRGEN-SAME: @AD__externalFn1__vjp_src_0_wrt_0
58+
// IRGEN-SAME: }
59+
5260
// Test SIL differentiability witness for bodiless original function, with bodiless jvp/vjp.
5361

5462
sil @externalFn2 : $@convention(thin) (Float) -> Float
@@ -68,6 +76,11 @@ sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@conven
6876
// ROUNDTRIP: vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
6977
// ROUNDTRIP: }
7078

79+
// IRGEN-LABEL: @AD__externalFn2_PSRS ={{( protected)?}} global { i8*, i8* } {
80+
// IRGEN-SAME: @AD__externalFn2__jvp_src_0_wrt_0
81+
// IRGEN-SAME: @AD__externalFn2__vjp_src_0_wrt_0
82+
// IRGEN-SAME: }
83+
7184
// Test SIL differentiability witness declaration.
7285

7386
sil @externalFn3 : $@convention(thin) (Float) -> Float
@@ -77,6 +90,8 @@ sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@conven
7790
// ROUNDTRIP-LABEL: // differentiability witness for externalFn3
7891
// ROUNDTRIP: sil_differentiability_witness{{( public_external)?}} [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float{{[^{]*$}}
7992

93+
// IRGEN-NOT: @AD__externalFn3{{.*}}={{.*}}{ i8*, i8* }
94+
8095
// Test public non-generic function.
8196
// SIL differentiability witness:
8297
// - Has public linkage (implicit).
@@ -108,6 +123,11 @@ sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thi
108123
// ROUNDTRIP: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
109124
// ROUNDTRIP: }
110125

126+
// IRGEN-LABEL: @AD__foo_PSRS ={{( protected)?}} global { i8*, i8* } {
127+
// IRGEN-SAME: @AD__foo__jvp_src_0_wrt_0
128+
// IRGEN-SAME: @AD__foo__vjp_src_0_wrt_0
129+
// IRGEN-SAME: }
130+
111131
// Test internal generic function.
112132
// SIL differentiability witness:
113133
// - Has hidden linkage.
@@ -140,3 +160,8 @@ sil_differentiability_witness hidden [parameters 0 1] [results 0] <τ_0_0 where
140160
// ROUNDTRIP: jvp: @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
141161
// ROUNDTRIP: vjp: @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
142162
// ROUNDTRIP: }
163+
164+
// IRGEN-LABEL: @AD__generic_PSSRS16_Differentiation14DifferentiableRzl = hidden global { i8*, i8* } {
165+
// IRGEN-SAME: @AD__generic__jvp_src_0_wrt_0_1
166+
// IRGEN-SAME: @AD__generic__vjp_src_0_wrt_0_1
167+
// IRGEN-SAME: }

0 commit comments

Comments
 (0)