Skip to content

Commit cf52777

Browse files
authored
[AutoDiff] Make force-unwrapping differentiable. (#26826)
`Optional` force-unwrapping is a mathematically transposable operation. This patch makes force-unwrapping differentiable by applying its transpose on adjoint buffers. ```swift func bla<T: Differentiable & FloatingPoint>(_ t: T) -> (T, Float) where T == T.TangentVector { gradient(at: t, Float(1)) { (x, y) in (x as! Float) * y } } print(bla(Float(2))) // (1, 2) ``` Resolves [TF-455](https://bugs.swift.org/browse/TF-455).
1 parent 51886e4 commit cf52777

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,6 +1820,12 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
18201820
if (isVaried(cai->getSrc(), i))
18211821
recursivelySetVaried(cai->getDest(), i);
18221822
}
1823+
// Handle `unconditional_checked_cast_addr`.
1824+
else if (auto *uccai =
1825+
dyn_cast<UnconditionalCheckedCastAddrInst>(&inst)) {
1826+
if (isVaried(uccai->getSrc(), i))
1827+
recursivelySetVaried(uccai->getDest(), i);
1828+
}
18231829
// Handle `tuple_element_addr`.
18241830
else if (auto *teai = dyn_cast<TupleElementAddrInst>(&inst)) {
18251831
if (isVaried(teai->getOperand(), i)) {
@@ -1941,6 +1947,12 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
19411947
if (isUseful(cai->getDest(), i))
19421948
propagateUsefulThroughBuffer(cai->getSrc(), i);
19431949
}
1950+
// Handle `unconditional_checked_cast_addr`.
1951+
else if (auto *uccai =
1952+
dyn_cast<UnconditionalCheckedCastAddrInst>(&inst)) {
1953+
if (isUseful(uccai->getDest(), i))
1954+
propagateUsefulThroughBuffer(uccai->getSrc(), i);
1955+
}
19441956
// Handle everything else.
19451957
else if (llvm::any_of(inst.getResults(),
19461958
[&](SILValue res) { return isUseful(res, i); })) {
@@ -6601,6 +6613,27 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
66016613
}
66026614
}
66036615

6616+
/// Handle `unconditional_checked_cast_addr` instruction.
6617+
/// Original: y = unconditional_checked_cast_addr x
6618+
/// Adjoint: adj[x] += unconditional_checked_cast_addr adj[y]
6619+
void visitUnconditionalCheckedCastAddrInst(
6620+
UnconditionalCheckedCastAddrInst *uccai) {
6621+
auto *bb = uccai->getParent();
6622+
auto &adjDest = getAdjointBuffer(bb, uccai->getDest());
6623+
auto &adjSrc = getAdjointBuffer(bb, uccai->getSrc());
6624+
if (errorOccurred)
6625+
return;
6626+
auto destType = remapType(adjDest->getType());
6627+
auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType());
6628+
builder.createUnconditionalCheckedCastAddr(
6629+
uccai->getLoc(), adjDest, adjDest->getType().getASTType(), castBuf,
6630+
adjSrc->getType().getASTType());
6631+
addToAdjointBuffer(bb, uccai->getSrc(), castBuf, uccai->getLoc());
6632+
builder.emitDestroyAddrAndFold(uccai->getLoc(), castBuf);
6633+
builder.createDeallocStack(uccai->getLoc(), castBuf);
6634+
emitZeroIndirect(destType.getASTType(), adjDest, uccai->getLoc());
6635+
}
6636+
66046637
#define NOT_DIFFERENTIABLE(INST, DIAG) \
66056638
void visit##INST##Inst(INST##Inst *inst) { \
66066639
getContext().emitNondifferentiabilityError( \

test/AutoDiff/simple_math.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,4 +315,13 @@ SimpleMathTests.test("SubsetIndices") {
315315
expectEqual(4, gradWRTNonDiff { x, y in x + Float(y) })
316316
}
317317

318+
SimpleMathTests.test("ForceUnwrapping") {
319+
func bla<T: Differentiable & FloatingPoint>(_ t: T) -> (T, Float) where T == T.TangentVector {
320+
gradient(at: t, Float(1)) { (x, y) in
321+
(x as! Float) * y
322+
}
323+
}
324+
expectEqual((1, 2), bla(Float(2)))
325+
}
326+
318327
runAllTests()

0 commit comments

Comments
 (0)