Skip to content

Commit 486d206

Browse files
committed
[AutoDiff] Enhance performance of custom derivatives lookup
In swiftlang#58965, lookup for custom derivatives in non-primary source files was introduced. It required traversing all delayed parsed function bodies of a file if the file was compiled with differential programming enabled (even for functions with no `@derivative` attribute). This patch introduces `CustomDerivativesLookupRequest` to address the issue. Resolves swiftlang#60102
1 parent 3aed095 commit 486d206

File tree

5 files changed

+55
-23
lines changed

5 files changed

+55
-23
lines changed

include/swift/AST/Module.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "swift/AST/Import.h"
2626
#include "swift/AST/LookupKinds.h"
2727
#include "swift/AST/Type.h"
28+
#include "swift/AST/TypeCheckRequests.h"
2829
#include "swift/Basic/Assertions.h"
2930
#include "swift/Basic/BasicSourceInfo.h"
3031
#include "swift/Basic/CXXStdlibKind.h"
@@ -240,6 +241,7 @@ class ModuleDecl
240241
: public DeclContext, public TypeDecl, public ASTAllocated<ModuleDecl> {
241242
friend class DirectOperatorLookupRequest;
242243
friend class DirectPrecedenceGroupLookupRequest;
244+
friend class CustomDerivativesRequest;
243245

244246
/// The ABI name of the module, if it differs from the module name.
245247
mutable Identifier ModuleABIName;

include/swift/AST/TypeCheckRequests.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5165,6 +5165,23 @@ class GenericTypeParamDeclGetValueTypeRequest
51655165
bool isCached() const { return true; }
51665166
};
51675167

5168+
class CustomDerivativesRequest
5169+
: public SimpleRequest<CustomDerivativesRequest,
5170+
evaluator::SideEffect(ModuleDecl *),
5171+
RequestFlags::Cached> {
5172+
public:
5173+
using SimpleRequest::SimpleRequest;
5174+
5175+
private:
5176+
friend SimpleRequest;
5177+
5178+
evaluator::SideEffect evaluate(Evaluator &evaluator,
5179+
ModuleDecl *module) const;
5180+
5181+
public:
5182+
bool isCached() const { return true; }
5183+
};
5184+
51685185
#define SWIFT_TYPEID_ZONE TypeChecker
51695186
#define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def"
51705187
#include "swift/Basic/DefineTypeIDZone.h"

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,9 @@ SWIFT_REQUEST(TypeChecker, ParamCaptureInfoRequest,
605605
SWIFT_REQUEST(TypeChecker, IsUnsafeRequest,
606606
bool(Decl *),
607607
SeparatelyCached, NoLocationInfo)
608+
SWIFT_REQUEST(TypeChecker, CustomDerivativesRequest,
609+
CustomDerivativesResult(ModuleDecl *),
610+
Cached, NoLocationInfo)
608611

609612
SWIFT_REQUEST(TypeChecker, GenericTypeParamDeclGetValueTypeRequest,
610613
Type(GenericTypeParamDecl *), Cached, NoLocationInfo)

lib/AST/Module.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class swift::SourceLookupCache {
166166
ValueDeclMap TopLevelValues;
167167
ValueDeclMap ClassMembers;
168168
bool MemberCachePopulated = false;
169+
CustomDerivativesLookupResult CustomDerivatives;
169170
DeclName UniqueMacroNamePlaceholder;
170171

171172
template<typename T>
@@ -205,6 +206,9 @@ class swift::SourceLookupCache {
205206
/// guaranteed to be meaningful.
206207
void getPrecedenceGroups(SmallVectorImpl<PrecedenceGroupDecl *> &results);
207208

209+
// TODO: is it valid to return const reference from here?
210+
llvm::SmallVector<AbstractFunctionDecl *, 0> getCustomDerivativeDecls();
211+
208212
/// Look up an operator declaration.
209213
///
210214
/// \param name The operator name ("+", ">>", etc.)
@@ -275,6 +279,11 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
275279
if (!onlyOperators && VD->getAttrs().hasAttribute<CustomAttr>()) {
276280
MayHaveAuxiliaryDecls.push_back(VD);
277281
}
282+
283+
if (!onlyOperators)
284+
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(VD))
285+
if (AFD->getAttrs().hasAttribute<DerivativeAttr>())
286+
CustomDerivatives.push_back(AFD);
278287
}
279288
}
280289

@@ -572,6 +581,11 @@ void SourceLookupCache::getOperatorDecls(
572581
results.append(ops.second.begin(), ops.second.end());
573582
}
574583

584+
llvm::SmallVector<AbstractFunctionDecl *, 0>
585+
SourceLookupCache::getCustomDerivativeDecls() {
586+
return CustomDerivatives;
587+
}
588+
575589
void SourceLookupCache::lookupOperator(Identifier name, OperatorFixity fixity,
576590
TinyPtrVector<OperatorDecl *> &results) {
577591
auto ops = Operators.find(name);
@@ -4008,6 +4022,23 @@ bool IsNonUserModuleRequest::evaluate(Evaluator &evaluator, ModuleDecl *mod) con
40084022
(!sdkOrPlatform.empty() && pathStartsWith(sdkOrPlatform, modulePath));
40094023
}
40104024

4025+
evaluator::SideEffect
4026+
CustomDerivativesRequest::evaluate(Evaluator &evaluator,
4027+
ModuleDecl *module) const {
4028+
assert(isParsedModule(module));
4029+
llvm::SmallVector<AbstractFunctionDecl *, 0> decls =
4030+
module->getSourceLookupCache().getCustomDerivativeDecls();
4031+
for (const AbstractFunctionDecl *afd : decls) {
4032+
for (const auto *derAttr :
4033+
afd->getAttrs().getAttributes<DerivativeAttr>()) {
4034+
// Resolve derivative function configurations from `@derivative`
4035+
// attributes by type-checking them.
4036+
(void)derAttr->getOriginalFunction(SF.getASTContext());
4037+
}
4038+
}
4039+
return {};
4040+
}
4041+
40114042
version::Version ModuleDecl::getLanguageVersionBuiltWith() const {
40124043
for (auto *F : getFiles()) {
40134044
auto *LD = dyn_cast<LoadedFile>(F);

lib/Sema/TypeChecker.cpp

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -404,34 +404,13 @@ void swift::loadDerivativeConfigurations(SourceFile &SF) {
404404
FrontendStatsTracer tracer(Ctx.Stats,
405405
"load-derivative-configurations");
406406

407-
class DerivativeFinder : public ASTWalker {
408-
public:
409-
DerivativeFinder() {}
410-
411-
MacroWalking getMacroWalkingBehavior() const override {
412-
return MacroWalking::Expansion;
413-
}
414-
415-
PreWalkAction walkToDeclPre(Decl *D) override {
416-
if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) {
417-
for (auto *derAttr : afd->getAttrs().getAttributes<DerivativeAttr>()) {
418-
// Resolve derivative function configurations from `@derivative`
419-
// attributes by type-checking them.
420-
(void)derAttr->getOriginalFunction(D->getASTContext());
421-
}
422-
}
423-
424-
return Action::Continue();
425-
}
426-
};
427-
428407
switch (SF.Kind) {
429408
case SourceFileKind::DefaultArgument:
430409
case SourceFileKind::Library:
431410
case SourceFileKind::MacroExpansion:
432411
case SourceFileKind::Main: {
433-
DerivativeFinder finder;
434-
SF.walkContext(finder);
412+
CustomDerivativesRequest request(SF.getParentModule());
413+
evaluateOrDefault(SF.getASTContext().evaluator, request, {});
435414
return;
436415
}
437416
case SourceFileKind::SIL:

0 commit comments

Comments
 (0)