Skip to content

Commit 01339e7

Browse files
bgogulrxwei
authored andcommitted
Allow differentiation to look across modules for definitions in repl mode (#21992)
* Allow differentiation to look across modules for function definitions when running in lldb repl mode. * Refactor repl check to a static function.
1 parent 5dbd653 commit 01339e7

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ template <typename T> static inline void debugDump(T &v) {
6868
<< v << "\n==== END DEBUG DUMP ====\n");
6969
}
7070

71+
static bool isInLLDBREPL(SILModule &module) {
72+
llvm::StringRef module_name = module.getSwiftModule()->getNameStr();
73+
// TODO(SR-9704): Use a more prinicpled way to do this check.
74+
return module_name.startswith("__lldb_expr_");
75+
}
76+
7177
/// Creates arguments in the entry block based on the function type.
7278
static void createEntryArguments(SILFunction *f) {
7379
auto *entry = f->getEntryBlock();
@@ -1651,12 +1657,21 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
16511657
context.lookUpMinimalDifferentiationTask(originalFn, desiredIndices);
16521658
if (!task) {
16531659
if (originalFn->isExternalDeclaration()) {
1654-
context.emitNondifferentiabilityError(original, parentTask,
1655-
diag::autodiff_external_nondifferentiable_function);
1656-
return None;
1660+
// For lldb repl, we should attempt to load the function as
1661+
// this may be defined in a different cell.
1662+
if (isInLLDBREPL(*original->getModule())) {
1663+
original->getModule()->loadFunction(originalFn);
1664+
}
1665+
// If we still don't have the definition, generate an error message.
1666+
if (!originalFn->isDefinition()) {
1667+
context.emitNondifferentiabilityError(
1668+
original, parentTask,
1669+
diag::autodiff_external_nondifferentiable_function);
1670+
return None;
1671+
}
16571672
}
1658-
task = context.registerDifferentiationTask(
1659-
originalFn, desiredIndices, invoker);
1673+
task = context.registerDifferentiationTask(originalFn, desiredIndices,
1674+
invoker);
16601675
}
16611676
assert(task);
16621677
taskCallback(task);

0 commit comments

Comments
 (0)