Skip to content

Commit b88a119

Browse files
authored
[AutoDiff] Fix cloned curry thunk verification error. (#28662)
Differentiation has special support for canonicalizing `differentiable_function` instructions of curry thunk applications: curry thunks are cloned to a version returning a `differentiable_function` of the original return value. In cloned curry thunks, the `differentiable_function` instruction is now created right after the return value's defining instruction, not at the end of the function. This avoids `alloc_stack`/`dealloc_stack` ordering issues. Resolves TF-1039.
1 parent c82b9f3 commit b88a119

File tree

2 files changed

+74
-17
lines changed

2 files changed

+74
-17
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,12 +1392,14 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
13921392

13931393
auto thunkTy = thunk->getLoweredFunctionType();
13941394
auto thunkResult = thunkTy->getSingleResult();
1395-
if (auto resultFnTy = thunkResult.getInterfaceType()->getAs<SILFunctionType>()) {
1396-
// Construct new curry thunk type with `@differentiable` result.
1397-
auto diffableResultFnTy = resultFnTy->getWithExtInfo(
1398-
resultFnTy->getExtInfo()
1399-
.withDifferentiabilityKind(DifferentiabilityKind::Normal));
1400-
auto newThunkResult = thunkResult.getWithInterfaceType(diffableResultFnTy);
1395+
if (auto resultFnTy =
1396+
thunkResult.getInterfaceType()->getAs<SILFunctionType>()) {
1397+
// Construct new curry thunk type with `@differentiable` function
1398+
// result.
1399+
auto diffResultFnTy = resultFnTy->getWithExtInfo(
1400+
resultFnTy->getExtInfo().withDifferentiabilityKind(
1401+
DifferentiabilityKind::Normal));
1402+
auto newThunkResult = thunkResult.getWithInterfaceType(diffResultFnTy);
14011403
auto thunkType = SILFunctionType::get(
14021404
thunkTy->getSubstGenericSignature(), thunkTy->getExtInfo(),
14031405
thunkTy->getCoroutineKind(), thunkTy->getCalleeConvention(),
@@ -1425,12 +1427,18 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
14251427
cloner.run();
14261428
auto *retInst =
14271429
cast<ReturnInst>(newThunk->findReturnBB()->getTerminator());
1428-
SILBuilder thunkBuilder(retInst);
1429-
auto *dfi = context.createDifferentiableFunction(thunkBuilder, loc,
1430-
parameterIndices,
1431-
retInst->getOperand());
1430+
auto returnValue = retInst->getOperand();
1431+
// Create `differentiable_function` instruction directly after the
1432+
// defining instruction (e.g. `partial_apply`) of the returned value.
1433+
// Note: `differentiable_function` is not created at the end of the
1434+
// new thunk to avoid `alloc_stack`/`dealloc_stack` ordering issues.
1435+
SILBuilder dfiBuilder(
1436+
std::next(returnValue->getDefiningInstruction()->getIterator()));
1437+
auto *dfi = context.createDifferentiableFunction(
1438+
dfiBuilder, loc, parameterIndices, returnValue);
14321439
context.setResultIndex(dfi, resultIndex);
1433-
thunkBuilder.createReturn(loc, dfi);
1440+
dfiBuilder.setInsertionPoint(newThunk->findReturnBB());
1441+
dfiBuilder.createReturn(loc, dfi);
14341442
retInst->eraseFromParent();
14351443

14361444
context.recordGeneratedFunction(newThunk);
@@ -1450,12 +1458,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
14501458
auto *newApply = builder.createApply(
14511459
ai->getLoc(), newThunkRef, ai->getSubstitutionMap(), newArgs,
14521460
ai->isNonThrowing());
1453-
for (auto arg : newArgsToDestroy) {
1454-
if (arg->getType().isObject())
1455-
builder.emitDestroyValueOperation(loc, arg);
1456-
else
1457-
builder.emitDestroyAddr(loc, arg);
1458-
}
1461+
for (auto arg : newArgsToDestroy)
1462+
builder.emitDestroyOperation(loc, arg);
14591463
for (auto *alloc : newBuffersToDealloc)
14601464
builder.createDeallocStack(loc, alloc);
14611465
return newApply;
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: %target-swift-emit-sil %s -verify
2+
// REQUIRES: asserts
3+
4+
// TF-1039: Cloned curry thunks generated by differentiation should create a
5+
// `differentiable_function` instruction before any `dealloc_stack` instructions
6+
// to prevent `alloc_stack`/`dealloc_stack` ordering issues.
7+
8+
protocol P {
9+
@differentiable
10+
func foo(_ x: Float) -> Float
11+
}
12+
extension P {
13+
@derivative(of: foo)
14+
func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
15+
return (x, { $0 })
16+
}
17+
}
18+
struct S: P {
19+
@differentiable
20+
func foo(_ x: Float) -> Float { x }
21+
}
22+
func foo<T: P>(_ x: T) {
23+
// Curry thunk emitted here for `x.foo`.
24+
_ = gradient(at: 1, in: x.foo)
25+
}
26+
27+
// SIL verification failed: stack dealloc does not match most recent stack alloc: op == state.Stack.back()
28+
// Verifying instruction:
29+
// %2 = alloc_stack $τ_0_0 // users: %7, %5, %9, %8, %3
30+
// -> dealloc_stack %2 : $*τ_0_0 // id: %9
31+
// In function:
32+
// // AD__$s5curry1PP3fooyS2fFTc__differentiable_curry_thunk_src_0_wrt_0
33+
// sil shared [thunk] @AD__$s5curry1PP3fooyS2fFTc__differentiable_curry_thunk_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0) -> @owned @differentiable @callee_guaranteed (Float) -> Float {
34+
// // %0 // user: %3
35+
// bb0(%0 : $*τ_0_0):
36+
// %1 = witness_method $τ_0_0, #P.foo!1 : <Self where Self : P> (Self) -> (Float) -> Float : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> Float // user: %8
37+
// %2 = alloc_stack $τ_0_0 // users: %7, %5, %9, %8, %3
38+
// copy_addr %0 to [initialization] %2 : $*τ_0_0 // id: %3
39+
// %4 = alloc_stack $τ_0_0 // users: %15, %11, %5
40+
// copy_addr %2 to [initialization] %4 : $*τ_0_0 // id: %5
41+
// %6 = alloc_stack $τ_0_0 // users: %14, %13, %7
42+
// copy_addr %2 to [initialization] %6 : $*τ_0_0 // id: %7
43+
// %8 = partial_apply [callee_guaranteed] %1<τ_0_0>(%2) : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> Float // user: %16
44+
// dealloc_stack %2 : $*τ_0_0 // id: %9
45+
// %10 = witness_method $τ_0_0, #P.foo!1.jvp.SU : <Self where Self : P> (Self) -> (Float) -> Float : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %11
46+
// %11 = partial_apply [callee_guaranteed] %10<τ_0_0>(%4) : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %16
47+
// %12 = witness_method $τ_0_0, #P.foo!1.vjp.SU : <Self where Self : P> (Self) -> (Float) -> Float : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %13
48+
// %13 = partial_apply [callee_guaranteed] %12<τ_0_0>(%6) : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (Float, @in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %16
49+
// dealloc_stack %6 : $*τ_0_0 // id: %14
50+
// dealloc_stack %4 : $*τ_0_0 // id: %15
51+
// %16 = differentiable_function [parameters 0] %8 : $@callee_guaranteed (Float) -> Float with_derivative {%11 : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %13 : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} // user: %17
52+
// return %16 : $@differentiable @callee_guaranteed (Float) -> Float // id: %17
53+
// } // end sil function 'AD__$s5curry1PP3fooyS2fFTc__differentiable_curry_thunk_src_0_wrt_0'

0 commit comments

Comments
 (0)