Skip to content

[AutoDiff] Support differentiating global constant closures. #22004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,10 @@ NOTE(autodiff_function_generic_functions_unsupported,none,
NOTE(autodiff_external_nondifferentiable_function,none,
"cannot differentiate an external function that has not been marked "
"'@differentiable'", ())
NOTE(autodiff_global_let_closure_not_differentiable,none,
"global constant closure is not differentiable", ())
NOTE(autodiff_cannot_differentiate_global_var_closures,none,
"cannot differentiate global mutable closures", ())
NOTE(autodiff_protocol_member_not_differentiable,none,
"member is not differentiable because the corresponding protocol "
"requirement is not '@differentiable'", ())
Expand Down
156 changes: 94 additions & 62 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ template <typename T> static inline void debugDump(T &v) {
<< v << "\n==== END DEBUG DUMP ====\n");
}

/// Returns true if the module we are compiling is in an LLDB REPL.
static bool isInLLDBREPL(SILModule &module) {
llvm::StringRef module_name = module.getSwiftModule()->getNameStr();
// TODO(SR-9704): Use a more prinicpled way to do this check.
return module_name.startswith("__lldb_expr_");
return module.getSwiftModule()->getNameStr().startswith("__lldb_expr_");
}

/// Creates arguments in the entry block based on the function type.
Expand Down Expand Up @@ -197,31 +197,6 @@ static CanType joinElementTypesFromValues(SILValueRange &&range,
return TupleType::get(elts, ctx)->getCanonicalType();
}

/// Looks through the definition of a function value. If the source that
/// produced this function value is `function_ref` and the function is visible
/// (either in the same module or is serialized), returns the instruction.
/// Otherwise, returns null.
static FunctionRefInst *findReferenceToVisibleFunction(SILValue value) {
auto *inst = value->getDefiningInstruction();
if (!inst)
return nullptr;
if (auto *fri = dyn_cast<FunctionRefInst>(inst)) {
auto *fn = fri->getReferencedFunction();
if (&fn->getModule() == &inst->getModule() ||
fn->isSerialized() == IsSerialized)
return fri;
}
if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(inst))
return findReferenceToVisibleFunction(thinToThick->getOperand());
if (auto *convertFn = dyn_cast<ConvertFunctionInst>(inst))
return findReferenceToVisibleFunction(convertFn->getOperand());
if (auto *convertFn = dyn_cast<ConvertEscapeToNoEscapeInst>(inst))
return findReferenceToVisibleFunction(convertFn->getOperand());
if (auto *partialApply = dyn_cast<PartialApplyInst>(inst))
return findReferenceToVisibleFunction(partialApply->getCallee());
return nullptr;
}

/// Given an operator name, such as "+", and a protocol, returns the
/// "+" operator with type `(Self, Self) -> Self`. If the operator does not
/// exist in the protocol, returns null.
Expand Down Expand Up @@ -693,6 +668,15 @@ class DifferentiationTask {
SILFunction *getJVP() const { return jvp; }
SILFunction *getVJP() const { return vjp; }

SILFunction *getAssociatedFunction(AutoDiffAssociatedFunctionKind kind) {
switch (kind) {
case AutoDiffAssociatedFunctionKind::JVP:
return jvp;
case AutoDiffAssociatedFunctionKind::VJP:
return vjp;
}
}

DenseMap<ApplyInst *, NestedApplyActivity> &getNestedApplyActivities() {
return nestedApplyActivities;
}
Expand Down Expand Up @@ -1581,21 +1565,18 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
llvm_unreachable("Unhandled function convertion instruction");
}

/// Looks through function conversion instructions to find an underlying witness
/// method instruction. Returns `nullptr` if `value` does not come from a
/// `witness_method` or if there are unhandled conversion instructions between
/// `value` and the `witness_method`..
static WitnessMethodInst *findWitnessMethod(SILValue value) {
if (auto *witnessMethod = dyn_cast<WitnessMethodInst>(value))
return witnessMethod;
template<class Inst>
static Inst *peerThroughFunctionConversions(SILValue value) {
if (auto *inst = dyn_cast<Inst>(value))
return inst;
if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(value))
return findWitnessMethod(thinToThick->getOperand());
return peerThroughFunctionConversions<Inst>(thinToThick->getOperand());
if (auto *convertFn = dyn_cast<ConvertFunctionInst>(value))
return findWitnessMethod(convertFn->getOperand());
return peerThroughFunctionConversions<Inst>(convertFn->getOperand());
if (auto *convertFn = dyn_cast<ConvertEscapeToNoEscapeInst>(value))
return findWitnessMethod(convertFn->getOperand());
return peerThroughFunctionConversions<Inst>(convertFn->getOperand());
if (auto *partialApply = dyn_cast<PartialApplyInst>(value))
return findWitnessMethod(partialApply->getCallee());
return peerThroughFunctionConversions<Inst>(partialApply->getCallee());
return nullptr;
}

Expand All @@ -1615,8 +1596,8 @@ static WitnessMethodInst *findWitnessMethod(SILValue value) {
static Optional<std::pair<SILValue, SILAutoDiffIndices>>
emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
const DifferentiationTask *parentTask, SILAutoDiffIndices desiredIndices,
AutoDiffAssociatedFunctionKind kind,
SILValue original, DifferentiationInvoker invoker,
AutoDiffAssociatedFunctionKind kind, SILValue original,
DifferentiationInvoker invoker,
std::function<void(DifferentiationTask *)> taskCallback) {

// If `original` is itself an `AutoDiffFunctionExtractInst` whose kind matches
Expand Down Expand Up @@ -1646,24 +1627,21 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
}
}

// TODO: Refactor this function to recursively handle function conversions,
// rather than using `findReferenceToVisibleFunction`, `findWitnessMethod`,
// and `reapplyFunctionConversion`.

if (auto *originalFRI = findReferenceToVisibleFunction(original)) {
// Find local function reference.
if (auto *originalFRI =
peerThroughFunctionConversions<FunctionRefInst>(original)) {
auto loc = originalFRI->getLoc();
auto *originalFn = originalFRI->getReferencedFunction();
auto *task =
context.lookUpMinimalDifferentiationTask(originalFn, desiredIndices);
if (!task) {
if (originalFn->isExternalDeclaration()) {
// For lldb repl, we should attempt to load the function as
// For LLDB REPL, we should attempt to load the function as
// this may be defined in a different cell.
if (isInLLDBREPL(*original->getModule())) {
if (isInLLDBREPL(*original->getModule()))
original->getModule()->loadFunction(originalFn);
}
// If we still don't have the definition, generate an error message.
if (!originalFn->isDefinition()) {
if (originalFn->isExternalDeclaration()) {
context.emitNondifferentiabilityError(
original, parentTask,
diag::autodiff_external_nondifferentiable_function);
Expand All @@ -1675,22 +1653,75 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
}
assert(task);
taskCallback(task);
SILFunction *assocFn = nullptr;
switch (kind) {
case AutoDiffAssociatedFunctionKind::JVP:
assocFn = task->getJVP();
break;
case AutoDiffAssociatedFunctionKind::VJP:
assocFn = task->getVJP();
break;
}
auto *ref = builder.createFunctionRef(loc, assocFn);
auto convertedRef =
reapplyFunctionConversion(ref, originalFRI, original, builder, loc);
auto *ref =
builder.createFunctionRef(loc, task->getAssociatedFunction(kind));
auto convertedRef = reapplyFunctionConversion(
ref, originalFRI, original, builder, loc);
return std::make_pair(convertedRef, task->getIndices());
}

if (auto *witnessMethod = findWitnessMethod(original)) {
// Find global `let` closure.
if (auto *load = peerThroughFunctionConversions<LoadInst>(original)) {
FunctionRefInst *initialFnRef = nullptr;
SILValue initVal;
if (auto *globalAddr = dyn_cast<GlobalAddrInst>(load->getOperand())) {
// Search for the original function used to initialize this `let`
// constant.
if (auto *global = globalAddr->getReferencedGlobal()) {
if (!global->isLet()) {
context.emitNondifferentiabilityError(original, parentTask,
diag::autodiff_cannot_differentiate_global_var_closures);
return None;
}
// FIXME: In LLDB REPL, "main" will not be the function we should look
// for.
if (auto *mainFn = global->getModule().lookUpFunction("main")) {
if (mainFn->isDefinition())
for (auto &inst : mainFn->front())
if (auto *globalAddrInMain = dyn_cast<GlobalAddrInst>(&inst))
if (globalAddrInMain->getReferencedGlobal() == global)
for (auto *use : globalAddrInMain->getUses())
if (auto *store = dyn_cast<StoreInst>(use->getUser()))
if (store->getDest() == globalAddrInMain)
initialFnRef = peerThroughFunctionConversions
<FunctionRefInst>((initVal = store->getSrc()));
}
}
}
if (initialFnRef) {
assert(initVal);
auto *initialFn = initialFnRef->getReferencedFunction();
auto *task =
context.lookUpMinimalDifferentiationTask(initialFn, desiredIndices);
if (!task) {
if (initialFn->isExternalDeclaration()) {
if (isInLLDBREPL(*original->getModule()))
original->getModule()->loadFunction(initialFn);
if (initialFn->isExternalDeclaration()) {
context.emitNondifferentiabilityError(original, parentTask,
diag::autodiff_global_let_closure_not_differentiable);
return None;
}
}
task = context.registerDifferentiationTask(
initialFn, desiredIndices, invoker);
}
auto loc = original.getLoc();
auto *initialVJPRef = builder.createFunctionRef(
loc, task->getAssociatedFunction(kind));
auto converted =
reapplyFunctionConversion(initialVJPRef, initialFnRef, initVal,
builder, loc);
converted = reapplyFunctionConversion(converted, load, original,
builder, loc);
SILAutoDiffIndices indices(0, desiredIndices.parameters);
return std::make_pair(converted, indices);
}
}

// Find witness method retrieval.
if (auto *witnessMethod =
peerThroughFunctionConversions<WitnessMethodInst>(original)) {
auto loc = witnessMethod->getLoc();
auto requirement = witnessMethod->getMember();
auto *requirementDecl = requirement.getDecl();
Expand Down Expand Up @@ -1735,6 +1766,7 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
return std::make_pair(convertedRef, requirementIndices);
}

// Emit the general opaque function error.
context.emitNondifferentiabilityError(original, parentTask,
diag::autodiff_opaque_function_not_differentiable);
return None;
Expand Down
19 changes: 19 additions & 0 deletions test/AutoDiff/simple_math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,23 @@ SimpleMathTests.test("CaptureGlobal") {
expectEqual(30, gradient(at: 0, in: foo))
}

let foo: (Float) -> Float = { x in
return x * x
}
SimpleMathTests.test("GlobalLet") {
expectEqual(2, gradient(at: 1, in: foo))
}

var foo_diffable: @autodiff (Float) -> (Float)
= differentiableFunction { x in (x * x, { v in 2 * x * v }) }
SimpleMathTests.test("GlobalDiffableFunc") {
expectEqual(2, gradient(at: 1, in: foo_diffable))
expectEqual(2, gradient(at: 1, in: { x in foo_diffable(x) }))
expectEqual(1, gradient(at: 1, in: { (x: Float) -> Float in
foo_diffable = { x in x + 1 };
return foo_diffable(x)
}))
expectEqual(1, gradient(at: 1, in: foo_diffable))
}

runAllTests()
5 changes: 5 additions & 0 deletions test/TensorFlowRuntime/tensor_autodiff_runtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,9 @@ TensorADTests.testAllBackends("SR-9345: OwnedCheckpoints") {
expectEqual(Tensor(1.0), pb(Tensor(1)))
}

let cube: (Tensor<Float>) -> Tensor<Float> = { $0 * $0 * $0 }
TensorADTests.testAllBackends("DifferentiateGlobal") {
expectEqual(Tensor(48), gradient(at: Tensor(4), in: cube))
}

runAllTests()