Skip to content

Commit fbfb67a

Browse files
authored
[AutoDiff upstream] Define AutoDiffConfig. (#29099)
Define `AutoDiffConfig`: the configuration for a derivative function: - Parameter indices. - Result indices. - Derivative generic signature (optional). Progress towards TF-828: upstream `@differentiable` attribute type-checking.
1 parent 90d94c8 commit fbfb67a

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <cstdint>
2121

22+
#include "swift/AST/GenericSignature.h"
2223
#include "swift/AST/Identifier.h"
2324
#include "swift/AST/IndexSubset.h"
2425
#include "swift/AST/Type.h"
@@ -70,6 +71,25 @@ struct AutoDiffDerivativeFunctionKind {
7071
}
7172
};
7273

74+
/// Identifies an autodiff derivative function configuration:
75+
/// - Parameter indices.
76+
/// - Result indices.
77+
/// - Derivative generic signature (optional).
78+
struct AutoDiffConfig {
79+
IndexSubset *parameterIndices;
80+
IndexSubset *resultIndices;
81+
GenericSignature derivativeGenericSignature;
82+
83+
/*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
84+
IndexSubset *resultIndices,
85+
GenericSignature derivativeGenericSignature)
86+
: parameterIndices(parameterIndices), resultIndices(resultIndices),
87+
derivativeGenericSignature(derivativeGenericSignature) {}
88+
89+
void print(llvm::raw_ostream &s = llvm::outs()) const;
90+
SWIFT_DEBUG_DUMP;
91+
};
92+
7393
class ParsedAutoDiffParameter {
7494
public:
7595
enum class Kind { Named, Ordered, Self };
@@ -148,10 +168,59 @@ void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
148168

149169
namespace llvm {
150170

171+
using swift::AutoDiffConfig;
151172
using swift::AutoDiffDerivativeFunctionKind;
173+
using swift::GenericSignature;
174+
using swift::IndexSubset;
152175

153176
template <typename T> struct DenseMapInfo;
154177

178+
template <> struct DenseMapInfo<AutoDiffConfig> {
179+
static AutoDiffConfig getEmptyKey() {
180+
auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
181+
// The `derivativeGenericSignature` component must be `nullptr` so that
182+
// `getHashValue` and `isEqual` do not try to call
183+
// `GenericSignatureImpl::getCanonicalSignature()` on an invalid pointer.
184+
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
185+
nullptr};
186+
}
187+
188+
static AutoDiffConfig getTombstoneKey() {
189+
auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
190+
// The `derivativeGenericSignature` component must be `nullptr` so that
191+
// `getHashValue` and `isEqual` do not try to call
192+
// `GenericSignatureImpl::getCanonicalSignature()` on an invalid pointer.
193+
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
194+
nullptr};
195+
}
196+
197+
static unsigned getHashValue(const AutoDiffConfig &Val) {
198+
auto canGenSig =
199+
Val.derivativeGenericSignature
200+
? Val.derivativeGenericSignature->getCanonicalSignature()
201+
: nullptr;
202+
unsigned combinedHash = hash_combine(
203+
~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
204+
DenseMapInfo<void *>::getHashValue(Val.resultIndices),
205+
DenseMapInfo<GenericSignature>::getHashValue(canGenSig));
206+
return combinedHash;
207+
}
208+
209+
static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
210+
auto lhsCanGenSig =
211+
LHS.derivativeGenericSignature
212+
? LHS.derivativeGenericSignature->getCanonicalSignature()
213+
: nullptr;
214+
auto rhsCanGenSig =
215+
RHS.derivativeGenericSignature
216+
? RHS.derivativeGenericSignature->getCanonicalSignature()
217+
: nullptr;
218+
return LHS.parameterIndices == RHS.parameterIndices &&
219+
LHS.resultIndices == RHS.resultIndices &&
220+
DenseMapInfo<GenericSignature>::isEqual(lhsCanGenSig, rhsCanGenSig);
221+
}
222+
};
223+
155224
template <> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
156225
static AutoDiffDerivativeFunctionKind getEmptyKey() {
157226
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(

0 commit comments

Comments
 (0)