Skip to content

[AutoDiff] Handle differentiation of @autodiff functions. #21792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions lib/SILOptimizer/Mandatory/TFDifferentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1826,11 +1826,12 @@ static WitnessMethodInst *findWitnessMethod(SILValue value) {
return nullptr;
}

/// Emits a reference to the VJP of `original`, differentiated with respect to a
/// superset of `desiredIndices`. Returns the `SILValue` for the VJP and the
/// actual indices that the VJP is with respect to.
/// Emits a reference to an associated function of `original`, differentiated
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
/// the associated function and the actual indices that the associated function
/// is with respect to.
///
/// On failure, returns `None`.
/// Returns `None` on failure.
///
/// Creates new differentiation tasks, if necessary, using `invoker` as the
/// invoker. Calls `taskCallback` for all newly-created tasks (but may also call
Expand All @@ -1843,15 +1844,27 @@ emitAssociatedFunctionReference(
DifferentiationInvoker invoker,
std::function<void(DifferentiationTask *)> taskCallback) {

auto fnType = original->getType().castTo<SILFunctionType>();
if (fnType->isDifferentiable()) {
SILValue assocFn = builder.createAutoDiffFunctionExtract(original.getLoc(),
kind, 1, original);
if (!fnType->getDifferentiationParameterIndices()
.test(desiredIndices.parameters))
return None;
SILAutoDiffIndices indices(0, desiredIndices.parameters);
return std::make_pair(assocFn, indices);
// If `original` is itself an `AutoDiffFunctionExtractInst` whose kind matches
// the given kind and desired differentiation parameter indices, simply
// extract the associated function of its function operand, retain the
// associated function, and return it.
if (auto *inst = original->getDefiningInstruction()) {
if (auto *adfei = dyn_cast<AutoDiffFunctionExtractInst>(inst)) {
if (adfei->getExtractee() == AutoDiffFunctionExtractee::Original) {
builder.createRetainValue(original.getLoc(), adfei->getFunctionOperand(),
builder.getDefaultAtomicity());
SILValue assocFn = builder.createAutoDiffFunctionExtract(
original.getLoc(), kind, /*differentiationOrder*/ 1,
adfei->getFunctionOperand());
auto autodiffFnType =
adfei->getFunctionOperand()->getType().castTo<SILFunctionType>();
if (autodiffFnType->getDifferentiationParameterIndices().test(
desiredIndices.parameters))
return None;
SILAutoDiffIndices indices(0, desiredIndices.parameters);
return std::make_pair(assocFn, indices);
}
}
}

// TODO: Refactor this function to recursively handle function conversions,
Expand Down
7 changes: 7 additions & 0 deletions test/AutoDiff/custom_derivatives.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,11 @@ CustomDerivativesTests.test("Make a differentiable function") {
expectEqual(20, gradient(at: 10, in: diffableFoo))
}

CustomDerivativesTests.test("Differentiation of @autodiff function") {
let diffableFoo = differentiableFunction { x in
(value: foo(x), pullback: { v in v * x * 2 })
}
expectEqual(20, gradient(at: 10, in: { x in diffableFoo(x) }))
}

runAllTests()