@@ -212,35 +212,10 @@ struct AutoDiffDerivativeFunctionKind {
212
212
// / - Parameter indices.
213
213
// / - Result indices.
214
214
// / - Derivative generic signature (optional).
215
- // TODO(TF-893): Use `AutoDiffConfig` in `AutoDiffDerivativeFunctionIdentifier`
216
- // to avoid duplication.
217
- class AutoDiffConfig : public llvm ::FoldingSetNode {
218
- IndexSubset *const parameterIndices;
219
- IndexSubset *const resultIndices;
215
+ struct AutoDiffConfig {
216
+ IndexSubset *parameterIndices;
217
+ IndexSubset *resultIndices;
220
218
GenericSignature *derivativeGenericSignature;
221
-
222
- AutoDiffConfig (IndexSubset *parameterIndices, IndexSubset *resultIndices,
223
- GenericSignature *derivativeGenericSignature)
224
- : parameterIndices(parameterIndices), resultIndices(resultIndices),
225
- derivativeGenericSignature (derivativeGenericSignature) {}
226
-
227
- public:
228
- IndexSubset *getParameterIndices () const { return parameterIndices; }
229
- IndexSubset *getResultIndices () const { return resultIndices; }
230
- GenericSignature *getDerivativeGenericSignature () const {
231
- return derivativeGenericSignature;
232
- }
233
-
234
- static AutoDiffConfig *get (IndexSubset *parameterIndices,
235
- IndexSubset *resultIndices,
236
- GenericSignature *derivativeGenericSignature,
237
- ASTContext &C);
238
-
239
- void Profile (llvm::FoldingSetNodeID &ID) {
240
- ID.AddPointer (parameterIndices);
241
- ID.AddPointer (resultIndices);
242
- ID.AddPointer (derivativeGenericSignature);
243
- }
244
219
};
245
220
246
221
// / In conjunction with the original function declaration, identifies an
@@ -253,8 +228,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
253
228
IndexSubset *const parameterIndices;
254
229
255
230
AutoDiffDerivativeFunctionIdentifier (
256
- AutoDiffDerivativeFunctionKind kind,
257
- IndexSubset *parameterIndices) :
231
+ AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices) :
258
232
kind (kind), parameterIndices(parameterIndices) {}
259
233
260
234
public:
@@ -276,7 +250,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
276
250
// / The key type used for uniquing `SILDifferentiabilityWitness` in
277
251
// / `SILModule`: original function name, parameter indices, result indices, and
278
252
// / derivative generic signature.
279
- using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig * >;
253
+ using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;
280
254
281
255
// / Automatic differentiation utility namespace.
282
256
namespace autodiff {
@@ -403,10 +377,44 @@ class VectorSpace {
403
377
404
378
namespace llvm {
405
379
380
+ using swift::AutoDiffConfig;
381
+ using swift::AutoDiffDerivativeFunctionKind;
382
+ using swift::GenericSignature;
383
+ using swift::IndexSubset;
406
384
using swift::SILAutoDiffIndices;
407
385
408
386
template <typename T> struct DenseMapInfo ;
409
387
388
+ template <> struct DenseMapInfo <AutoDiffConfig> {
389
+ static AutoDiffConfig getEmptyKey () {
390
+ auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey ();
391
+ return {static_cast <IndexSubset *>(ptr),
392
+ static_cast <IndexSubset *>(ptr),
393
+ static_cast <GenericSignature *>(ptr)};
394
+ }
395
+
396
+ static AutoDiffConfig getTombstoneKey () {
397
+ auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey ();
398
+ return {static_cast <IndexSubset *>(ptr),
399
+ static_cast <IndexSubset *>(ptr),
400
+ static_cast <GenericSignature *>(ptr)};
401
+ }
402
+
403
+ static unsigned getHashValue (const AutoDiffConfig &Val) {
404
+ unsigned combinedHash = hash_combine (
405
+ ~1U , DenseMapInfo<void *>::getHashValue (Val.parameterIndices ),
406
+ DenseMapInfo<void *>::getHashValue (Val.resultIndices ),
407
+ DenseMapInfo<void *>::getHashValue (Val.derivativeGenericSignature ));
408
+ return combinedHash;
409
+ }
410
+
411
+ static bool isEqual (const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
412
+ return LHS.parameterIndices == RHS.parameterIndices &&
413
+ LHS.resultIndices == RHS.resultIndices &&
414
+ LHS.derivativeGenericSignature == RHS.derivativeGenericSignature ;
415
+ }
416
+ };
417
+
410
418
template <> struct DenseMapInfo <SILAutoDiffIndices> {
411
419
static SILAutoDiffIndices getEmptyKey () {
412
420
return { DenseMapInfo<unsigned >::getEmptyKey (), nullptr };
0 commit comments