Skip to content

Commit ffe8d78

Browse files
committed
[AutoDiff] Fix differentiation for non-wrt inout parameters.
Fix SIL differential function type calculation to handle non-wrt `inout` parameters. Patch `SILFunctionType::getDifferentiabilityResultIndices` to prevent returning empty result indices for `@differentiable` function types with no formal results where all `inout` parameters are `@noDerivative`. TF-1305 tracks a robust fix. Resolves SR-13305. Exposes TF-1305: parameter/result differentiability hole for `inout` parameters.
1 parent 6d534ec commit ffe8d78

File tree

3 files changed

+144
-4
lines changed

3 files changed

+144
-4
lines changed

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,23 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() {
235235
resultIndices.push_back(resultAndIndex.index());
236236
// Check `inout` parameters.
237237
for (auto inoutParamAndIndex : enumerate(getIndirectMutatingParameters()))
238-
if (inoutParamAndIndex.value().getDifferentiability() !=
239-
SILParameterDifferentiability::NotDifferentiable)
238+
// FIXME(TF-1305): The `getResults().empty()` condition is a hack.
239+
//
240+
// Currently, an `inout` parameter can either be:
241+
// 1. Both a differentiability parameter and a differentiability result.
242+
// 2. `@noDerivative`: neither a differentiability parameter nor a
243+
// differentiability result.
244+
// However, there is no way to represent an `inout` parameter that:
245+
// 3. Is a differentiability parameter but not a differentiability result.
246+
// 4. Is a differentiability result but not a differentiability parameter.
247+
//
248+
// See TF-1305 for solution ideas. For now, `@noDerivative` `inout`
249+
// parameters are not treated as differentiability results, unless the
250+
// original function has no formal results, which case all `inout`
251+
// parameters are treated as differentiability results.
252+
if (getResults().empty() ||
253+
inoutParamAndIndex.value().getDifferentiability() !=
254+
SILParameterDifferentiability::NotDifferentiable)
240255
resultIndices.push_back(getNumResults() + inoutParamAndIndex.index());
241256
auto numSemanticResults =
242257
getNumResults() + getNumIndirectMutatingParameters();
@@ -428,8 +443,9 @@ static CanSILFunctionType getAutoDiffDifferentialType(
428443
}
429444
}
430445
SmallVector<SILResultInfo, 1> differentialResults;
431-
if (inoutParamIndices->isEmpty()) {
432-
for (auto resultIndex : resultIndices->getIndices()) {
446+
for (auto resultIndex : resultIndices->getIndices()) {
447+
// Handle formal original result.
448+
if (resultIndex < originalFnTy->getNumResults()) {
433449
auto &result = originalResults[resultIndex];
434450
auto resultTan =
435451
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
@@ -448,8 +464,27 @@ static CanSILFunctionType getAutoDiffDifferentialType(
448464
substReplacements.push_back(resultTanType);
449465
differentialResults.push_back({gpType, resultConv});
450466
}
467+
continue;
451468
}
469+
// Handle original `inout` parameter.
470+
auto inoutParamIndex = resultIndex - originalFnTy->getNumResults();
471+
auto inoutParamIt = std::next(
472+
originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
473+
auto paramIndex =
474+
std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt);
475+
// If the original `inout` parameter is a differentiability parameter, then
476+
// it already has a corresponding differential parameter. Skip adding a
477+
// corresponding differential result.
478+
if (parameterIndices->contains(paramIndex))
479+
continue;
480+
auto inoutParam = originalFnTy->getParameters()[paramIndex];
481+
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
482+
lookupConformance);
483+
assert(paramTan && "Parameter type does not have a tangent space?");
484+
differentialResults.push_back(
485+
{paramTan->getCanonicalType(), ResultConvention::Indirect});
452486
}
487+
453488
SubstitutionMap substitutions;
454489
if (!substGenericParams.empty()) {
455490
auto genericSig =
@@ -710,7 +745,9 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
710745
CanGenericSignature derivativeFnInvocationGenSig,
711746
bool isReabstractionThunk) {
712747
assert(parameterIndices);
748+
assert(!parameterIndices->isEmpty() && "Parameter indices must not be empty");
713749
assert(resultIndices);
750+
assert(!resultIndices->isEmpty() && "Result indices must not be empty");
714751
auto &ctx = getASTContext();
715752

716753
// Look up result in cache.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
// REQUIRES: asserts
3+
4+
import _Differentiation
5+
6+
// SR-13305: Test protocol witness thunk for `@differentiable` protocol
7+
// requirement, where the required method has a non-wrt `inout` parameter
8+
// that should be treated as a differentiability result.
9+
10+
protocol SR_13305_Protocol {
11+
@differentiable(wrt: x)
12+
func method(x: Float, y: inout Float)
13+
}
14+
15+
struct SR_13305_Struct: SR_13305_Protocol {
16+
@differentiable(wrt: x)
17+
func method(x: Float, y: inout Float) {
18+
y = y * x
19+
}
20+
}
21+
22+
// Original crash:
23+
// Assertion failed: (!array.empty() && "claiming next from empty array!"), function claimNext, file /Users/danielzheng/swift-build/swift/lib/SILGen/SILGenPoly.cpp, line 112.
24+
// Stack dump:
25+
// ...
26+
// 1. Swift version 5.3-dev (LLVM f8bd914aadc2e7b, Swift ba9c433c81d51ea)
27+
// 2. While evaluating request ASTLoweringRequest(Lowering AST to SIL for module main)
28+
// 3. While generating SIL witness table protocol conformance to 'SR_13305_Protocol' (at sr-13305.swift:7:1) for type 'SR_13305_Struct' (declared at [sr-13305.swift:12:1 - line:17:1] RangeText="struct SR_13305_Struct: SR_13305_Protocol {
29+
// @differentiable(wrt: x)
30+
// func method(x: Float, y: inout Float) {
31+
// y = y * x
32+
// }
33+
// ")
34+
// 4. While generating protocol witness thunk SIL function "@AD__$s4main15SR_13305_StructVAA0B15_13305_ProtocolA2aDP6method1x1yySf_SfztFTW_jvp_SUU".
35+
// for 'method(x:y:)' (at sr-13305.swift:14:3)
36+
// 5. While emitting reabstraction thunk in SIL function "@$sSfIegy_S2fIegyd_TR".
37+
// ...
38+
// 7 swift-frontend 0x0000000100fe80ad swift::SILResultInfo const& claimNext<swift::SILResultInfo>(llvm::ArrayRef<swift::SILResultInfo>&) + 93
39+
// 8 swift-frontend 0x0000000100fe6cc0 (anonymous namespace)::ResultPlanner::claimNextInnerResult((anonymous namespace)::ResultPlanner::PlanData&) + 32
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import DifferentiationUnittest
5+
import StdlibUnittest
6+
7+
var InoutParameterAutoDiffTests = TestSuite("InoutParameterDifferentiation")
8+
9+
// SR-13305: Test function with non-wrt `inout` parameter, which should be
10+
// treated as a differentiability result.
11+
12+
protocol SR_13305_Protocol {
13+
@differentiable(wrt: x)
14+
func method(_ x: Float, _ y: inout Float)
15+
16+
@differentiable(wrt: x)
17+
func genericMethod<T: Differentiable>(_ x: T, _ y: inout T)
18+
}
19+
20+
InoutParameterAutoDiffTests.test("non-wrt inout parameter") {
21+
struct SR_13305_Struct: SR_13305_Protocol {
22+
@differentiable(wrt: x)
23+
func method(_ x: Float, _ y: inout Float) {
24+
y = y * x
25+
}
26+
27+
@differentiable(wrt: x)
28+
func genericMethod<T: Differentiable>(_ x: T, _ y: inout T) {
29+
y = x
30+
}
31+
}
32+
33+
@differentiable(wrt: x)
34+
func foo(_ s: SR_13305_Struct, _ x: Float, _ y: Float) -> Float {
35+
var y = y
36+
s.method(x, &y)
37+
return y
38+
}
39+
40+
@differentiable(wrt: x)
41+
func fooGeneric<T: SR_13305_Protocol>(_ s: T, _ x: Float, _ y: Float) -> Float {
42+
var y = y
43+
s.method(x, &y)
44+
return x
45+
}
46+
47+
let s = SR_13305_Struct()
48+
49+
do {
50+
let (value, (dx, dy)) = valueWithGradient(at: 2, 3, in: { foo(s, $0, $1) })
51+
expectEqual(6, value)
52+
expectEqual((3, 2), (dx, dy))
53+
}
54+
expectEqual((value: 6, gradient: 3), valueWithGradient(at: 2, in: { foo(s, $0, 3) }))
55+
56+
do {
57+
let (value, (dx, dy)) = valueWithGradient(at: 2, 3, in: { fooGeneric(s, $0, $1) })
58+
expectEqual(2, value)
59+
expectEqual((1, 0), (dx, dy))
60+
}
61+
expectEqual((value: 2, gradient: 1), valueWithGradient(at: 2, in: { fooGeneric(s, $0, 3) }))
62+
}
63+
64+
runAllTests()

0 commit comments

Comments
 (0)