Skip to content

Commit f240ed2

Browse files
committed
Change AutoDiffConfig to a POD.
The contents of `AutoDiffConfig` are all uniqued, so uniquing the product does not make sense.
1 parent bb93af3 commit f240ed2

File tree

5 files changed

+57
-76
lines changed

5 files changed

+57
-76
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -212,35 +212,10 @@ struct AutoDiffDerivativeFunctionKind {
212212
/// - Parameter indices.
213213
/// - Result indices.
214214
/// - Derivative generic signature (optional).
215-
// TODO(TF-893): Use `AutoDiffConfig` in `AutoDiffDerivativeFunctionIdentifier`
216-
// to avoid duplication.
217-
class AutoDiffConfig : public llvm::FoldingSetNode {
218-
IndexSubset *const parameterIndices;
219-
IndexSubset *const resultIndices;
215+
struct AutoDiffConfig {
216+
IndexSubset *parameterIndices;
217+
IndexSubset *resultIndices;
220218
GenericSignature *derivativeGenericSignature;
221-
222-
AutoDiffConfig(IndexSubset *parameterIndices, IndexSubset *resultIndices,
223-
GenericSignature *derivativeGenericSignature)
224-
: parameterIndices(parameterIndices), resultIndices(resultIndices),
225-
derivativeGenericSignature(derivativeGenericSignature) {}
226-
227-
public:
228-
IndexSubset *getParameterIndices() const { return parameterIndices; }
229-
IndexSubset *getResultIndices() const { return resultIndices; }
230-
GenericSignature *getDerivativeGenericSignature() const {
231-
return derivativeGenericSignature;
232-
}
233-
234-
static AutoDiffConfig *get(IndexSubset *parameterIndices,
235-
IndexSubset *resultIndices,
236-
GenericSignature *derivativeGenericSignature,
237-
ASTContext &C);
238-
239-
void Profile(llvm::FoldingSetNodeID &ID) {
240-
ID.AddPointer(parameterIndices);
241-
ID.AddPointer(resultIndices);
242-
ID.AddPointer(derivativeGenericSignature);
243-
}
244219
};
245220

246221
/// In conjunction with the original function declaration, identifies an
@@ -253,8 +228,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
253228
IndexSubset *const parameterIndices;
254229

255230
AutoDiffDerivativeFunctionIdentifier(
256-
AutoDiffDerivativeFunctionKind kind,
257-
IndexSubset *parameterIndices) :
231+
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices) :
258232
kind(kind), parameterIndices(parameterIndices) {}
259233

260234
public:
@@ -276,7 +250,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
276250
/// The key type used for uniquing `SILDifferentiabilityWitness` in
277251
/// `SILModule`: original function name, parameter indices, result indices, and
278252
/// derivative generic signature.
279-
using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig *>;
253+
using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;
280254

281255
/// Automatic differentiation utility namespace.
282256
namespace autodiff {
@@ -403,10 +377,44 @@ class VectorSpace {
403377

404378
namespace llvm {
405379

380+
using swift::AutoDiffConfig;
381+
using swift::AutoDiffDerivativeFunctionKind;
382+
using swift::GenericSignature;
383+
using swift::IndexSubset;
406384
using swift::SILAutoDiffIndices;
407385

408386
template<typename T> struct DenseMapInfo;
409387

388+
template<> struct DenseMapInfo<AutoDiffConfig> {
389+
static AutoDiffConfig getEmptyKey() {
390+
auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
391+
return {static_cast<IndexSubset *>(ptr),
392+
static_cast<IndexSubset *>(ptr),
393+
static_cast<GenericSignature *>(ptr)};
394+
}
395+
396+
static AutoDiffConfig getTombstoneKey() {
397+
auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
398+
return {static_cast<IndexSubset *>(ptr),
399+
static_cast<IndexSubset *>(ptr),
400+
static_cast<GenericSignature *>(ptr)};
401+
}
402+
403+
static unsigned getHashValue(const AutoDiffConfig &Val) {
404+
unsigned combinedHash = hash_combine(
405+
~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
406+
DenseMapInfo<void *>::getHashValue(Val.resultIndices),
407+
DenseMapInfo<void *>::getHashValue(Val.derivativeGenericSignature));
408+
return combinedHash;
409+
}
410+
411+
static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
412+
return LHS.parameterIndices == RHS.parameterIndices &&
413+
LHS.resultIndices == RHS.resultIndices &&
414+
LHS.derivativeGenericSignature == RHS.derivativeGenericSignature;
415+
}
416+
};
417+
410418
template<> struct DenseMapInfo<SILAutoDiffIndices> {
411419
static SILAutoDiffIndices getEmptyKey() {
412420
return { DenseMapInfo<unsigned>::getEmptyKey(), nullptr };

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@ class SILDifferentiabilityWitness
4747
SILLinkage linkage;
4848
/// The original function.
4949
SILFunction *originalFunction;
50-
/// The autodiff configuration: parameter indices, result indices, and
51-
/// derivative generic signature (optional).
52-
AutoDiffConfig *autoDiffConfig;
50+
/// The parameter indices.
51+
IndexSubset *parameterIndices;
52+
/// The result indices.
53+
IndexSubset *resultIndices;
54+
/// The derivative generic signature (optional).
55+
GenericSignature *derivativeGenericSignature;
5356
/// The JVP (Jacobian-vector products) derivative function.
5457
SILFunction *jvp;
5558
/// The VJP (vector-Jacobian products) derivative function.
@@ -71,9 +74,9 @@ class SILDifferentiabilityWitness
7174
SILFunction *jvp, SILFunction *vjp,
7275
bool isSerialized)
7376
: module(module), linkage(linkage), originalFunction(originalFunction),
74-
autoDiffConfig(getAutoDiffConfig(
75-
module, parameterIndices, resultIndices, derivativeGenSig)),
76-
jvp(jvp), vjp(vjp), serialized(isSerialized) {}
77+
parameterIndices(parameterIndices), resultIndices(resultIndices),
78+
derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp),
79+
serialized(isSerialized) {}
7780

7881
public:
7982
static SILDifferentiabilityWitness *create(
@@ -87,13 +90,13 @@ class SILDifferentiabilityWitness
8790
SILLinkage getLinkage() const { return linkage; }
8891
SILFunction *getOriginalFunction() const { return originalFunction; }
8992
IndexSubset *getParameterIndices() const {
90-
return autoDiffConfig->getParameterIndices();
93+
return parameterIndices;
9194
}
9295
IndexSubset *getResultIndices() const {
93-
return autoDiffConfig->getResultIndices();
96+
return resultIndices;
9497
}
9598
GenericSignature *getDerivativeGenericSignature() const {
96-
return autoDiffConfig->getDerivativeGenericSignature();
99+
return derivativeGenericSignature;
97100
}
98101
SILFunction *getJVP() const { return jvp; }
99102
SILFunction *getVJP() const { return vjp; }

lib/AST/ASTContext.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -449,9 +449,6 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL)
449449
/// For uniquifying `IndexSubset` allocations.
450450
llvm::FoldingSet<IndexSubset> IndexSubsets;
451451

452-
/// For uniquifying `AutoDiffConfig` allocations.
453-
llvm::FoldingSet<AutoDiffConfig> AutoDiffConfigs;
454-
455452
/// For uniquifying `AutoDiffDerivativeFunctionIdentifier` allocations.
456453
llvm::FoldingSet<AutoDiffDerivativeFunctionIdentifier>
457454
AutoDiffDerivativeFunctionIdentifiers;
@@ -4832,27 +4829,6 @@ IndexSubset::get(ASTContext &ctx, const SmallBitVector &indices) {
48324829
return newNode;
48334830
}
48344831

4835-
AutoDiffConfig *AutoDiffConfig::get(
4836-
IndexSubset *parameterIndices, IndexSubset *resultIndices,
4837-
GenericSignature *derivativeGenericSignature, ASTContext &C) {
4838-
assert(parameterIndices);
4839-
assert(resultIndices);
4840-
auto &foldingSet = C.getImpl().AutoDiffConfigs;
4841-
llvm::FoldingSetNodeID id;
4842-
id.AddPointer(parameterIndices);
4843-
id.AddPointer(resultIndices);
4844-
id.AddPointer(derivativeGenericSignature);
4845-
void *insertPos;
4846-
auto *existing = foldingSet.FindNodeOrInsertPos(id, insertPos);
4847-
if (existing)
4848-
return existing;
4849-
void *buf = C.Allocate(sizeof(AutoDiffConfig), alignof(AutoDiffConfig));
4850-
auto *newNode = new (buf) AutoDiffConfig(
4851-
parameterIndices, resultIndices, derivativeGenericSignature);
4852-
foldingSet.InsertNode(newNode, insertPos);
4853-
return newNode;
4854-
}
4855-
48564832
AutoDiffDerivativeFunctionIdentifier *
48574833
AutoDiffDerivativeFunctionIdentifier::get(
48584834
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices,

lib/AST/ASTMangler.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,9 @@ std::string ASTMangler::mangleSILDifferentiabilityWitnessKey(
432432
beginManglingWithoutPrefix();
433433

434434
auto originalName = key.first;
435-
auto *autoDiffConfig = key.second;
436-
auto *parameterIndices = autoDiffConfig->getParameterIndices();
437-
auto *resultIndices = autoDiffConfig->getResultIndices();
438-
auto *derivativeGenericSignature = autoDiffConfig->getDerivativeGenericSignature();
435+
auto *parameterIndices = key.second.parameterIndices;
436+
auto *resultIndices = key.second.resultIndices;
437+
auto *derivativeGenericSignature = key.second.derivativeGenericSignature;
439438

440439
Buffer << "AD__" << originalName << '_';
441440
Buffer << "P" << parameterIndices->getString();

lib/SIL/SILDifferentiabilityWitness.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,8 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
3535
return diffWitness;
3636
}
3737

38-
AutoDiffConfig *SILDifferentiabilityWitness::getAutoDiffConfig(
39-
SILModule &module, IndexSubset *parameterIndices,
40-
IndexSubset *resultIndices, GenericSignature *derivativeGenSig) {
41-
return AutoDiffConfig::get(parameterIndices, resultIndices, derivativeGenSig,
42-
module.getASTContext());
43-
}
44-
4538
SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
46-
return std::make_pair(originalFunction->getName(), autoDiffConfig);
39+
AutoDiffConfig config{parameterIndices, resultIndices,
40+
derivativeGenericSignature};
41+
return std::make_pair(originalFunction->getName(), config);
4742
}

0 commit comments

Comments
 (0)