Skip to content

Commit b0a038b

Browse files
authored
Merge pull request #41401 from asl/sr15205-fix
[AutoDiff] Fix several issues related to captured arguments
2 parents 4c792f9 + 91458b4 commit b0a038b

File tree

3 files changed

+45
-1
lines changed

3 files changed

+45
-1
lines changed

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
714714
for (auto &param : diffParams) {
715715
// Skip `inout` parameters, which semantically behave as original results
716716
// and always appear as pullback parameters.
717-
if (param.isIndirectInOut())
717+
if (param.isIndirectMutating())
718718
continue;
719719
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
720720
param.getInterfaceType(), lookupConformance,

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,14 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
452452
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
453453
config.parameterIndices,
454454
original->getInterfaceType()->castTo<AnyFunctionType>());
455+
456+
if (silParameterIndices->getCapacity() < parameterIndices->getCapacity()) {
457+
assert(original->getCaptureInfo().hasLocalCaptures());
458+
silParameterIndices =
459+
silParameterIndices->extendingCapacity(original->getASTContext(),
460+
parameterIndices->getCapacity());
461+
}
462+
455463
// If all indices in `parameterIndices` are in `daParameterIndices`, and
456464
// it has fewer indices than our current candidate and a primitive VJP,
457465
// then `attr` is our new candidate.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
3+
// SR-15205: fix assertions related to captured arguments, they should
4+
// treated as constants
5+
6+
import _Differentiation
7+
8+
func outerFunc(value: inout Float) -> (Float, (Float) -> (Float, Float)) {
9+
10+
@differentiable(reverse, wrt: param)
11+
func innerFunc(param: Float, other: Float) -> Float {
12+
value += param * other
13+
return value * param * 2.0
14+
}
15+
16+
let valAndPullback = valueWithPullback(at: value, 2.0, of: innerFunc)
17+
return (value + valAndPullback.value, valAndPullback.pullback)
18+
}
19+
20+
func outerFunc2(value: inout Float) -> (Float, (Float) -> Float) {
21+
22+
@differentiable(reverse, wrt: param)
23+
func innerFunc(param: Float, other: Float) -> Float {
24+
value += param * other
25+
return value * param * 2.0
26+
}
27+
28+
@differentiable(reverse)
29+
func curriedFunc(param: Float) -> Float {
30+
return innerFunc(param: param, other: 3.0)
31+
}
32+
33+
let valAndPullback = valueWithPullback(at: value, of: curriedFunc)
34+
return (value + valAndPullback.value, valAndPullback.pullback)
35+
}
36+

0 commit comments

Comments
 (0)