Skip to content

Commit 8d56a06

Browse files
authored
[AutoDiff] Not differentiate through global closures flow-sensitively. (#25676)
Earlier we added a hack to differentiate global constants or variables that are initialized with a closure whose body is differentiable. However, function conversion should not be flow sensitive. This patch removes the hack.
1 parent 8f5f77a commit 8d56a06

File tree

2 files changed

+0
-89
lines changed

2 files changed

+0
-89
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2086,88 +2086,6 @@ emitAssociatedFunctionReference(
20862086
return std::make_pair(convertedRef, minimalAttr->getIndices());
20872087
}
20882088

2089-
// Find global `let` closure.
2090-
if (auto *load = peerThroughFunctionConversions<LoadInst>(original)) {
2091-
FunctionRefInst *initialFnRef = nullptr;
2092-
SILValue initVal;
2093-
if (auto *globalAddr = dyn_cast<GlobalAddrInst>(load->getOperand())) {
2094-
// Search for the original function used to initialize this `let`
2095-
// constant.
2096-
if (auto *global = globalAddr->getReferencedGlobal()) {
2097-
if (!global->isLet()) {
2098-
context.emitNondifferentiabilityError(original, invoker,
2099-
diag::autodiff_cannot_differentiate_global_var_closures);
2100-
return None;
2101-
}
2102-
// FIXME: In LLDB REPL, "main" will not be the function we should look
2103-
// for.
2104-
if (auto *mainFn = global->getModule().lookUpFunction("main")) {
2105-
if (mainFn->isDefinition())
2106-
for (auto &inst : mainFn->front())
2107-
if (auto *globalAddrInMain = dyn_cast<GlobalAddrInst>(&inst))
2108-
if (globalAddrInMain->getReferencedGlobal() == global)
2109-
for (auto *use : globalAddrInMain->getUses())
2110-
if (auto *store = dyn_cast<StoreInst>(use->getUser()))
2111-
if (store->getDest() == globalAddrInMain)
2112-
initialFnRef = peerThroughFunctionConversions
2113-
<FunctionRefInst>((initVal = store->getSrc()));
2114-
}
2115-
}
2116-
}
2117-
if (initialFnRef) {
2118-
assert(initVal);
2119-
auto *initialFn = initialFnRef->getReferencedFunctionOrNull();
2120-
auto *minimalAttr =
2121-
context.lookUpMinimalDifferentiableAttr(initialFn, desiredIndices);
2122-
if (!minimalAttr) {
2123-
if (initialFn->isExternalDeclaration()) {
2124-
context.emitNondifferentiabilityError(
2125-
original, invoker,
2126-
diag::autodiff_global_let_closure_not_differentiable);
2127-
return None;
2128-
}
2129-
ArrayRef<Requirement> contextualRequirements;
2130-
if (invoker.getKind() ==
2131-
DifferentiationInvoker::Kind::IndirectDifferentiation)
2132-
contextualRequirements =
2133-
invoker.getIndirectDifferentiation().second->getRequirements();
2134-
auto *newAttr = context.getOrCreateDifferentiableAttr(
2135-
initialFn, desiredIndices, contextualRequirements);
2136-
bool error = context.processDifferentiableAttribute(
2137-
initialFn, newAttr, invoker);
2138-
if (error)
2139-
return None;
2140-
minimalAttr = newAttr;
2141-
}
2142-
if (context.processDifferentiableAttribute(
2143-
initialFn, minimalAttr, invoker))
2144-
return None;
2145-
SILFunction *assocFn = nullptr;
2146-
switch (kind) {
2147-
case AutoDiffAssociatedFunctionKind::JVP:
2148-
assert(!minimalAttr->getJVPName().empty() && "Expected JVP name");
2149-
assocFn = context.getModule().lookUpFunction(minimalAttr->getJVPName());
2150-
break;
2151-
case AutoDiffAssociatedFunctionKind::VJP:
2152-
assert(!minimalAttr->getVJPName().empty() && "Expected VJP name");
2153-
assocFn = context.getModule().lookUpFunction(minimalAttr->getVJPName());
2154-
break;
2155-
}
2156-
assert(assocFn && "Associated function must be resolved");
2157-
auto assocFnGenSig =
2158-
assocFn->getLoweredFunctionType()->getGenericSignature();
2159-
auto loc = original.getLoc();
2160-
auto *initialVJPRef = builder.createFunctionRef(loc, assocFn);
2161-
auto converted =
2162-
reapplyFunctionConversion(initialVJPRef, initialFnRef, initVal,
2163-
builder, loc, assocFnGenSig);
2164-
converted =
2165-
reapplyFunctionConversion(converted, load, original,
2166-
builder, loc, assocFnGenSig);
2167-
return std::make_pair(converted, minimalAttr->getIndices());
2168-
}
2169-
}
2170-
21712089
// Find witness method retrieval.
21722090
if (auto *witnessMethod =
21732091
peerThroughFunctionConversions<WitnessMethodInst>(original)) {

test/AutoDiff/simple_math.swift

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,6 @@ SimpleMathTests.test("CaptureGlobal") {
7373
expectEqual(30, gradient(at: 0, in: foo))
7474
}
7575

76-
let foo: (Float) -> Float = { x in
77-
return x * x
78-
}
79-
SimpleMathTests.test("GlobalLet") {
80-
expectEqual(2, gradient(at: 1, in: foo))
81-
}
82-
8376
var foo_diffable: @differentiable (Float) -> (Float)
8477
= differentiableFunction { x in (x * x, { v in 2 * x * v }) }
8578
SimpleMathTests.test("GlobalDiffableFunc") {

0 commit comments

Comments
 (0)