Skip to content

Commit 642f2e4

Browse files
committed
Emit @noDerivative fixit only when valid.
Emit `@noDerivative` fixit only when there is at least one valid differentiability/linearity parameter. Otherwise, adding `@noDerivative` results in an ill-formed `@differentiable` function type.
1 parent f7ae573 commit 642f2e4

File tree

3 files changed

+59
-22
lines changed

3 files changed

+59
-22
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4097,8 +4097,10 @@ ERROR(attr_only_on_parameters_of_differentiable,none,
40974097
ERROR(differentiable_function_type_invalid_parameter,none,
40984098
"parameter type '%0' does not conform to 'Differentiable'"
40994099
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
4100-
"function type is '@differentiable%select{|(linear)}1'; did you want to "
4101-
"add '@noDerivative' to this parameter?", (StringRef, bool))
4100+
"function type is '@differentiable%select{|(linear)}1'"
4101+
"%select{|; did you want to add '@noDerivative' to this parameter?}2",
4102+
(StringRef, /*tangentVectorEqualsSelf*/ bool,
4103+
/*hasValidDifferentiabilityParameter*/ bool))
41024104
ERROR(differentiable_function_type_invalid_result,none,
41034105
"result type '%0' does not conform to 'Differentiable'"
41044106
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "

lib/Sema/TypeCheckType.cpp

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2581,29 +2581,48 @@ bool TypeResolver::resolveASTFunctionTypeParams(
25812581
// SWIFT_ENABLE_TENSORFLOW END
25822582
}
25832583

2584-
// SWIFT_ENABLE_TENSORFLOW
2585-
// All non-`@noDerivative` parameters of `@differentiable` and
2586-
// `@differentiable(linear)` function types must be differentiable.
2587-
if (diffKind != DifferentiabilityKind::NonDifferentiable &&
2588-
resolution.getStage() != TypeResolutionStage::Structural) {
2589-
if (!noDerivative) {
2590-
bool isLinear = diffKind == DifferentiabilityKind::Linear;
2591-
if (!isDifferentiable(ty, /*tangentVectorEqualsSelf*/ isLinear)) {
2592-
diagnose(eltTypeRepr->getLoc(),
2593-
diag::differentiable_function_type_invalid_parameter,
2594-
ty->getString(), isLinear)
2595-
.fixItInsert(eltTypeRepr->getLoc(), "@noDerivative ");
2596-
}
2597-
}
2598-
}
2599-
// SWIFT_ENABLE_TENSORFLOW END
2600-
26012584
auto paramFlags = ParameterTypeFlags::fromParameterType(
26022585
ty, variadic, autoclosure, /*isNonEphemeral*/ false, ownership,
26032586
noDerivative);
26042587
elements.emplace_back(ty, Identifier(), paramFlags);
26052588
}
26062589

2590+
// SWIFT_ENABLE_TENSORFLOW
2591+
// All non-`@noDerivative` parameters of `@differentiable` and
2592+
// `@differentiable(linear)` function types must be differentiable.
2593+
if (diffKind != DifferentiabilityKind::NonDifferentiable &&
2594+
resolution.getStage() != TypeResolutionStage::Structural) {
2595+
bool isLinear = diffKind == DifferentiabilityKind::Linear;
2596+
// Emit `@noDerivative` fixit only if there is at least one valid
2597+
// differentiability/linearity parameter. Otherwise, adding `@noDerivative`
2598+
// produces an ill-formed function type.
2599+
auto hasValidDifferentiabilityParam =
2600+
llvm::find_if(elements, [&](AnyFunctionType::Param param) {
2601+
if (param.isNoDerivative())
2602+
return false;
2603+
return isDifferentiable(param.getPlainType(),
2604+
/*tangentVectorEqualsSelf*/ isLinear);
2605+
}) != elements.end();
2606+
// });
2607+
for (unsigned i = 0, end = inputRepr->getNumElements(); i != end; ++i) {
2608+
auto *eltTypeRepr = inputRepr->getElementType(i);
2609+
auto param = elements[i];
2610+
if (param.isNoDerivative())
2611+
continue;
2612+
auto paramType = param.getPlainType();
2613+
if (isDifferentiable(paramType, /*tangentVectorEqualsSelf*/ isLinear))
2614+
continue;
2615+
auto paramTypeString = paramType->getString();
2616+
auto diagnostic =
2617+
diagnose(eltTypeRepr->getLoc(),
2618+
diag::differentiable_function_type_invalid_parameter,
2619+
paramTypeString, isLinear, hasValidDifferentiabilityParam);
2620+
if (hasValidDifferentiabilityParam)
2621+
diagnostic.fixItInsert(eltTypeRepr->getLoc(), "@noDerivative ");
2622+
}
2623+
}
2624+
// SWIFT_ENABLE_TENSORFLOW END
2625+
26072626
return false;
26082627
}
26092628

test/AutoDiff/differentiable_func_type_type_checking.swift

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,28 @@ let _: @differentiable (Float) throws -> Float
88
//===----------------------------------------------------------------------===//
99

1010
struct NonDiffType { var x: Int }
11+
1112
// FIXME: Properly type-check parameters and the result's differentiability
12-
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'; did you want to add '@noDerivative' to this parameter?}} {{25-25=@noDerivative }}
13+
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
1314
let _: @differentiable (NonDiffType) -> Float
15+
16+
// Emit `@noDerivative` fixit iff there is at least one valid differentiability parameter.
17+
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'; did you want to add '@noDerivative' to this parameter?}} {{32-32=@noDerivative }}
18+
let _: @differentiable (Float, NonDiffType) -> Float
19+
20+
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
21+
let _: @differentiable(linear) (Float) -> NonDiffType
22+
23+
// Emit `@noDerivative` fixit iff there is at least one valid linearity parameter.
24+
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(linear)'; did you want to add '@noDerivative' to this parameter?}} {{40-40=@noDerivative }}
25+
let _: @differentiable(linear) (Float, NonDiffType) -> Float
26+
1427
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
1528
let _: @differentiable (Float) -> NonDiffType
1629

30+
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
31+
let _: @differentiable(linear) (Float) -> NonDiffType
32+
1733
let _: @differentiable(linear) (Float) -> Float
1834

1935
//===----------------------------------------------------------------------===//
@@ -115,7 +131,7 @@ func inferredConformancesGeneric<T, U>(_: @differentiable (Vector<T>) -> Vector<
115131

116132
// expected-error @+5 {{generic signature requires types 'Vector<T>' and 'Vector<T>.TangentVector' to be the same}}
117133
// expected-error @+4 {{generic signature requires types 'Vector<U>' and 'Vector<U>.TangentVector' to be the same}}
118-
// expected-error @+3 {{parameter type 'Vector<T>' does not conform to 'Differentiable' and satisfy 'Vector<T> == Vector<T>.TangentVector', but the enclosing function type is '@differentiable(linear)'; did you want to add '@noDerivative' to this parameter?}}
134+
// expected-error @+3 {{parameter type 'Vector<T>' does not conform to 'Differentiable' and satisfy 'Vector<T> == Vector<T>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
119135
// expected-error @+2 {{result type 'Vector<U>' does not conform to 'Differentiable' and satisfy 'Vector<U> == Vector<U>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
120136
// expected-note @+1 2 {{where 'T' = 'Int'}}
121137
func inferredConformancesGenericLinear<T, U>(_: @differentiable(linear) (Vector<T>) -> Vector<U>) {}
@@ -132,7 +148,7 @@ inferredConformancesGeneric(diff) // okay!
132148
func inferredConformancesGenericResult<T, U>() -> @differentiable (Vector<T>) -> Vector<U> {}
133149
// expected-error @+4 {{generic signature requires types 'Vector<T>' and 'Vector<T>.TangentVector' to be the same}}
134150
// expected-error @+3 {{generic signature requires types 'Vector<U>' and 'Vector<U>.TangentVector' to be the same}}
135-
// expected-error @+2 {{parameter type 'Vector<T>' does not conform to 'Differentiable' and satisfy 'Vector<T> == Vector<T>.TangentVector', but the enclosing function type is '@differentiable(linear)'; did you want to add '@noDerivative' to this parameter?}}
151+
// expected-error @+2 {{parameter type 'Vector<T>' does not conform to 'Differentiable' and satisfy 'Vector<T> == Vector<T>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
136152
// expected-error @+1 {{result type 'Vector<U>' does not conform to 'Differentiable' and satisfy 'Vector<U> == Vector<U>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
137153
func inferredConformancesGenericResultLinear<T, U>() -> @differentiable(linear) (Vector<T>) -> Vector<U> {}
138154

0 commit comments

Comments
 (0)