Skip to content

[AutoDiff] Fix several issues related to captured arguments #41401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/SIL/IR/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
for (auto &param : diffParams) {
// Skip `inout` parameters, which semantically behave as original results
// and always appear as pullback parameters.
if (param.isIndirectInOut())
if (param.isIndirectMutating())
continue;
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
param.getInterfaceType(), lookupConformance,
Expand Down
8 changes: 8 additions & 0 deletions lib/SILOptimizer/Differentiation/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,14 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
config.parameterIndices,
original->getInterfaceType()->castTo<AnyFunctionType>());

if (silParameterIndices->getCapacity() < parameterIndices->getCapacity()) {
assert(original->getCaptureInfo().hasLocalCaptures());
silParameterIndices =
silParameterIndices->extendingCapacity(original->getASTContext(),
parameterIndices->getCapacity());
}

Comment on lines +455 to +462
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a numCaptures parameter to autodiff::getLoweredParameterIndices and make sure we allocate the right capacity in the first place?

Copy link
Contributor Author

@asl asl Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are few workarounds of similar kind sprinkled over the code, so we'd need to eliminate them as well (see https://github.com/apple/swift/blob/main/lib/SILOptimizer/Mandatory/Differentiation.cpp#L518 as an example). Just setting numCaptures seems not work in general as we need to know how the captured arguments will be lowered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's not a perfect solution. Since captured arguments are always at the end of a SIL function, we could get the number of captures from CaptureInfo. Tuple captures won't be splat in SIL arguments, so we might be able to assume the number of captures in CaptureInfo is the number of captures in the SIL function, but I'm not sure entirely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me too :) I'm planning to resolve the TODO afterwards, after checking all possibilities and fixing getLoweredParameterIndices properly. There is a tentative patch for this, but it needs more testing and checking.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as is since similar workarounds already exist, but it would be really nice to fix getLoweredParameterIndices. Let me know if you want to take a stab at that or rather merge this first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sounds good.

// If all indices in `parameterIndices` are in `daParameterIndices`, and
// it has fewer indices than our current candidate and a primitive VJP,
// then `attr` is our new candidate.
Expand Down
36 changes: 36 additions & 0 deletions test/AutoDiff/compiler_crashers_fixed/sr15205-diff-capture.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: %target-swift-frontend -emit-sil -verify %s

// SR-15205: fix assertions related to captured arguments, they should
// treated as constants

import _Differentiation

func outerFunc(value: inout Float) -> (Float, (Float) -> (Float, Float)) {

@differentiable(reverse, wrt: param)
func innerFunc(param: Float, other: Float) -> Float {
value += param * other
return value * param * 2.0
}

let valAndPullback = valueWithPullback(at: value, 2.0, of: innerFunc)
return (value + valAndPullback.value, valAndPullback.pullback)
}

func outerFunc2(value: inout Float) -> (Float, (Float) -> Float) {

@differentiable(reverse, wrt: param)
func innerFunc(param: Float, other: Float) -> Float {
value += param * other
return value * param * 2.0
}

@differentiable(reverse)
func curriedFunc(param: Float) -> Float {
return innerFunc(param: param, other: 3.0)
}

let valAndPullback = valueWithPullback(at: value, of: curriedFunc)
return (value + valAndPullback.value, valAndPullback.pullback)
}