Skip to content

Commit 6e2c4fa

Browse files
authored
[AutoDiff] Lookup for custom derivatives in non-primary source files (#58965)
* Lookup for custom derivatives in non-primary source files after typecheck is finished for the primary source. This registers all custom derivatives before autodiff transformations and makes them available to them. Fully resolves #55170
1 parent 2679294 commit 6e2c4fa

File tree

9 files changed

+111
-37
lines changed

9 files changed

+111
-37
lines changed

include/swift/AST/Decl.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6277,11 +6277,9 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
62776277
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;
62786278

62796279
public:
6280-
/// Get all derivative function configurations. If `lookInNonPrimarySources`
6281-
/// is true then lookup is done in non-primary sources as well. Note that
6282-
/// such lookup might end in cycles if done during sema stages.
6280+
/// Get all derivative function configurations.
62836281
ArrayRef<AutoDiffConfig>
6284-
getDerivativeFunctionConfigurations(bool lookInNonPrimarySources = true);
6282+
getDerivativeFunctionConfigurations();
62856283

62866284
/// Add the given derivative function configuration.
62876285
void addDerivativeFunctionConfiguration(const AutoDiffConfig &config);

include/swift/Frontend/Frontend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,7 @@ class CompilerInstance {
669669

670670
/// If \p fn returns true, exits early and returns true.
671671
bool forEachFileToTypeCheck(llvm::function_ref<bool(SourceFile &)> fn);
672+
bool forEachSourceFile(llvm::function_ref<bool(SourceFile &)> fn);
672673

673674
/// Whether the cancellation of the current operation has been requested.
674675
bool isCancellationRequested() const;

include/swift/Subsystems.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ namespace swift {
157157
/// emitted.
158158
void performWholeModuleTypeChecking(SourceFile &SF);
159159

160+
/// Load derivative configurations from @derivative attributes (including
161+
/// those defined in non-primary sources).
162+
void loadDerivativeConfigurations(SourceFile &SF);
163+
160164
/// Resolve the given \c TypeRepr to an interface type.
161165
///
162166
/// This is used when dealing with partial source files (e.g. SIL parsing,

lib/AST/Decl.cpp

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8321,7 +8321,7 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
83218321
}
83228322

83238323
ArrayRef<AutoDiffConfig>
8324-
AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimarySources) {
8324+
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
83258325
prepareDerivativeFunctionConfigurations();
83268326

83278327
// Resolve derivative function configurations from `@differentiable`
@@ -8345,36 +8345,6 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimaryS
83458345
*DerivativeFunctionConfigs);
83468346
}
83478347

8348-
class DerivativeFinder : public ASTWalker {
8349-
const AbstractFunctionDecl *AFD;
8350-
public:
8351-
DerivativeFinder(const AbstractFunctionDecl *afd) : AFD(afd) {}
8352-
8353-
bool walkToDeclPre(Decl *D) override {
8354-
if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) {
8355-
for (auto *derAttr : afd->getAttrs().getAttributes<DerivativeAttr>()) {
8356-
// Resolve derivative function configurations from `@derivative`
8357-
// attributes by type-checking them.
8358-
if (AFD->getName().matchesRef(
8359-
derAttr->getOriginalFunctionName().Name.getFullName())) {
8360-
(void)derAttr->getOriginalFunction(afd->getASTContext());
8361-
return false;
8362-
}
8363-
}
8364-
}
8365-
8366-
return true;
8367-
}
8368-
};
8369-
8370-
// Load derivative configurations from @derivative attributes defined in
8371-
// non-primary sources. Note that it might trigger lookup cycles if called
8372-
// from inside Sema stages.
8373-
if (lookInNonPrimarySources) {
8374-
DerivativeFinder finder(this);
8375-
getParent()->walkContext(finder);
8376-
}
8377-
83788348
return DerivativeFunctionConfigs->getArrayRef();
83798349
}
83808350

lib/Frontend/Frontend.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,11 +1185,31 @@ bool CompilerInstance::forEachFileToTypeCheck(
11851185
return false;
11861186
}
11871187

1188+
bool CompilerInstance::forEachSourceFile(
1189+
llvm::function_ref<bool(SourceFile &)> fn) {
1190+
for (auto fileName : getMainModule()->getFiles()) {
1191+
auto *SF = dyn_cast<SourceFile>(fileName);
1192+
if (!SF) {
1193+
continue;
1194+
}
1195+
if (fn(*SF))
1196+
return true;
1197+
;
1198+
}
1199+
1200+
return false;
1201+
}
1202+
11881203
void CompilerInstance::finishTypeChecking() {
11891204
forEachFileToTypeCheck([](SourceFile &SF) {
11901205
performWholeModuleTypeChecking(SF);
11911206
return false;
11921207
});
1208+
1209+
forEachSourceFile([](SourceFile &SF) {
1210+
loadDerivativeConfigurations(SF);
1211+
return false;
1212+
});
11931213
}
11941214

11951215
SourceFile::ParsingOptions

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,7 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
379379
bool foundExactConfig = false;
380380
Optional<AutoDiffConfig> supersetConfig = None;
381381
for (auto witnessConfig :
382-
witnessAFD->getDerivativeFunctionConfigurations(
383-
/*lookInNonPrimarySources*/ false)) {
382+
witnessAFD->getDerivativeFunctionConfigurations()) {
384383
// All the witness's derivative generic requirements must be satisfied
385384
// by the requirement's derivative generic requirements OR by the
386385
// conditional conformance requirements.

lib/Sema/TypeChecker.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,44 @@ void swift::performWholeModuleTypeChecking(SourceFile &SF) {
379379
}
380380
}
381381

382+
void swift::loadDerivativeConfigurations(SourceFile &SF) {
383+
if (!isDifferentiableProgrammingEnabled(SF))
384+
return;
385+
386+
auto &Ctx = SF.getASTContext();
387+
FrontendStatsTracer tracer(Ctx.Stats,
388+
"load-derivative-configurations");
389+
390+
class DerivativeFinder : public ASTWalker {
391+
public:
392+
DerivativeFinder() {}
393+
394+
bool walkToDeclPre(Decl *D) override {
395+
if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) {
396+
for (auto *derAttr : afd->getAttrs().getAttributes<DerivativeAttr>()) {
397+
// Resolve derivative function configurations from `@derivative`
398+
// attributes by type-checking them.
399+
(void)derAttr->getOriginalFunction(D->getASTContext());
400+
}
401+
}
402+
403+
return true;
404+
}
405+
};
406+
407+
switch (SF.Kind) {
408+
case SourceFileKind::Library:
409+
case SourceFileKind::Main: {
410+
DerivativeFinder finder;
411+
SF.walkContext(finder);
412+
return;
413+
}
414+
case SourceFileKind::SIL:
415+
case SourceFileKind::Interface:
416+
return;
417+
}
418+
}
419+
382420
bool swift::isAdditiveArithmeticConformanceDerivationEnabled(SourceFile &SF) {
383421
auto &ctx = SF.getASTContext();
384422
// Return true if `AdditiveArithmetic` derived conformances are explicitly
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import _Differentiation
2+
3+
@inlinable
4+
@derivative(of: min)
5+
func minVJP<T: Comparable & Differentiable>(
6+
_ x: T,
7+
_ y: T
8+
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
9+
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
10+
if x <= y {
11+
return (v, .zero)
12+
}
13+
else {
14+
return (.zero, v)
15+
}
16+
}
17+
return (value: min(x, y), pullback: pullback)
18+
}
19+
20+
@inlinable
21+
@derivative(of: max)
22+
func maxVJP<T: Comparable & Differentiable>(
23+
_ x: T,
24+
_ y: T
25+
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
26+
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
27+
if x < y {
28+
return (.zero, v)
29+
}
30+
else {
31+
return (v, .zero)
32+
}
33+
}
34+
return (value: max(x, y), pullback: pullback)
35+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/derivatives.swift -module-name main -o /dev/null
2+
3+
import _Differentiation
4+
5+
@differentiable(reverse)
6+
func clamp(_ value: Double, _ lowerBound: Double, _ upperBound: Double) -> Double {
7+
// No error expected
8+
return max(min(value, upperBound), lowerBound)
9+
}

0 commit comments

Comments
 (0)