|
19 | 19 |
|
20 | 20 | #include <cstdint>
|
21 | 21 |
|
| 22 | +#include "swift/AST/GenericSignature.h" |
22 | 23 | #include "swift/AST/Identifier.h"
|
23 | 24 | #include "swift/AST/IndexSubset.h"
|
24 | 25 | #include "swift/AST/Type.h"
|
@@ -70,6 +71,25 @@ struct AutoDiffDerivativeFunctionKind {
|
70 | 71 | }
|
71 | 72 | };
|
72 | 73 |
|
| 74 | +/// Identifies an autodiff derivative function configuration: |
| 75 | +/// - Parameter indices. |
| 76 | +/// - Result indices. |
| 77 | +/// - Derivative generic signature (optional). |
| 78 | +struct AutoDiffConfig { |
| 79 | + IndexSubset *parameterIndices; |
| 80 | + IndexSubset *resultIndices; |
| 81 | + GenericSignature derivativeGenericSignature; |
| 82 | + |
| 83 | + /*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices, |
| 84 | + IndexSubset *resultIndices, |
| 85 | + GenericSignature derivativeGenericSignature) |
| 86 | + : parameterIndices(parameterIndices), resultIndices(resultIndices), |
| 87 | + derivativeGenericSignature(derivativeGenericSignature) {} |
| 88 | + |
| 89 | + void print(llvm::raw_ostream &s = llvm::outs()) const; |
| 90 | + SWIFT_DEBUG_DUMP; |
| 91 | +}; |
| 92 | + |
73 | 93 | class ParsedAutoDiffParameter {
|
74 | 94 | public:
|
75 | 95 | enum class Kind { Named, Ordered, Self };
|
@@ -148,10 +168,59 @@ void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
|
148 | 168 |
|
149 | 169 | namespace llvm {
|
150 | 170 |
|
| 171 | +using swift::AutoDiffConfig; |
151 | 172 | using swift::AutoDiffDerivativeFunctionKind;
|
| 173 | +using swift::GenericSignature; |
| 174 | +using swift::IndexSubset; |
152 | 175 |
|
153 | 176 | template <typename T> struct DenseMapInfo;
|
154 | 177 |
|
| 178 | +template <> struct DenseMapInfo<AutoDiffConfig> { |
| 179 | + static AutoDiffConfig getEmptyKey() { |
| 180 | + auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey(); |
| 181 | + // The `derivativeGenericSignature` component must be `nullptr` so that |
| 182 | + // `getHashValue` and `isEqual` do not try to call |
| 183 | + // `GenericSignatureImpl::getCanonicalSignature()` on an invalid pointer. |
| 184 | + return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr), |
| 185 | + nullptr}; |
| 186 | + } |
| 187 | + |
| 188 | + static AutoDiffConfig getTombstoneKey() { |
| 189 | + auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
| 190 | + // The `derivativeGenericSignature` component must be `nullptr` so that |
| 191 | + // `getHashValue` and `isEqual` do not try to call |
| 192 | + // `GenericSignatureImpl::getCanonicalSignature()` on an invalid pointer. |
| 193 | + return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr), |
| 194 | + nullptr}; |
| 195 | + } |
| 196 | + |
| 197 | + static unsigned getHashValue(const AutoDiffConfig &Val) { |
| 198 | + auto canGenSig = |
| 199 | + Val.derivativeGenericSignature |
| 200 | + ? Val.derivativeGenericSignature->getCanonicalSignature() |
| 201 | + : nullptr; |
| 202 | + unsigned combinedHash = hash_combine( |
| 203 | + ~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices), |
| 204 | + DenseMapInfo<void *>::getHashValue(Val.resultIndices), |
| 205 | + DenseMapInfo<GenericSignature>::getHashValue(canGenSig)); |
| 206 | + return combinedHash; |
| 207 | + } |
| 208 | + |
| 209 | + static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) { |
| 210 | + auto lhsCanGenSig = |
| 211 | + LHS.derivativeGenericSignature |
| 212 | + ? LHS.derivativeGenericSignature->getCanonicalSignature() |
| 213 | + : nullptr; |
| 214 | + auto rhsCanGenSig = |
| 215 | + RHS.derivativeGenericSignature |
| 216 | + ? RHS.derivativeGenericSignature->getCanonicalSignature() |
| 217 | + : nullptr; |
| 218 | + return LHS.parameterIndices == RHS.parameterIndices && |
| 219 | + LHS.resultIndices == RHS.resultIndices && |
| 220 | + DenseMapInfo<GenericSignature>::isEqual(lhsCanGenSig, rhsCanGenSig); |
| 221 | + } |
| 222 | +}; |
| 223 | + |
155 | 224 | template <> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
|
156 | 225 | static AutoDiffDerivativeFunctionKind getEmptyKey() {
|
157 | 226 | return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
|
|
0 commit comments