Skip to content

Commit a5dc918

Browse files
authored
[AutoDiff] Add SIL differentiability witnesses. (#27487)
SIL differentiability witnesses are a new top-level SIL construct mapping "original" SIL functions to derivative SIL functions. They will replace SIL function `[differentiable]` attributes, additionally enabling cross-module retroactive derivative registration. SIL differentiability witnesses have the following components: - Original `SILFunction`. - Linkage. - Parameter indices (`IndexSubset`). - Result indices (`IndexSubset`). - Derivative generic signature (optional). - JVP `SILFunction` (optional). - VJP `SILFunction` (optional). - "Is serialized?" bit. This patch adds the `SILDifferentiabilityWitness` data structure, along with parsing, printing, verification, and serialization (including lookup by key). The TF-866 master issue tracks follow-up, including SILGen and differentiation transform changes.
1 parent 9c76a29 commit a5dc918

24 files changed

+1135
-17
lines changed

include/swift/AST/ASTMangler.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,23 +155,31 @@ class ASTMangler : public Mangler {
155155
ModuleDecl *Module);
156156

157157
// SWIFT_ENABLE_TENSORFLOW
158-
// Mangle the derivative function (JVP/VJP) with the given:
159-
// - Mangled original function name.
160-
// - Derivative function kind.
161-
// - Parameter/result indices.
158+
/// Mangle the derivative function (JVP/VJP) with the given:
159+
/// - Mangled original function name.
160+
/// - Derivative function kind.
161+
/// - Parameter/result indices.
162162
std::string mangleAutoDiffDerivativeFunctionHelper(
163163
StringRef name, AutoDiffDerivativeFunctionKind kind,
164164
const SILAutoDiffIndices &indices);
165165

166-
// SWIFT_ENABLE_TENSORFLOW
167-
// Mangle the autodiff linear map (differential/pullback) with the given:
168-
// - Mangled original function name.
169-
// - Linear map kind.
170-
// - Parameter/result indices.
166+
/// Mangle the autodiff linear map (differential/pullback) with the given:
167+
/// - Mangled original function name.
168+
/// - Linear map kind.
169+
/// - Parameter/result indices.
171170
std::string mangleAutoDiffLinearMapHelper(
172171
StringRef name, AutoDiffLinearMapKind kind,
173172
const SILAutoDiffIndices &indices);
174173

174+
/// Mangle a SIL differentiability witness key.
175+
/// - Mangled original function name.
176+
/// - Parameter indices.
177+
/// - Result indices.
178+
/// - Derivative generic signature (optional).
179+
std::string mangleSILDifferentiabilityWitnessKey(
180+
SILDifferentiabilityWitnessKey key);
181+
// SWIFT_ENABLE_TENSORFLOW END
182+
175183
std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property,
176184
GenericSignature *signature,
177185
CanType baseType,

include/swift/AST/AutoDiff.h

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,16 @@ struct AutoDiffDerivativeFunctionKind {
208208
}
209209
};
210210

211+
/// Identifies an autodiff derivative function configuration:
212+
/// - Parameter indices.
213+
/// - Result indices.
214+
/// - Derivative generic signature (optional).
215+
struct AutoDiffConfig {
216+
IndexSubset *parameterIndices;
217+
IndexSubset *resultIndices;
218+
GenericSignature *derivativeGenericSignature;
219+
};
220+
211221
/// In conjunction with the original function declaration, identifies an
212222
/// autodiff derivative function.
213223
///
@@ -218,8 +228,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
218228
IndexSubset *const parameterIndices;
219229

220230
AutoDiffDerivativeFunctionIdentifier(
221-
AutoDiffDerivativeFunctionKind kind,
222-
IndexSubset *parameterIndices) :
231+
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices) :
223232
kind(kind), parameterIndices(parameterIndices) {}
224233

225234
public:
@@ -238,6 +247,11 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
238247
}
239248
};
240249

250+
/// The key type used for uniquing `SILDifferentiabilityWitness` in
251+
/// `SILModule`: original function name, parameter indices, result indices, and
252+
/// derivative generic signature.
253+
using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;
254+
241255
/// Automatic differentiation utility namespace.
242256
namespace autodiff {
243257
/// Appends the subset's parameter's types to `result`, in the order in
@@ -363,10 +377,42 @@ class VectorSpace {
363377

364378
namespace llvm {
365379

380+
using swift::AutoDiffConfig;
381+
using swift::AutoDiffDerivativeFunctionKind;
382+
using swift::GenericSignature;
383+
using swift::IndexSubset;
366384
using swift::SILAutoDiffIndices;
367385

368386
template<typename T> struct DenseMapInfo;
369387

388+
template<> struct DenseMapInfo<AutoDiffConfig> {
389+
static AutoDiffConfig getEmptyKey() {
390+
auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
391+
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
392+
static_cast<GenericSignature *>(ptr)};
393+
}
394+
395+
static AutoDiffConfig getTombstoneKey() {
396+
auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
397+
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
398+
static_cast<GenericSignature *>(ptr)};
399+
}
400+
401+
static unsigned getHashValue(const AutoDiffConfig &Val) {
402+
unsigned combinedHash = hash_combine(
403+
~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
404+
DenseMapInfo<void *>::getHashValue(Val.resultIndices),
405+
DenseMapInfo<void *>::getHashValue(Val.derivativeGenericSignature));
406+
return combinedHash;
407+
}
408+
409+
static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
410+
return LHS.parameterIndices == RHS.parameterIndices &&
411+
LHS.resultIndices == RHS.resultIndices &&
412+
LHS.derivativeGenericSignature == RHS.derivativeGenericSignature;
413+
}
414+
};
415+
370416
template<> struct DenseMapInfo<SILAutoDiffIndices> {
371417
static SILAutoDiffIndices getEmptyKey() {
372418
return { DenseMapInfo<unsigned>::getEmptyKey(), nullptr };

include/swift/AST/DiagnosticsParse.def

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,16 @@ ERROR(sil_witness_assoc_conf_not_found,none,
686686
ERROR(sil_witness_protocol_conformance_not_found,none,
687687
"sil protocol conformance not found", ())
688688

689+
// SIL differentiability witnesses
690+
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
691+
"expected '%0' in differentiability witness", (StringRef))
692+
ERROR(sil_diff_witness_expected_index_list,PointsToFirstBadToken,
693+
"expected a space-separated list of indices, e.g. '0 1'", ())
694+
ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken,
695+
"expected a parameter index to differentiate with respect to", ())
696+
ERROR(sil_diff_witness_expected_result_index,PointsToFirstBadToken,
697+
"expected a result index to differentiate with respect to", ())
698+
689699
// SIL Coverage Map
690700
ERROR(sil_coverage_func_not_found, none,
691701
"sil function not found %0", (Identifier))

include/swift/Parse/ParseSILSupport.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ namespace swift {
3232
virtual bool parseSILGlobal(Parser &P) = 0;
3333
virtual bool parseSILWitnessTable(Parser &P) = 0;
3434
virtual bool parseSILDefaultWitnessTable(Parser &P) = 0;
35+
// SWIFT_ENABLE_TENSORFLOW
36+
virtual bool parseSILDifferentiabilityWitness(Parser &P) = 0;
37+
// SWIFT_ENABLE_TENSORFLOW END
3538
virtual bool parseSILCoverageMap(Parser &P) = 0;
3639
virtual bool parseSILProperty(Parser &P) = 0;
3740
virtual bool parseSILScope(Parser &P) = 0;
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
//===--- SILDifferentiabilityWitness.h - Differentiability witnesses ------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file defines the SILDifferentiabilityWitness class, which maps an
14+
// original SILFunction and derivative configuration (parameter indices, result
15+
// indices, derivative generic signature) to derivative functions (JVP and VJP).
16+
//
17+
// SIL differentiability witnesses are generated from the `@differentiable`
18+
// and `@differentiating` attributes AST declaration attributes.
19+
// Differentiability witnesses are canonicalized by the differentiation SIL
20+
// transform, which fills in missing derivative functions. Canonical
21+
// differentiability witnesses from other modules can be deserialized to look up
22+
// derivative functions.
23+
//
24+
//===----------------------------------------------------------------------===//
25+
26+
#ifndef SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
27+
#define SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
28+
29+
#include "swift/AST/Attr.h"
30+
#include "swift/AST/AutoDiff.h"
31+
#include "swift/AST/GenericSignature.h"
32+
#include "swift/SIL/SILAllocated.h"
33+
#include "llvm/ADT/ilist_node.h"
34+
#include "llvm/ADT/ilist.h"
35+
36+
namespace swift {
37+
38+
class SILPrintContext;
39+
40+
class SILDifferentiabilityWitness
41+
: public llvm::ilist_node<SILDifferentiabilityWitness>,
42+
public SILAllocated<SILDifferentiabilityWitness>
43+
{
44+
private:
45+
/// The module which contains the differentiability witness.
46+
SILModule &module;
47+
/// The linkage of the differentiability witness.
48+
SILLinkage linkage;
49+
/// The original function.
50+
SILFunction *originalFunction;
51+
/// The parameter indices.
52+
IndexSubset *parameterIndices;
53+
/// The result indices.
54+
IndexSubset *resultIndices;
55+
/// The derivative generic signature (optional).
56+
GenericSignature *derivativeGenericSignature;
57+
/// The JVP (Jacobian-vector products) derivative function.
58+
SILFunction *jvp;
59+
/// The VJP (vector-Jacobian products) derivative function.
60+
SILFunction *vjp;
61+
/// Whether or not this differentiability witness is serialized, which allows
62+
/// devirtualization from another module.
63+
bool serialized;
64+
/// The AST `@differentiable` or `@differentiating` attribute from which the
65+
/// differentiability witness is generated. Used for diagnostics.
66+
/// Null if the differentiability witness is parsed from SIL or if it is
67+
/// deserialized.
68+
DeclAttribute *attribute = nullptr;
69+
70+
SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage,
71+
SILFunction *originalFunction,
72+
IndexSubset *parameterIndices,
73+
IndexSubset *resultIndices,
74+
GenericSignature *derivativeGenSig,
75+
SILFunction *jvp, SILFunction *vjp,
76+
bool isSerialized, DeclAttribute *attribute)
77+
: module(module), linkage(linkage), originalFunction(originalFunction),
78+
parameterIndices(parameterIndices), resultIndices(resultIndices),
79+
derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp),
80+
serialized(isSerialized), attribute(attribute) {}
81+
82+
public:
83+
static SILDifferentiabilityWitness *create(
84+
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
85+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
86+
GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
87+
bool isSerialized, DeclAttribute *attribute = nullptr);
88+
89+
SILDifferentiabilityWitnessKey getKey() const;
90+
SILModule &getModule() const { return module; }
91+
SILLinkage getLinkage() const { return linkage; }
92+
SILFunction *getOriginalFunction() const { return originalFunction; }
93+
IndexSubset *getParameterIndices() const {
94+
return parameterIndices;
95+
}
96+
IndexSubset *getResultIndices() const {
97+
return resultIndices;
98+
}
99+
GenericSignature *getDerivativeGenericSignature() const {
100+
return derivativeGenericSignature;
101+
}
102+
SILFunction *getJVP() const { return jvp; }
103+
SILFunction *getVJP() const { return vjp; }
104+
SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const {
105+
switch (kind) {
106+
case AutoDiffDerivativeFunctionKind::JVP: return jvp;
107+
case AutoDiffDerivativeFunctionKind::VJP: return vjp;
108+
}
109+
}
110+
void setJVP(SILFunction *jvp) { this->jvp = jvp; }
111+
void setVJP(SILFunction *vjp) { this->vjp = vjp; }
112+
void setDerivative(AutoDiffDerivativeFunctionKind kind,
113+
SILFunction *derivative) {
114+
switch (kind) {
115+
case AutoDiffDerivativeFunctionKind::JVP: jvp = derivative; break;
116+
case AutoDiffDerivativeFunctionKind::VJP: vjp = derivative; break;
117+
}
118+
}
119+
bool isSerialized() const { return serialized; }
120+
DeclAttribute *getAttribute() const { return attribute; }
121+
122+
/// Verify that the differentiability witness is well-formed.
123+
void verify(const SILModule &module) const;
124+
125+
void print(llvm::raw_ostream &os, bool verbose = false) const;
126+
void dump() const;
127+
};
128+
129+
} // end namespace swift
130+
131+
namespace llvm {
132+
133+
//===----------------------------------------------------------------------===//
134+
// ilist_traits for SILDifferentiabilityWitness
135+
//===----------------------------------------------------------------------===//
136+
137+
template <>
138+
struct ilist_traits<::swift::SILDifferentiabilityWitness>
139+
: public ilist_node_traits<::swift::SILDifferentiabilityWitness> {
140+
using SILDifferentiabilityWitness = ::swift::SILDifferentiabilityWitness;
141+
142+
public:
143+
static void deleteNode(SILDifferentiabilityWitness *DW) {
144+
DW->~SILDifferentiabilityWitness();
145+
}
146+
147+
private:
148+
void createNode(const SILDifferentiabilityWitness &);
149+
};
150+
151+
} // namespace llvm
152+
153+
#endif // SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H

include/swift/SIL/SILModule.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include "swift/SIL/SILCoverageMap.h"
2929
#include "swift/SIL/SILDeclRef.h"
3030
#include "swift/SIL/SILDefaultWitnessTable.h"
31+
// SWIFT_ENABLE_TENSORFLOW
32+
#include "swift/SIL/SILDifferentiabilityWitness.h"
3133
#include "swift/SIL/SILFunction.h"
3234
#include "swift/SIL/SILGlobalVariable.h"
3335
#include "swift/SIL/SILPrintContext.h"
@@ -113,6 +115,10 @@ class SILModule {
113115
using PropertyListType = llvm::ilist<SILProperty>;
114116
using WitnessTableListType = llvm::ilist<SILWitnessTable>;
115117
using DefaultWitnessTableListType = llvm::ilist<SILDefaultWitnessTable>;
118+
// SWIFT_ENABLE_TENSORFLOW
119+
using DifferentiabilityWitnessListType =
120+
llvm::ilist<SILDifferentiabilityWitness>;
121+
// SWIFT_ENABLE_TENSORFLOW END
116122
using CoverageMapCollectionType =
117123
llvm::MapVector<StringRef, SILCoverageMap *>;
118124

@@ -139,6 +145,9 @@ class SILModule {
139145
friend SILProperty;
140146
friend SILUndef;
141147
friend SILWitnessTable;
148+
// SWIFT_ENABLE_TENSORFLOW
149+
friend SILDifferentiabilityWitness;
150+
// SWIFT_ENABLE_TENSORFLOW END
142151
friend Lowering::SILGenModule;
143152
friend Lowering::TypeConverter;
144153
class SerializationCallback;
@@ -194,6 +203,17 @@ class SILModule {
194203
/// The list of SILDefaultWitnessTables in the module.
195204
DefaultWitnessTableListType defaultWitnessTables;
196205

206+
// SWIFT_ENABLE_TENSORFLOW
207+
/// Lookup table for SIL differentiability witnesses from original functions.
208+
/// Indexed by key type: original function, parameter indices, result indices,
209+
/// and derivative generic signature.
210+
llvm::DenseMap<SILDifferentiabilityWitnessKey, SILDifferentiabilityWitness *>
211+
DifferentiabilityWitnessMap;
212+
213+
/// The list of SILDifferentiabilityWitnesses in the module.
214+
DifferentiabilityWitnessListType differentiabilityWitnesses;
215+
// SWIFT_ENABLE_TENSORFLOW END
216+
197217
/// Lookup table for SIL Global Variables.
198218
llvm::StringMap<SILGlobalVariable *> GlobalVariableMap;
199219

@@ -446,6 +466,27 @@ class SILModule {
446466
return {defaultWitnessTables.begin(), defaultWitnessTables.end()};
447467
}
448468

469+
// SWIFT_ENABLE_TENSORFLOW
470+
using differentiability_witness_iterator = DifferentiabilityWitnessListType::iterator;
471+
using differentiability_witness_const_iterator = DifferentiabilityWitnessListType::const_iterator;
472+
DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() { return differentiabilityWitnesses; }
473+
const DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() const { return differentiabilityWitnesses; }
474+
differentiability_witness_iterator differentiability_witness_begin() { return differentiabilityWitnesses.begin(); }
475+
differentiability_witness_iterator differentiability_witness_end() { return differentiabilityWitnesses.end(); }
476+
differentiability_witness_const_iterator differentiability_witness_begin() const { return differentiabilityWitnesses.begin(); }
477+
differentiability_witness_const_iterator differentiability_witness_end() const { return differentiabilityWitnesses.end(); }
478+
iterator_range<differentiability_witness_iterator>
479+
getDifferentiabilityWitnesses() {
480+
return {differentiabilityWitnesses.begin(),
481+
differentiabilityWitnesses.end()};
482+
}
483+
iterator_range<differentiability_witness_const_iterator>
484+
getDifferentiabilityWitnesses() const {
485+
return {differentiabilityWitnesses.begin(),
486+
differentiabilityWitnesses.end()};
487+
}
488+
// SWIFT_ENABLE_TENSORFLOW END
489+
449490
using sil_global_iterator = GlobalListType::iterator;
450491
using sil_global_const_iterator = GlobalListType::const_iterator;
451492
GlobalListType &getSILGlobalList() { return silGlobals; }

0 commit comments

Comments
 (0)