Skip to content

Commit 92cc593

Browse files
authored
[AutoDiff] Support differentiating global constant closures. (#22004)
* [AutoDiff] Support differentiating global constant closures. * When there is a reference to a global `let` closure, we find its initializer in `@main`, trace back to the original `function_ref`, and differentiate that. * NFC: Refactor `findReferenceToVisibleFunction` and `findWitnessMethod` to a single template. * Add a global constant closure test that uses tensors. * Add global differentiable closure tests.
1 parent 01339e7 commit 92cc593

File tree

4 files changed

+122
-62
lines changed

4 files changed

+122
-62
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ NOTE(autodiff_function_generic_functions_unsupported,none,
373373
NOTE(autodiff_external_nondifferentiable_function,none,
374374
"cannot differentiate an external function that has not been marked "
375375
"'@differentiable'", ())
376+
NOTE(autodiff_global_let_closure_not_differentiable,none,
377+
"global constant closure is not differentiable", ())
378+
NOTE(autodiff_cannot_differentiate_global_var_closures,none,
379+
"cannot differentiate global mutable closures", ())
376380
NOTE(autodiff_protocol_member_not_differentiable,none,
377381
"member is not differentiable because the corresponding protocol "
378382
"requirement is not '@differentiable'", ())

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 94 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ template <typename T> static inline void debugDump(T &v) {
6868
<< v << "\n==== END DEBUG DUMP ====\n");
6969
}
7070

71+
/// Returns true if the module we are compiling is in an LLDB REPL.
7172
static bool isInLLDBREPL(SILModule &module) {
72-
llvm::StringRef module_name = module.getSwiftModule()->getNameStr();
7373
// TODO(SR-9704): Use a more prinicpled way to do this check.
74-
return module_name.startswith("__lldb_expr_");
74+
return module.getSwiftModule()->getNameStr().startswith("__lldb_expr_");
7575
}
7676

7777
/// Creates arguments in the entry block based on the function type.
@@ -197,31 +197,6 @@ static CanType joinElementTypesFromValues(SILValueRange &&range,
197197
return TupleType::get(elts, ctx)->getCanonicalType();
198198
}
199199

200-
/// Looks through the definition of a function value. If the source that
201-
/// produced this function value is `function_ref` and the function is visible
202-
/// (either in the same module or is serialized), returns the instruction.
203-
/// Otherwise, returns null.
204-
static FunctionRefInst *findReferenceToVisibleFunction(SILValue value) {
205-
auto *inst = value->getDefiningInstruction();
206-
if (!inst)
207-
return nullptr;
208-
if (auto *fri = dyn_cast<FunctionRefInst>(inst)) {
209-
auto *fn = fri->getReferencedFunction();
210-
if (&fn->getModule() == &inst->getModule() ||
211-
fn->isSerialized() == IsSerialized)
212-
return fri;
213-
}
214-
if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(inst))
215-
return findReferenceToVisibleFunction(thinToThick->getOperand());
216-
if (auto *convertFn = dyn_cast<ConvertFunctionInst>(inst))
217-
return findReferenceToVisibleFunction(convertFn->getOperand());
218-
if (auto *convertFn = dyn_cast<ConvertEscapeToNoEscapeInst>(inst))
219-
return findReferenceToVisibleFunction(convertFn->getOperand());
220-
if (auto *partialApply = dyn_cast<PartialApplyInst>(inst))
221-
return findReferenceToVisibleFunction(partialApply->getCallee());
222-
return nullptr;
223-
}
224-
225200
/// Given an operator name, such as "+", and a protocol, returns the
226201
/// "+" operator with type `(Self, Self) -> Self`. If the operator does not
227202
/// exist in the protocol, returns null.
@@ -693,6 +668,15 @@ class DifferentiationTask {
693668
SILFunction *getJVP() const { return jvp; }
694669
SILFunction *getVJP() const { return vjp; }
695670

671+
SILFunction *getAssociatedFunction(AutoDiffAssociatedFunctionKind kind) {
672+
switch (kind) {
673+
case AutoDiffAssociatedFunctionKind::JVP:
674+
return jvp;
675+
case AutoDiffAssociatedFunctionKind::VJP:
676+
return vjp;
677+
}
678+
}
679+
696680
DenseMap<ApplyInst *, NestedApplyActivity> &getNestedApplyActivities() {
697681
return nestedApplyActivities;
698682
}
@@ -1581,21 +1565,18 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
15811565
llvm_unreachable("Unhandled function convertion instruction");
15821566
}
15831567

1584-
/// Looks through function conversion instructions to find an underlying witness
1585-
/// method instruction. Returns `nullptr` if `value` does not come from a
1586-
/// `witness_method` or if there are unhandled conversion instructions between
1587-
/// `value` and the `witness_method`..
1588-
static WitnessMethodInst *findWitnessMethod(SILValue value) {
1589-
if (auto *witnessMethod = dyn_cast<WitnessMethodInst>(value))
1590-
return witnessMethod;
1568+
template<class Inst>
1569+
static Inst *peerThroughFunctionConversions(SILValue value) {
1570+
if (auto *inst = dyn_cast<Inst>(value))
1571+
return inst;
15911572
if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(value))
1592-
return findWitnessMethod(thinToThick->getOperand());
1573+
return peerThroughFunctionConversions<Inst>(thinToThick->getOperand());
15931574
if (auto *convertFn = dyn_cast<ConvertFunctionInst>(value))
1594-
return findWitnessMethod(convertFn->getOperand());
1575+
return peerThroughFunctionConversions<Inst>(convertFn->getOperand());
15951576
if (auto *convertFn = dyn_cast<ConvertEscapeToNoEscapeInst>(value))
1596-
return findWitnessMethod(convertFn->getOperand());
1577+
return peerThroughFunctionConversions<Inst>(convertFn->getOperand());
15971578
if (auto *partialApply = dyn_cast<PartialApplyInst>(value))
1598-
return findWitnessMethod(partialApply->getCallee());
1579+
return peerThroughFunctionConversions<Inst>(partialApply->getCallee());
15991580
return nullptr;
16001581
}
16011582

@@ -1615,8 +1596,8 @@ static WitnessMethodInst *findWitnessMethod(SILValue value) {
16151596
static Optional<std::pair<SILValue, SILAutoDiffIndices>>
16161597
emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
16171598
const DifferentiationTask *parentTask, SILAutoDiffIndices desiredIndices,
1618-
AutoDiffAssociatedFunctionKind kind,
1619-
SILValue original, DifferentiationInvoker invoker,
1599+
AutoDiffAssociatedFunctionKind kind, SILValue original,
1600+
DifferentiationInvoker invoker,
16201601
std::function<void(DifferentiationTask *)> taskCallback) {
16211602

16221603
// If `original` is itself an `AutoDiffFunctionExtractInst` whose kind matches
@@ -1646,24 +1627,21 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
16461627
}
16471628
}
16481629

1649-
// TODO: Refactor this function to recursively handle function conversions,
1650-
// rather than using `findReferenceToVisibleFunction`, `findWitnessMethod`,
1651-
// and `reapplyFunctionConversion`.
1652-
1653-
if (auto *originalFRI = findReferenceToVisibleFunction(original)) {
1630+
// Find local function reference.
1631+
if (auto *originalFRI =
1632+
peerThroughFunctionConversions<FunctionRefInst>(original)) {
16541633
auto loc = originalFRI->getLoc();
16551634
auto *originalFn = originalFRI->getReferencedFunction();
16561635
auto *task =
16571636
context.lookUpMinimalDifferentiationTask(originalFn, desiredIndices);
16581637
if (!task) {
16591638
if (originalFn->isExternalDeclaration()) {
1660-
// For lldb repl, we should attempt to load the function as
1639+
// For LLDB REPL, we should attempt to load the function as
16611640
// this may be defined in a different cell.
1662-
if (isInLLDBREPL(*original->getModule())) {
1641+
if (isInLLDBREPL(*original->getModule()))
16631642
original->getModule()->loadFunction(originalFn);
1664-
}
16651643
// If we still don't have the definition, generate an error message.
1666-
if (!originalFn->isDefinition()) {
1644+
if (originalFn->isExternalDeclaration()) {
16671645
context.emitNondifferentiabilityError(
16681646
original, parentTask,
16691647
diag::autodiff_external_nondifferentiable_function);
@@ -1675,22 +1653,75 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
16751653
}
16761654
assert(task);
16771655
taskCallback(task);
1678-
SILFunction *assocFn = nullptr;
1679-
switch (kind) {
1680-
case AutoDiffAssociatedFunctionKind::JVP:
1681-
assocFn = task->getJVP();
1682-
break;
1683-
case AutoDiffAssociatedFunctionKind::VJP:
1684-
assocFn = task->getVJP();
1685-
break;
1686-
}
1687-
auto *ref = builder.createFunctionRef(loc, assocFn);
1688-
auto convertedRef =
1689-
reapplyFunctionConversion(ref, originalFRI, original, builder, loc);
1656+
auto *ref =
1657+
builder.createFunctionRef(loc, task->getAssociatedFunction(kind));
1658+
auto convertedRef = reapplyFunctionConversion(
1659+
ref, originalFRI, original, builder, loc);
16901660
return std::make_pair(convertedRef, task->getIndices());
16911661
}
16921662

1693-
if (auto *witnessMethod = findWitnessMethod(original)) {
1663+
// Find global `let` closure.
1664+
if (auto *load = peerThroughFunctionConversions<LoadInst>(original)) {
1665+
FunctionRefInst *initialFnRef = nullptr;
1666+
SILValue initVal;
1667+
if (auto *globalAddr = dyn_cast<GlobalAddrInst>(load->getOperand())) {
1668+
// Search for the original function used to initialize this `let`
1669+
// constant.
1670+
if (auto *global = globalAddr->getReferencedGlobal()) {
1671+
if (!global->isLet()) {
1672+
context.emitNondifferentiabilityError(original, parentTask,
1673+
diag::autodiff_cannot_differentiate_global_var_closures);
1674+
return None;
1675+
}
1676+
// FIXME: In LLDB REPL, "main" will not be the function we should look
1677+
// for.
1678+
if (auto *mainFn = global->getModule().lookUpFunction("main")) {
1679+
if (mainFn->isDefinition())
1680+
for (auto &inst : mainFn->front())
1681+
if (auto *globalAddrInMain = dyn_cast<GlobalAddrInst>(&inst))
1682+
if (globalAddrInMain->getReferencedGlobal() == global)
1683+
for (auto *use : globalAddrInMain->getUses())
1684+
if (auto *store = dyn_cast<StoreInst>(use->getUser()))
1685+
if (store->getDest() == globalAddrInMain)
1686+
initialFnRef = peerThroughFunctionConversions
1687+
<FunctionRefInst>((initVal = store->getSrc()));
1688+
}
1689+
}
1690+
}
1691+
if (initialFnRef) {
1692+
assert(initVal);
1693+
auto *initialFn = initialFnRef->getReferencedFunction();
1694+
auto *task =
1695+
context.lookUpMinimalDifferentiationTask(initialFn, desiredIndices);
1696+
if (!task) {
1697+
if (initialFn->isExternalDeclaration()) {
1698+
if (isInLLDBREPL(*original->getModule()))
1699+
original->getModule()->loadFunction(initialFn);
1700+
if (initialFn->isExternalDeclaration()) {
1701+
context.emitNondifferentiabilityError(original, parentTask,
1702+
diag::autodiff_global_let_closure_not_differentiable);
1703+
return None;
1704+
}
1705+
}
1706+
task = context.registerDifferentiationTask(
1707+
initialFn, desiredIndices, invoker);
1708+
}
1709+
auto loc = original.getLoc();
1710+
auto *initialVJPRef = builder.createFunctionRef(
1711+
loc, task->getAssociatedFunction(kind));
1712+
auto converted =
1713+
reapplyFunctionConversion(initialVJPRef, initialFnRef, initVal,
1714+
builder, loc);
1715+
converted = reapplyFunctionConversion(converted, load, original,
1716+
builder, loc);
1717+
SILAutoDiffIndices indices(0, desiredIndices.parameters);
1718+
return std::make_pair(converted, indices);
1719+
}
1720+
}
1721+
1722+
// Find witness method retrieval.
1723+
if (auto *witnessMethod =
1724+
peerThroughFunctionConversions<WitnessMethodInst>(original)) {
16941725
auto loc = witnessMethod->getLoc();
16951726
auto requirement = witnessMethod->getMember();
16961727
auto *requirementDecl = requirement.getDecl();
@@ -1735,6 +1766,7 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
17351766
return std::make_pair(convertedRef, requirementIndices);
17361767
}
17371768

1769+
// Emit the general opaque function error.
17381770
context.emitNondifferentiabilityError(original, parentTask,
17391771
diag::autodiff_opaque_function_not_differentiable);
17401772
return None;

test/AutoDiff/simple_math.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,23 @@ SimpleMathTests.test("CaptureGlobal") {
7373
expectEqual(30, gradient(at: 0, in: foo))
7474
}
7575

76+
let foo: (Float) -> Float = { x in
77+
return x * x
78+
}
79+
SimpleMathTests.test("GlobalLet") {
80+
expectEqual(2, gradient(at: 1, in: foo))
81+
}
82+
83+
var foo_diffable: @autodiff (Float) -> (Float)
84+
= differentiableFunction { x in (x * x, { v in 2 * x * v }) }
85+
SimpleMathTests.test("GlobalDiffableFunc") {
86+
expectEqual(2, gradient(at: 1, in: foo_diffable))
87+
expectEqual(2, gradient(at: 1, in: { x in foo_diffable(x) }))
88+
expectEqual(1, gradient(at: 1, in: { (x: Float) -> Float in
89+
foo_diffable = { x in x + 1 };
90+
return foo_diffable(x)
91+
}))
92+
expectEqual(1, gradient(at: 1, in: foo_diffable))
93+
}
94+
7695
runAllTests()

test/TensorFlowRuntime/tensor_autodiff_runtime.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,9 @@ TensorADTests.testAllBackends("SR-9345: OwnedCheckpoints") {
143143
expectEqual(Tensor(1.0), pb(Tensor(1)))
144144
}
145145

146+
let cube: (Tensor<Float>) -> Tensor<Float> = { $0 * $0 * $0 }
147+
TensorADTests.testAllBackends("DifferentiateGlobal") {
148+
expectEqual(Tensor(48), gradient(at: Tensor(4), in: cube))
149+
}
150+
146151
runAllTests()

0 commit comments

Comments
 (0)