Skip to content

Commit f5b40d6

Browse files
authored
[AutoDiff upstream] Add SIL derivative function type caching. (#29953)
Upstream #29590: cache `SILFunctionType::getAutoDiffDerivativeFunctionType` results.
1 parent 88ca382 commit f5b40d6

File tree

3 files changed

+77
-5
lines changed

3 files changed

+77
-5
lines changed

include/swift/AST/ASTContext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ namespace swift {
114114
class VarDecl;
115115
class UnifiedStatsReporter;
116116
class IndexSubset;
117+
struct SILAutoDiffDerivativeFunctionKey;
117118

118119
enum class KnownProtocolKind : uint8_t;
119120

@@ -288,6 +289,10 @@ class ASTContext final {
288289
/// Cached mapping from types to their associated tangent spaces.
289290
llvm::DenseMap<Type, Optional<TangentSpace>> AutoDiffTangentSpaces;
290291

292+
/// A cache of derivative function types per configuration.
293+
llvm::DenseMap<SILAutoDiffDerivativeFunctionKey, CanSILFunctionType>
294+
SILAutoDiffDerivativeFunctions;
295+
291296
/// Cache of `@differentiable` attributes keyed by parameter indices. Used to
292297
/// diagnose duplicate `@differentiable` attributes for the same key.
293298
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>

include/swift/AST/AutoDiff.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,16 @@ struct AutoDiffConfig {
113113
SWIFT_DEBUG_DUMP;
114114
};
115115

116+
/// Key for caching SIL derivative function types.
117+
struct SILAutoDiffDerivativeFunctionKey {
118+
SILFunctionType *originalType;
119+
IndexSubset *parameterIndices;
120+
IndexSubset *resultIndices;
121+
AutoDiffDerivativeFunctionKind kind;
122+
CanGenericSignature derivativeFnGenSig;
123+
bool isReabstractionThunk;
124+
};
125+
116126
class ParsedAutoDiffParameter {
117127
public:
118128
enum class Kind { Named, Ordered, Self };
@@ -281,8 +291,11 @@ namespace llvm {
281291

282292
using swift::AutoDiffConfig;
283293
using swift::AutoDiffDerivativeFunctionKind;
294+
using swift::CanGenericSignature;
284295
using swift::GenericSignature;
285296
using swift::IndexSubset;
297+
using swift::SILAutoDiffDerivativeFunctionKey;
298+
using swift::SILFunctionType;
286299

287300
template <typename T> struct DenseMapInfo;
288301

@@ -354,6 +367,50 @@ template <> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
354367
}
355368
};
356369

370+
template <> struct DenseMapInfo<SILAutoDiffDerivativeFunctionKey> {
371+
static bool isEqual(const SILAutoDiffDerivativeFunctionKey lhs,
372+
const SILAutoDiffDerivativeFunctionKey rhs) {
373+
return lhs.originalType == rhs.originalType &&
374+
lhs.parameterIndices == rhs.parameterIndices &&
375+
lhs.resultIndices == rhs.resultIndices &&
376+
lhs.kind.rawValue == rhs.kind.rawValue &&
377+
lhs.derivativeFnGenSig == rhs.derivativeFnGenSig &&
378+
lhs.isReabstractionThunk == rhs.isReabstractionThunk;
379+
}
380+
381+
static inline SILAutoDiffDerivativeFunctionKey getEmptyKey() {
382+
return {DenseMapInfo<SILFunctionType *>::getEmptyKey(),
383+
DenseMapInfo<IndexSubset *>::getEmptyKey(),
384+
DenseMapInfo<IndexSubset *>::getEmptyKey(),
385+
AutoDiffDerivativeFunctionKind::innerty(
386+
DenseMapInfo<unsigned>::getEmptyKey()),
387+
CanGenericSignature(DenseMapInfo<GenericSignature>::getEmptyKey()),
388+
(bool)DenseMapInfo<unsigned>::getEmptyKey()};
389+
}
390+
391+
static inline SILAutoDiffDerivativeFunctionKey getTombstoneKey() {
392+
return {
393+
DenseMapInfo<SILFunctionType *>::getTombstoneKey(),
394+
DenseMapInfo<IndexSubset *>::getTombstoneKey(),
395+
DenseMapInfo<IndexSubset *>::getTombstoneKey(),
396+
AutoDiffDerivativeFunctionKind::innerty(
397+
DenseMapInfo<unsigned>::getTombstoneKey()),
398+
CanGenericSignature(DenseMapInfo<GenericSignature>::getTombstoneKey()),
399+
(bool)DenseMapInfo<unsigned>::getTombstoneKey()};
400+
}
401+
402+
static unsigned getHashValue(const SILAutoDiffDerivativeFunctionKey &Val) {
403+
return hash_combine(
404+
DenseMapInfo<SILFunctionType *>::getHashValue(Val.originalType),
405+
DenseMapInfo<IndexSubset *>::getHashValue(Val.parameterIndices),
406+
DenseMapInfo<IndexSubset *>::getHashValue(Val.resultIndices),
407+
DenseMapInfo<unsigned>::getHashValue((unsigned)Val.kind.rawValue),
408+
DenseMapInfo<GenericSignature>::getHashValue(Val.derivativeFnGenSig),
409+
DenseMapInfo<unsigned>::getHashValue(
410+
(unsigned)Val.isReabstractionThunk));
411+
}
412+
};
413+
357414
} // end namespace llvm
358415

359416
#endif // SWIFT_AST_AUTODIFF_H

lib/SIL/SILFunctionType.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,15 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
247247
LookupConformanceFn lookupConformance,
248248
CanGenericSignature derivativeFnGenSig, bool isReabstractionThunk) {
249249
auto &ctx = getASTContext();
250+
auto resultIndices = IndexSubset::get(ctx, getNumResults(), {resultIndex});
251+
SILAutoDiffDerivativeFunctionKey key{
252+
this, parameterIndices, resultIndices,
253+
kind, derivativeFnGenSig, isReabstractionThunk};
254+
auto insertion =
255+
ctx.SILAutoDiffDerivativeFunctions.try_emplace(key, CanSILFunctionType());
256+
auto &cachedResult = insertion.first->getSecond();
257+
if (!insertion.second)
258+
return cachedResult;
250259

251260
// Returns true if `index` is a differentiability parameter index.
252261
auto isDiffParamIndex = [&](unsigned index) -> bool {
@@ -396,11 +405,12 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
396405
auto extInfo = getExtInfo();
397406
if (getRepresentation() == SILFunctionTypeRepresentation::CFunctionPointer)
398407
extInfo = extInfo.withRepresentation(SILFunctionTypeRepresentation::Thin);
399-
return SILFunctionType::get(canGenSig, extInfo, getCoroutineKind(),
400-
getCalleeConvention(), newParameters, getYields(),
401-
newResults, getOptionalErrorResult(),
402-
getSubstitutions(), isGenericSignatureImplied(),
403-
ctx, getWitnessMethodConformanceOrInvalid());
408+
cachedResult = SILFunctionType::get(
409+
canGenSig, extInfo, getCoroutineKind(), getCalleeConvention(),
410+
newParameters, getYields(), newResults, getOptionalErrorResult(),
411+
getSubstitutions(), isGenericSignatureImplied(), ctx,
412+
getWitnessMethodConformanceOrInvalid());
413+
return cachedResult;
404414
}
405415

406416
CanSILFunctionType SILFunctionType::getAutoDiffTransposeFunctionType(

0 commit comments

Comments
 (0)