Skip to content

Commit bb93af3

Browse files
committed
Add AutoDiffConfig and use in SILDifferentiabilityWitnessKey.
1 parent aea64d3 commit bb93af3

File tree

7 files changed

+110
-35
lines changed

7 files changed

+110
-35
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,41 @@ struct AutoDiffDerivativeFunctionKind {
208208
}
209209
};
210210

211+
/// Identifies an autodiff derivative function configuration:
212+
/// - Parameter indices.
213+
/// - Result indices.
214+
/// - Derivative generic signature (optional).
215+
// TODO(TF-893): Use `AutoDiffConfig` in `AutoDiffDerivativeFunctionIdentifier`
216+
// to avoid duplication.
217+
class AutoDiffConfig : public llvm::FoldingSetNode {
218+
IndexSubset *const parameterIndices;
219+
IndexSubset *const resultIndices;
220+
GenericSignature *derivativeGenericSignature;
221+
222+
AutoDiffConfig(IndexSubset *parameterIndices, IndexSubset *resultIndices,
223+
GenericSignature *derivativeGenericSignature)
224+
: parameterIndices(parameterIndices), resultIndices(resultIndices),
225+
derivativeGenericSignature(derivativeGenericSignature) {}
226+
227+
public:
228+
IndexSubset *getParameterIndices() const { return parameterIndices; }
229+
IndexSubset *getResultIndices() const { return resultIndices; }
230+
GenericSignature *getDerivativeGenericSignature() const {
231+
return derivativeGenericSignature;
232+
}
233+
234+
static AutoDiffConfig *get(IndexSubset *parameterIndices,
235+
IndexSubset *resultIndices,
236+
GenericSignature *derivativeGenericSignature,
237+
ASTContext &C);
238+
239+
void Profile(llvm::FoldingSetNodeID &ID) {
240+
ID.AddPointer(parameterIndices);
241+
ID.AddPointer(resultIndices);
242+
ID.AddPointer(derivativeGenericSignature);
243+
}
244+
};
245+
211246
/// In conjunction with the original function declaration, identifies an
212247
/// autodiff derivative function.
213248
///
@@ -241,9 +276,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
241276
/// The key type used for uniquing `SILDifferentiabilityWitness` in
242277
/// `SILModule`: original function name, parameter indices, result indices, and
243278
/// derivative generic signature.
244-
// TODO(TF-893): Unify with `AutoDiffDerivativeFunctionIdentifier`.
245-
using SILDifferentiabilityWitnessKey =
246-
std::tuple<StringRef, IndexSubset *, IndexSubset *, GenericSignature *>;
279+
using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig *>;
247280

248281
/// Automatic differentiation utility namespace.
249282
namespace autodiff {

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
#include "swift/AST/AutoDiff.h"
3030
#include "swift/AST/GenericSignature.h"
3131
#include "swift/SIL/SILAllocated.h"
32-
#include "swift/SIL/SILInstruction.h"
3332
#include "llvm/ADT/ilist_node.h"
3433
#include "llvm/ADT/ilist.h"
3534

@@ -48,12 +47,9 @@ class SILDifferentiabilityWitness
4847
SILLinkage linkage;
4948
/// The original function.
5049
SILFunction *originalFunction;
51-
/// The parameter indices.
52-
IndexSubset *parameterIndices;
53-
/// The result indices.
54-
IndexSubset *resultIndices;
55-
/// The derivative generic signature (optional).
56-
GenericSignature *derivativeGenericSignature;
50+
/// The autodiff configuration: parameter indices, result indices, and
51+
/// derivative generic signature (optional).
52+
AutoDiffConfig *autoDiffConfig;
5753
/// The JVP (Jacobian-vector products) derivative function.
5854
SILFunction *jvp;
5955
/// The VJP (vector-Jacobian products) derivative function.
@@ -62,6 +58,11 @@ class SILDifferentiabilityWitness
6258
/// devirtualization from another module.
6359
bool serialized;
6460

61+
static AutoDiffConfig *
62+
getAutoDiffConfig(SILModule &module, IndexSubset *parameterIndices,
63+
IndexSubset *resultIndices,
64+
GenericSignature *derivativeGenSig);
65+
6566
SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage,
6667
SILFunction *originalFunction,
6768
IndexSubset *parameterIndices,
@@ -70,9 +71,9 @@ class SILDifferentiabilityWitness
7071
SILFunction *jvp, SILFunction *vjp,
7172
bool isSerialized)
7273
: module(module), linkage(linkage), originalFunction(originalFunction),
73-
parameterIndices(parameterIndices), resultIndices(resultIndices),
74-
derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp),
75-
serialized(isSerialized) {}
74+
autoDiffConfig(getAutoDiffConfig(
75+
module, parameterIndices, resultIndices, derivativeGenSig)),
76+
jvp(jvp), vjp(vjp), serialized(isSerialized) {}
7677

7778
public:
7879
static SILDifferentiabilityWitness *create(
@@ -86,16 +87,29 @@ class SILDifferentiabilityWitness
8687
SILLinkage getLinkage() const { return linkage; }
8788
SILFunction *getOriginalFunction() const { return originalFunction; }
8889
IndexSubset *getParameterIndices() const {
89-
return parameterIndices;
90+
return autoDiffConfig->getParameterIndices();
9091
}
9192
IndexSubset *getResultIndices() const {
92-
return resultIndices;
93+
return autoDiffConfig->getResultIndices();
9394
}
9495
GenericSignature *getDerivativeGenericSignature() const {
95-
return derivativeGenericSignature;
96+
return autoDiffConfig->getDerivativeGenericSignature();
9697
}
9798
SILFunction *getJVP() const { return jvp; }
9899
SILFunction *getVJP() const { return vjp; }
100+
SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const {
101+
switch (kind) {
102+
case AutoDiffDerivativeFunctionKind::JVP: return jvp;
103+
case AutoDiffDerivativeFunctionKind::VJP: return vjp;
104+
}
105+
}
106+
void setDerivative(AutoDiffDerivativeFunctionKind kind,
107+
SILFunction *derivative) {
108+
switch (kind) {
109+
case AutoDiffDerivativeFunctionKind::JVP: jvp = derivative; break;
110+
case AutoDiffDerivativeFunctionKind::VJP: vjp = derivative; break;
111+
}
112+
}
99113
bool isSerialized() const { return serialized; }
100114

101115
/// Verify that the differentiability witness is well-formed.

lib/AST/ASTContext.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,9 +449,13 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL)
449449
/// For uniquifying `IndexSubset` allocations.
450450
llvm::FoldingSet<IndexSubset> IndexSubsets;
451451

452+
/// For uniquifying `AutoDiffConfig` allocations.
453+
llvm::FoldingSet<AutoDiffConfig> AutoDiffConfigs;
454+
452455
/// For uniquifying `AutoDiffDerivativeFunctionIdentifier` allocations.
453456
llvm::FoldingSet<AutoDiffDerivativeFunctionIdentifier>
454457
AutoDiffDerivativeFunctionIdentifiers;
458+
// SWIFT_ENABLE_TENSORFLOW END
455459

456460
/// A cache of information about whether particular nominal types
457461
/// are representable in a foreign language.
@@ -4828,6 +4832,27 @@ IndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) {
48284832
return newNode;
48294833
}
48304834

4835+
AutoDiffConfig *AutoDiffConfig::get(
4836+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
4837+
GenericSignature *derivativeGenericSignature, ASTContext &C) {
4838+
assert(parameterIndices);
4839+
assert(resultIndices);
4840+
auto &foldingSet = C.getImpl().AutoDiffConfigs;
4841+
llvm::FoldingSetNodeID id;
4842+
id.AddPointer(parameterIndices);
4843+
id.AddPointer(resultIndices);
4844+
id.AddPointer(derivativeGenericSignature);
4845+
void *insertPos;
4846+
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
4847+
if (existing)
4848+
return existing;
4849+
void *buf = C.Allocate(sizeof(AutoDiffConfig), alignof(AutoDiffConfig));
4850+
auto *newNode = new (buf) AutoDiffConfig(
4851+
parameterIndices, resultIndices, derivativeGenericSignature);
4852+
foldingSet.InsertNode(newNode, insertPos);
4853+
return newNode;
4854+
}
4855+
48314856
AutoDiffDerivativeFunctionIdentifier *
48324857
AutoDiffDerivativeFunctionIdentifier::get(
48334858
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices,

lib/AST/ASTMangler.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,11 @@ std::string ASTMangler::mangleSILDifferentiabilityWitnessKey(
431431
// TODO(TF-20): Make the mangling scheme robust.
432432
beginManglingWithoutPrefix();
433433

434-
auto originalName = std::get<0>(key);
435-
auto *parameterIndices = std::get<1>(key);
436-
auto *resultIndices = std::get<2>(key);
437-
auto *derivativeGenericSignature = std::get<3>(key);
434+
auto originalName = key.first;
435+
auto *autoDiffConfig = key.second;
436+
auto *parameterIndices = autoDiffConfig->getParameterIndices();
437+
auto *resultIndices = autoDiffConfig->getResultIndices();
438+
auto *derivativeGenericSignature = autoDiffConfig->getDerivativeGenericSignature();
438439

439440
Buffer << "AD__" << originalName << '_';
440441
Buffer << "P" << parameterIndices->getString();

lib/SIL/SILDifferentiabilityWitness.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
3535
return diffWitness;
3636
}
3737

38+
AutoDiffConfig *SILDifferentiabilityWitness::getAutoDiffConfig(
39+
SILModule &module, IndexSubset *parameterIndices,
40+
IndexSubset *resultIndices, GenericSignature *derivativeGenSig) {
41+
return AutoDiffConfig::get(parameterIndices, resultIndices, derivativeGenSig,
42+
module.getASTContext());
43+
}
44+
3845
SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
39-
return std::make_tuple(originalFunction->getName(), parameterIndices,
40-
resultIndices, derivativeGenericSignature);
46+
return std::make_pair(originalFunction->getName(), autoDiffConfig);
4147
}

lib/SIL/SILPrinter.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3068,31 +3068,27 @@ void SILDifferentiabilityWitness::print(
30683068
printLinkage(OS, linkage, ForDefinition);
30693069
// [parameters 0 1 ...]
30703070
OS << "[parameters ";
3071-
interleave(parameterIndices->getIndices(),
3071+
interleave(getParameterIndices()->getIndices(),
30723072
[&](unsigned index) { OS << index; },
30733073
[&] { OS << " "; });
30743074
// [results 0 1 ...]
30753075
OS << "] [results ";
3076-
interleave(resultIndices->getIndices(),
3076+
interleave(getResultIndices()->getIndices(),
30773077
[&](unsigned index) { OS << index; },
30783078
[&] { OS << " "; });
30793079
OS << ']';
30803080
// [where ...]
3081-
if (derivativeGenericSignature) {
3082-
// NOTE: This needs to be changed if there is no utility for parsing
3083-
// generic signatures. Idea: we could instead print the type of the original
3084-
// function substituted into this generic signature.
3081+
if (auto *derivativeGenSig = getDerivativeGenericSignature()) {
30853082
ArrayRef<Requirement> requirements;
30863083
SmallVector<Requirement, 4> requirementsScratch;
30873084
auto *origGenEnv = originalFunction->getGenericEnvironment();
3088-
if (derivativeGenericSignature) {
3085+
if (derivativeGenSig) {
30893086
if (origGenEnv) {
3090-
requirementsScratch =
3091-
derivativeGenericSignature->requirementsNotSatisfiedBy(
3092-
origGenEnv->getGenericSignature());
3087+
requirementsScratch = derivativeGenSig->requirementsNotSatisfiedBy(
3088+
origGenEnv->getGenericSignature());
30933089
requirements = requirementsScratch;
30943090
} else {
3095-
requirements = derivativeGenericSignature->getRequirements();
3091+
requirements = derivativeGenSig->getRequirements();
30963092
}
30973093
}
30983094
if (!requirements.empty()) {

lib/SIL/SILVerifier.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5365,7 +5365,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
53655365
// TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
53665366
// to accept result indices.
53675367
auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType(
5368-
getParameterIndices(), /*resultIndex*/ *resultIndices->begin(),
5368+
getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(),
53695369
AutoDiffDerivativeFunctionKind::JVP, M.Types,
53705370
LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig);
53715371
SILVerifier(*jvp).requireSameType(
@@ -5377,7 +5377,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
53775377
// TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
53785378
// to result indices.
53795379
auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType(
5380-
getParameterIndices(), /*resultIndex*/ *resultIndices->begin(),
5380+
getParameterIndices(), /*resultIndex*/ *getResultIndices()->begin(),
53815381
AutoDiffDerivativeFunctionKind::VJP, M.Types,
53825382
LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig);
53835383
SILVerifier(*vjp).requireSameType(

0 commit comments

Comments
 (0)