Skip to content

Commit 77e6c08

Browse files
committed
[AutoDiff] [CS] Formalize function tupling behavior for @differentiable
Reject tupling into a `@differentiable` function, and allow the stripping of `@noDerivative` given we're also losing `@differentiable`.
1 parent cab39bf commit 77e6c08

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

lib/Sema/CSSimplify.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,19 +2205,27 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
22052205
// take multiple arguments to be passed as an argument in places
22062206
// that expect a function that takes a single tuple (of the same
22072207
// arity);
2208-
auto canImplodeParams = [&](ArrayRef<AnyFunctionType::Param> params) {
2208+
auto canImplodeParams = [&](ArrayRef<AnyFunctionType::Param> params,
2209+
const FunctionType *destFn) {
22092210
if (params.size() == 1)
22102211
return false;
22112212

2213+
// We do not support imploding into a @differentiable function.
2214+
if (destFn->isDifferentiable())
2215+
return false;
2216+
22122217
for (auto &param : params) {
22132218
// We generally cannot handle parameter flags, though we can carve out an
22142219
// exception for ownership flags such as __owned, which we can thunk, and
22152220
// flags that can freely dropped from a function type such as
2216-
// @_nonEphemeral.
2221+
// @_nonEphemeral. Note that @noDerivative can also be freely dropped, as
2222+
// we've already ensured that the destination function is not
2223+
// @differentiable.
22172224
auto flags = param.getParameterFlags();
22182225
flags = flags.withValueOwnership(
22192226
param.isInOut() ? ValueOwnership::InOut : ValueOwnership::Default);
2220-
flags = flags.withNonEphemeral(false);
2227+
flags = flags.withNonEphemeral(false)
2228+
.withNoDerivative(false);
22212229
if (!flags.isNone())
22222230
return false;
22232231
}
@@ -2264,12 +2272,12 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
22642272
if (last != path.rend()) {
22652273
if (last->getKind() == ConstraintLocator::ApplyArgToParam) {
22662274
if (isSingleTupleParam(ctx, func2Params) &&
2267-
canImplodeParams(func1Params)) {
2275+
canImplodeParams(func1Params, /*destFn*/ func2)) {
22682276
implodeParams(func1Params);
22692277
increaseScore(SK_FunctionConversion);
22702278
} else if (!ctx.isSwiftVersionAtLeast(5) &&
22712279
isSingleTupleParam(ctx, func1Params) &&
2272-
canImplodeParams(func2Params)) {
2280+
canImplodeParams(func2Params, /*destFn*/ func1)) {
22732281
auto *simplified = locator.trySimplifyToExpr();
22742282
// We somehow let tuple unsplatting function conversions
22752283
// through in some cases in Swift 4, so let's let that
@@ -2305,11 +2313,11 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
23052313
// 2. `case .bar(let tuple) = e` allows to match multiple
23062314
// parameters with a single tuple argument.
23072315
if (isSingleTupleParam(ctx, func1Params) &&
2308-
canImplodeParams(func2Params)) {
2316+
canImplodeParams(func2Params, /*destFn*/ func1)) {
23092317
implodeParams(func2Params);
23102318
increaseScore(SK_FunctionConversion);
23112319
} else if (isSingleTupleParam(ctx, func2Params) &&
2312-
canImplodeParams(func1Params)) {
2320+
canImplodeParams(func1Params, /*destFn*/ func2)) {
23132321
implodeParams(func1Params);
23142322
increaseScore(SK_FunctionConversion);
23152323
}
@@ -2320,7 +2328,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
23202328
auto *anchor = locator.trySimplifyToExpr();
23212329
if (isa_and_nonnull<ClosureExpr>(anchor) &&
23222330
isSingleTupleParam(ctx, func2Params) &&
2323-
canImplodeParams(func1Params)) {
2331+
canImplodeParams(func1Params, /*destFn*/ func2)) {
23242332
auto *fix = AllowClosureParamDestructuring::create(
23252333
*this, func2, getConstraintLocator(anchor));
23262334
if (recordFix(fix))

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,3 +726,24 @@ struct Accessors: Differentiable {
726726
// expected-error @+1 {{cannot differentiate functions returning opaque result types}}
727727
@differentiable(reverse)
728728
func opaqueResult(_ x: Float) -> some Differentiable { x }
729+
730+
// Test the function tupling conversion with @differentiable.
731+
func tuplify<Ts, U>(_ fn: @escaping (Ts) -> U) -> (Ts) -> U { fn }
732+
func tuplifyDifferentiable<Ts : Differentiable, U>(_ fn: @escaping @differentiable(reverse) (Ts) -> U) -> @differentiable(reverse) (Ts) -> U { fn }
733+
734+
func testTupling(withoutNoDerivative: @escaping @differentiable(reverse) (Float, Float) -> Float,
735+
withNoDerivative: @escaping @differentiable(reverse) (Float, @noDerivative Float) -> Float) {
736+
// We support tupling of differentiable functions as long as they drop @differentiable.
737+
let _: ((Float, Float)) -> Float = tuplify(withoutNoDerivative)
738+
let fn1 = tuplify(withoutNoDerivative)
739+
_ = fn1((0, 0))
740+
741+
// In this case we also drop @noDerivative.
742+
let _: ((Float, Float)) -> Float = tuplify(withNoDerivative)
743+
let fn2 = tuplify(withNoDerivative)
744+
_ = fn2((0, 0))
745+
746+
// We do not support tupling into an @differentiable function.
747+
let _ = tuplifyDifferentiable(withoutNoDerivative) // expected-error {{cannot convert value of type '@differentiable(reverse) (Float, Float) -> Float' to expected argument type '@differentiable(reverse) (Float) -> Float'}}
748+
let _ = tuplifyDifferentiable(withNoDerivative) // expected-error {{cannot convert value of type '@differentiable(reverse) (Float, @noDerivative Float) -> Float' to expected argument type '@differentiable(reverse) (Float) -> Float'}}
749+
}

0 commit comments

Comments
 (0)