Skip to content

Commit 9f7ca52

Browse files
committed
[AutoDiff upstream] Add SIL differentiability witnesses.
SIL differentiability witnesses are a new top-level SIL construct mapping "original" SIL functions to derivative SIL functions. SIL differentiability witnesses have the following components: - "Original" `SILFunction`. - SIL linkage. - Differentiability parameter indices (`IndexSubset`). - Differentiability result indices (`IndexSubset`). - Derivative `GenericSignature` representing differentiability generic requirements (optional). - JVP derivative `SILFunction` (optional). - VJP derivative `SILFunction` (optional). - "Is serialized?" bit. This patch adds the `SILDifferentiabilityWitness` data structure, with documentation, parsing, and printing. Resolves TF-911. Todos: - TF-1136: upstream `SILDifferentiabilityWitness` serialization. - TF-1137: upstream `SILDifferentiabilityWitness` verification. - TF-1138: upstream `SILDifferentiabilityWitness` SILGen from `@differentiable` and `@derivative` attributes. - TF-20: robust mangling for `SILDifferentiabilityWitness` names.
1 parent 7934cdf commit 9f7ca52

18 files changed

+916
-50
lines changed

docs/SIL.rst

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,66 @@ variable cannot be used as l-value, i.e. the reference to the object cannot be
13161316
modified. As a consequence the variable cannot be accessed with ``global_addr``
13171317
but only with ``global_value``.
13181318

1319+
Differentiability Witnesses
1320+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
1321+
::
1322+
1323+
decl ::= sil-differentiability-witness
1324+
sil-differentiability-witness ::=
1325+
'sil_differentiability_witness'
1326+
sil-linkage?
1327+
'[' 'parameters' sil-differentiability-witness-function-index-list ']'
1328+
'[' 'results' sil-differentiability-witness-function-index-list ']'
1329+
generic-parameter-clause?
1330+
sil-function-name ':' sil-type
1331+
sil-differentiability-witness-body?
1332+
1333+
sil-differentiability-witness-body ::=
1334+
'{' sil-differentiability-witness-entry?
1335+
sil-differentiability-witness-entry? '}'
1336+
1337+
sil-differentiability-witness-entry ::=
1338+
sil-differentiability-witness-entry-kind ':'
1339+
sil-entry-name ':' sil-type
1340+
1341+
sil-differentiability-witness-entry-kind ::= 'jvp' | 'vjp'
1342+
1343+
SIL encodes function differentiability via differentiability witnesses.
1344+
1345+
Differentiability witnesses map a "key" (including an "original" SIL function)
1346+
to derivative SIL functions.
1347+
1348+
Differentiability witnesses are keyed by the following:
1349+
1350+
- An "original" SIL function name.
1351+
- Differentiability parameter indices.
1352+
- Differentiability result indices.
1353+
- A generic parameter clause, representing differentiability generic
1354+
requirements.
1355+
1356+
Differentiability witnesses may have a body, specifying derivative functions for
1357+
the key. Verification checks that derivative functions have the expected type
1358+
based on the key.
1359+
1360+
::
1361+
1362+
sil_differentiability_witness hidden [parameters 0] [results 0] <T where T : Differentiable> @id : $@convention(thin) (T) -> T {
1363+
jvp: @id_jvp : $@convention(thin) (T) -> (T, @owned @callee_guaranteed (T.TangentVector) -> T.TangentVector)
1364+
vjp: @id_vjp : $@convention(thin) (T) -> (T, @owned @callee_guaranteed (T.TangentVector) -> T.TangentVector)
1365+
}
1366+
1367+
During SILGen, differentiability witnesses are emitted for the following:
1368+
1369+
- `@differentiable` declaration attributes.
1370+
- `@derivative` declaration attributes. Registered derivative functions
1371+
become differentiability witness entries.
1372+
1373+
The SIL differentiation transform canonicalizes differentiability witnesses,
1374+
filling in missing entries.
1375+
1376+
Differentiability witness entries are accessed via the
1377+
`differentiability_witness_function` instruction.
1378+
13191379
Dataflow Errors
13201380
---------------
13211381

include/swift/AST/ASTMangler.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,15 @@ class ASTMangler : public Mangler {
153153
Type FromType, Type ToType,
154154
Type SelfType,
155155
ModuleDecl *Module);
156-
156+
157+
/// Mangle a SIL differentiability witness key:
158+
/// - Mangled original function name.
159+
/// - Parameter indices.
160+
/// - Result indices.
161+
/// - Derivative generic signature (optional).
162+
std::string
163+
mangleSILDifferentiabilityWitnessKey(SILDifferentiabilityWitnessKey key);
164+
157165
std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property,
158166
GenericSignature signature,
159167
CanType baseType,

include/swift/AST/AutoDiff.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ class TangentSpace {
222222
NominalTypeDecl *getNominal() const;
223223
};
224224

225+
/// The key type used for uniquing `SILDifferentiabilityWitness` in
226+
/// `SILModule`: original function name, parameter indices, result indices, and
227+
/// derivative generic signature.
228+
using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;
229+
225230
/// Automatic differentiation utility namespace.
226231
namespace autodiff {
227232

include/swift/AST/DiagnosticsParse.def

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,18 @@ ERROR(sil_witness_assoc_conf_not_found,none,
678678
ERROR(sil_witness_protocol_conformance_not_found,none,
679679
"sil protocol conformance not found", ())
680680

681+
// SIL differentiability witnesses
682+
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
683+
"expected '%0' in differentiability witness", (StringRef))
684+
ERROR(sil_diff_witness_serialized_declaration,none,
685+
"differentiability witness declaration should not be serialized", ())
686+
ERROR(sil_diff_witness_undefined,PointsToFirstBadToken,
687+
"reference to undefined differentiability witness", ())
688+
ERROR(sil_diff_witness_invalid_generic_signature,PointsToFirstBadToken,
689+
"expected witness generic signature '%0' does not have same generic "
690+
"parameters as original function generic signature '%1'",
691+
(StringRef, StringRef))
692+
681693
// SIL Coverage Map
682694
ERROR(sil_coverage_invalid_hash, none,
683695
"expected coverage hash", ())
@@ -1577,6 +1589,20 @@ ERROR(diff_params_clause_expected_parameter_unnamed,PointsToFirstBadToken,
15771589
ERROR(autodiff_attr_expected_original_decl_name,PointsToFirstBadToken,
15781590
"expected an original function name", ())
15791591

1592+
// SIL autodiff
1593+
ERROR(sil_autodiff_expected_lsquare,PointsToFirstBadToken,
1594+
"expected '[' to start the %0", (StringRef))
1595+
ERROR(sil_autodiff_expected_rsquare,PointsToFirstBadToken,
1596+
"expected ']' to complete the %0", (StringRef))
1597+
ERROR(sil_autodiff_expected_index_list,PointsToFirstBadToken,
1598+
"expected a space-separated list of indices, e.g. '0 1'", ())
1599+
ERROR(sil_autodiff_expected_index_list_label,PointsToFirstBadToken,
1600+
"expected label '%0' in index list", (StringRef))
1601+
ERROR(sil_autodiff_expected_parameter_index,PointsToFirstBadToken,
1602+
"expected the index of a parameter to differentiate with respect to", ())
1603+
ERROR(sil_autodiff_expected_result_index,PointsToFirstBadToken,
1604+
"expected the index of a result to differentiate from", ())
1605+
15801606
//------------------------------------------------------------------------------
15811607
// MARK: Generics parsing diagnostics
15821608
//------------------------------------------------------------------------------

include/swift/Parse/ParseSILSupport.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ namespace swift {
3232
virtual bool parseSILGlobal(Parser &P) = 0;
3333
virtual bool parseSILWitnessTable(Parser &P) = 0;
3434
virtual bool parseSILDefaultWitnessTable(Parser &P) = 0;
35+
virtual bool parseSILDifferentiabilityWitness(Parser &P) = 0;
3536
virtual bool parseSILCoverageMap(Parser &P) = 0;
3637
virtual bool parseSILProperty(Parser &P) = 0;
3738
virtual bool parseSILScope(Parser &P) = 0;
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
//===--- SILDifferentiabilityWitness.h - Differentiability witnesses ------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file defines the SILDifferentiabilityWitness class, which maps an
14+
// original SILFunction and derivative configuration (parameter indices, result
15+
// indices, derivative generic signature) to derivative functions (JVP and VJP).
16+
//
17+
// SIL differentiability witnesses are generated from the `@differentiable`
18+
// and `@derivative` AST declaration attributes.
19+
//
20+
// Differentiability witnesses are canonicalized by the SIL differentiation
21+
// transform, which fills in missing derivative functions.
22+
//
23+
//===----------------------------------------------------------------------===//
24+
25+
#ifndef SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
26+
#define SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
27+
28+
#include "swift/AST/Attr.h"
29+
#include "swift/AST/AutoDiff.h"
30+
#include "swift/AST/GenericSignature.h"
31+
#include "swift/SIL/SILAllocated.h"
32+
#include "swift/SIL/SILLinkage.h"
33+
#include "llvm/ADT/ilist.h"
34+
#include "llvm/ADT/ilist_node.h"
35+
36+
namespace swift {
37+
38+
class SILPrintContext;
39+
40+
class SILDifferentiabilityWitness
41+
: public llvm::ilist_node<SILDifferentiabilityWitness>,
42+
public SILAllocated<SILDifferentiabilityWitness> {
43+
private:
44+
/// The module which contains the differentiability witness.
45+
SILModule &Module;
46+
/// The linkage of the differentiability witness.
47+
SILLinkage Linkage;
48+
/// The original function.
49+
SILFunction *OriginalFunction;
50+
/// The derivative configuration: parameter indices, result indices, and
51+
/// derivative generic signature (optional). The derivative generic signature
52+
/// may contain same-type requirements such that all generic parameters are
53+
/// bound to concrete types.
54+
AutoDiffConfig Config;
55+
/// The JVP (Jacobian-vector products) derivative function.
56+
SILFunction *JVP;
57+
/// The VJP (vector-Jacobian products) derivative function.
58+
SILFunction *VJP;
59+
/// Whether or not this differentiability witness is a declaration.
60+
bool IsDeclaration;
61+
/// Whether or not this differentiability witness is serialized, which allows
62+
/// devirtualization from another module.
63+
bool IsSerialized;
64+
/// The AST `@differentiable` or `@derivative` attribute from which the
65+
/// differentiability witness is generated. Used for diagnostics.
66+
/// Null if the differentiability witness is parsed from SIL or if it is
67+
/// deserialized.
68+
const DeclAttribute *Attribute = nullptr;
69+
70+
SILDifferentiabilityWitness(
71+
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
72+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
73+
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
74+
bool isDeclaration, bool isSerialized, const DeclAttribute *attribute)
75+
: Module(module), Linkage(linkage), OriginalFunction(originalFunction),
76+
Config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
77+
JVP(jvp), VJP(vjp), IsDeclaration(isDeclaration),
78+
IsSerialized(isSerialized), Attribute(attribute) {}
79+
80+
public:
81+
static SILDifferentiabilityWitness *
82+
createDeclaration(SILModule &module, SILLinkage linkage,
83+
SILFunction *originalFunction,
84+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
85+
GenericSignature derivativeGenSig,
86+
const DeclAttribute *attribute = nullptr);
87+
88+
static SILDifferentiabilityWitness *createDefinition(
89+
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
90+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
91+
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
92+
bool isSerialized, const DeclAttribute *attribute = nullptr);
93+
94+
void convertToDefinition(SILFunction *jvp, SILFunction *vjp,
95+
bool isSerialized);
96+
97+
SILDifferentiabilityWitnessKey getKey() const;
98+
SILModule &getModule() const { return Module; }
99+
SILLinkage getLinkage() const { return Linkage; }
100+
SILFunction *getOriginalFunction() const { return OriginalFunction; }
101+
const AutoDiffConfig &getConfig() const { return Config; }
102+
IndexSubset *getParameterIndices() const { return Config.parameterIndices; }
103+
IndexSubset *getResultIndices() const { return Config.resultIndices; }
104+
GenericSignature getDerivativeGenericSignature() const {
105+
return Config.derivativeGenericSignature;
106+
}
107+
SILFunction *getJVP() const { return JVP; }
108+
SILFunction *getVJP() const { return VJP; }
109+
SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const {
110+
switch (kind) {
111+
case AutoDiffDerivativeFunctionKind::JVP:
112+
return JVP;
113+
case AutoDiffDerivativeFunctionKind::VJP:
114+
return VJP;
115+
}
116+
}
117+
void setJVP(SILFunction *jvp) { JVP = jvp; }
118+
void setVJP(SILFunction *vjp) { VJP = vjp; }
119+
void setDerivative(AutoDiffDerivativeFunctionKind kind,
120+
SILFunction *derivative) {
121+
switch (kind) {
122+
case AutoDiffDerivativeFunctionKind::JVP:
123+
JVP = derivative;
124+
break;
125+
case AutoDiffDerivativeFunctionKind::VJP:
126+
VJP = derivative;
127+
break;
128+
}
129+
}
130+
bool isDeclaration() const { return IsDeclaration; }
131+
bool isDefinition() const { return !IsDeclaration; }
132+
bool isSerialized() const { return IsSerialized; }
133+
const DeclAttribute *getAttribute() const { return Attribute; }
134+
135+
/// Verify that the differentiability witness is well-formed.
136+
void verify(const SILModule &module) const;
137+
138+
void print(llvm::raw_ostream &os, bool verbose = false) const;
139+
void dump() const;
140+
};
141+
142+
} // end namespace swift
143+
144+
namespace llvm {
145+
146+
//===----------------------------------------------------------------------===//
147+
// ilist_traits for SILDifferentiabilityWitness
148+
//===----------------------------------------------------------------------===//
149+
150+
template <>
151+
struct ilist_traits<::swift::SILDifferentiabilityWitness>
152+
: public ilist_node_traits<::swift::SILDifferentiabilityWitness> {
153+
using SILDifferentiabilityWitness = ::swift::SILDifferentiabilityWitness;
154+
155+
public:
156+
static void deleteNode(SILDifferentiabilityWitness *DW) {
157+
DW->~SILDifferentiabilityWitness();
158+
}
159+
160+
private:
161+
void createNode(const SILDifferentiabilityWitness &);
162+
};
163+
164+
} // namespace llvm
165+
166+
#endif // SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H

include/swift/SIL/SILModule.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "swift/SIL/SILCoverageMap.h"
2929
#include "swift/SIL/SILDeclRef.h"
3030
#include "swift/SIL/SILDefaultWitnessTable.h"
31+
#include "swift/SIL/SILDifferentiabilityWitness.h"
3132
#include "swift/SIL/SILFunction.h"
3233
#include "swift/SIL/SILGlobalVariable.h"
3334
#include "swift/SIL/SILPrintContext.h"
@@ -113,6 +114,8 @@ class SILModule {
113114
using PropertyListType = llvm::ilist<SILProperty>;
114115
using WitnessTableListType = llvm::ilist<SILWitnessTable>;
115116
using DefaultWitnessTableListType = llvm::ilist<SILDefaultWitnessTable>;
117+
using DifferentiabilityWitnessListType =
118+
llvm::ilist<SILDifferentiabilityWitness>;
116119
using CoverageMapCollectionType =
117120
llvm::MapVector<StringRef, SILCoverageMap *>;
118121

@@ -131,6 +134,7 @@ class SILModule {
131134
friend SILBasicBlock;
132135
friend SILCoverageMap;
133136
friend SILDefaultWitnessTable;
137+
friend SILDifferentiabilityWitness;
134138
friend SILFunction;
135139
friend SILGlobalVariable;
136140
friend SILLayout;
@@ -194,6 +198,17 @@ class SILModule {
194198
/// The list of SILDefaultWitnessTables in the module.
195199
DefaultWitnessTableListType defaultWitnessTables;
196200

201+
/// Lookup table for SIL differentiability witnesses, keyed by mangled name.
202+
llvm::StringMap<SILDifferentiabilityWitness *> DifferentiabilityWitnessMap;
203+
204+
/// Lookup table for SILDifferentiabilityWitnesses, keyed by original
205+
/// function name.
206+
llvm::StringMap<llvm::SmallVector<SILDifferentiabilityWitness *, 1>>
207+
DifferentiabilityWitnessesByFunction;
208+
209+
/// The list of SILDifferentiabilityWitnesses in the module.
210+
DifferentiabilityWitnessListType differentiabilityWitnesses;
211+
197212
/// Declarations which are externally visible.
198213
///
199214
/// These are method declarations which are referenced from inlinable
@@ -455,6 +470,24 @@ class SILModule {
455470
return {defaultWitnessTables.begin(), defaultWitnessTables.end()};
456471
}
457472

473+
using differentiability_witness_iterator = DifferentiabilityWitnessListType::iterator;
474+
using differentiability_witness_const_iterator = DifferentiabilityWitnessListType::const_iterator;
475+
DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() { return differentiabilityWitnesses; }
476+
const DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() const { return differentiabilityWitnesses; } differentiability_witness_iterator differentiability_witness_begin() { return differentiabilityWitnesses.begin(); }
477+
differentiability_witness_iterator differentiability_witness_end() { return differentiabilityWitnesses.end(); }
478+
differentiability_witness_const_iterator differentiability_witness_begin() const { return differentiabilityWitnesses.begin(); }
479+
differentiability_witness_const_iterator differentiability_witness_end() const { return differentiabilityWitnesses.end(); }
480+
iterator_range<differentiability_witness_iterator>
481+
getDifferentiabilityWitnesses() {
482+
return {differentiabilityWitnesses.begin(),
483+
differentiabilityWitnesses.end()};
484+
}
485+
iterator_range<differentiability_witness_const_iterator>
486+
getDifferentiabilityWitnesses() const {
487+
return {differentiabilityWitnesses.begin(),
488+
differentiabilityWitnesses.end()};
489+
}
490+
458491
void addExternallyVisibleDecl(ValueDecl *decl) {
459492
externallyVisible.insert(decl);
460493
}
@@ -591,6 +624,17 @@ class SILModule {
591624
/// hierarchy of \p Class.
592625
SILFunction *lookUpFunctionInVTable(ClassDecl *Class, SILDeclRef Member);
593626

627+
/// Look up the differentiability witness with the given name.
628+
SILDifferentiabilityWitness *lookUpDifferentiabilityWitness(StringRef name);
629+
630+
/// Look up the differentiability witness corresponding to the given key.
631+
SILDifferentiabilityWitness *
632+
lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key);
633+
634+
/// Look up the differentiability witness corresponding to the given function.
635+
llvm::ArrayRef<SILDifferentiabilityWitness *>
636+
lookUpDifferentiabilityWitnessesForFunction(StringRef name);
637+
594638
// Given a protocol, attempt to create a default witness table declaration
595639
// for it.
596640
SILDefaultWitnessTable *

0 commit comments

Comments
 (0)