Skip to content

Commit 8bab5c9

Browse files
dan-zhengrxwei
authored andcommitted
[AutoDiff] Handle differentiation of @autodiff functions. (#21792)
* [AutoDiff] Handle differentiation of `@autodiff` functions. Given an `AutoDiffFunctionExtractInst` in PrimalGen, get its VJP by simply extracting the VJP of its function operand. Thanks for detailed guidance from @rxwei. * Fix retain and shorten test. `autodiff_function_extract` consumes its function operand. Thus, the function operand should be retained before `autodiff_function_extract`. * Generalize code to extract given associated function kind. Update comments.
1 parent 66437a8 commit 8bab5c9

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,11 +1826,12 @@ static WitnessMethodInst *findWitnessMethod(SILValue value) {
18261826
return nullptr;
18271827
}
18281828

1829-
/// Emits a reference to the VJP of `original`, differentiated with respect to a
1830-
/// superset of `desiredIndices`. Returns the `SILValue` for the VJP and the
1831-
/// actual indices that the VJP is with respect to.
1829+
/// Emits a reference to an associated function of `original`, differentiated
1830+
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
1831+
/// the associated function and the actual indices that the associated function
1832+
/// is with respect to.
18321833
///
1833-
/// On failure, returns `None`.
1834+
/// Returns `None` on failure.
18341835
///
18351836
/// Creates new differentiation tasks, if necessary, using `invoker` as the
18361837
/// invoker. Calls `taskCallback` for all newly-created tasks (but may also call
@@ -1843,15 +1844,27 @@ emitAssociatedFunctionReference(
18431844
DifferentiationInvoker invoker,
18441845
std::function<void(DifferentiationTask *)> taskCallback) {
18451846

1846-
auto fnType = original->getType().castTo<SILFunctionType>();
1847-
if (fnType->isDifferentiable()) {
1848-
SILValue assocFn = builder.createAutoDiffFunctionExtract(original.getLoc(),
1849-
kind, 1, original);
1850-
if (!fnType->getDifferentiationParameterIndices()
1851-
.test(desiredIndices.parameters))
1852-
return None;
1853-
SILAutoDiffIndices indices(0, desiredIndices.parameters);
1854-
return std::make_pair(assocFn, indices);
1847+
// If `original` is itself an `AutoDiffFunctionExtractInst` whose kind matches
1848+
// the given kind and desired differentiation parameter indices, simply
1849+
// extract the associated function of its function operand, retain the
1850+
// associated function, and return it.
1851+
if (auto *inst = original->getDefiningInstruction()) {
1852+
if (auto *adfei = dyn_cast<AutoDiffFunctionExtractInst>(inst)) {
1853+
if (adfei->getExtractee() == AutoDiffFunctionExtractee::Original) {
1854+
builder.createRetainValue(original.getLoc(), adfei->getFunctionOperand(),
1855+
builder.getDefaultAtomicity());
1856+
SILValue assocFn = builder.createAutoDiffFunctionExtract(
1857+
original.getLoc(), kind, /*differentiationOrder*/ 1,
1858+
adfei->getFunctionOperand());
1859+
auto autodiffFnType =
1860+
adfei->getFunctionOperand()->getType().castTo<SILFunctionType>();
1861+
if (autodiffFnType->getDifferentiationParameterIndices().test(
1862+
desiredIndices.parameters))
1863+
return None;
1864+
SILAutoDiffIndices indices(0, desiredIndices.parameters);
1865+
return std::make_pair(assocFn, indices);
1866+
}
1867+
}
18551868
}
18561869

18571870
// TODO: Refactor this function to recursively handle function conversions,

test/AutoDiff/custom_derivatives.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,11 @@ CustomDerivativesTests.test("Make a differentiable function") {
2424
expectEqual(20, gradient(at: 10, in: diffableFoo))
2525
}
2626

27+
CustomDerivativesTests.test("Differentiation of @autodiff function") {
28+
let diffableFoo = differentiableFunction { x in
29+
(value: foo(x), pullback: { v in v * x * 2 })
30+
}
31+
expectEqual(20, gradient(at: 10, in: { x in diffableFoo(x) }))
32+
}
33+
2734
runAllTests()

0 commit comments

Comments
 (0)