Skip to content

Commit b6acb6f

Browse files
authored
[AutoDiff] Check @noDerivative for function type comparison (#67121)
As described in the issue #62922, the compiler should not allow to discard @noDerivative attribute and keep @differentiable. The patch adds a diagnostic for this case. Resolves #62922.
1 parent 019bd3c commit b6acb6f

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

lib/Sema/CSSimplify.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3496,6 +3496,13 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
34963496
return getTypeMatchFailure(argumentLocator);
34973497
}
34983498

3499+
// If functions are differentiable, ensure that @noDerivative is not
3500+
// discarded.
3501+
if (func1->isDifferentiable() && func2->isDifferentiable() &&
3502+
func1Param.isNoDerivative() && !func2Param.isNoDerivative()) {
3503+
return getTypeMatchFailure(argumentLocator);
3504+
}
3505+
34993506
// FIXME: We should check value ownership too, but it's not completely
35003507
// trivial because of inout-to-pointer conversions.
35013508

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,8 @@ struct TF_675 : Differentiable {
741741
let _: @differentiable(reverse) (Float) -> Float = TF_675().method
742742

743743
// TF-918: Test parameter subset thunk + partially-applied original function.
744-
let _: @differentiable(reverse) (Float, Float) -> Float = (+) as @differentiable(reverse) (Float, @noDerivative Float) -> Float
744+
let _: @differentiable(reverse) (Float, @noDerivative Float) -> Float = (+) as @differentiable(reverse) (Float, Float) -> Float
745+
let _: @differentiable(reverse) (Float, @noDerivative Float) -> Float = (+) as @differentiable(reverse) (Float, @noDerivative Float) -> Float
745746

746747
//===----------------------------------------------------------------------===//
747748
// Differentiation in fragile functions

test/Constraints/noderivative.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
import _Differentiation
4+
5+
// Allow Type -> @noDerivative Type
6+
//
7+
func test1(_ foo: @escaping @differentiable(reverse) (Float, Float) -> Float) {
8+
let fn: @differentiable(reverse) (Float, @noDerivative Float) -> Float = foo
9+
_ = fn(0, 0)
10+
}
11+
12+
// Allow @noDerivative Type -> Type when LHS function is not differentiable
13+
//
14+
func test2(_ foo: @escaping @differentiable(reverse) (Float, @noDerivative Float) -> Float) {
15+
let fn: (Float, Float) -> Float = foo
16+
_ = fn(0, 0)
17+
}
18+
19+
// Disallow @noDerivative Type -> Type when LHS function is also differentiable
20+
//
21+
func test3(_ foo: @escaping @differentiable(reverse) (Float, @noDerivative Float) -> Float) {
22+
// expected-error @+1 {{cannot convert value of type '@differentiable(reverse) (Float, @noDerivative Float) -> Float' to specified type '@differentiable(reverse) (Float, Float) -> Float'}}
23+
let fn: @differentiable(reverse) (Float, Float) -> Float = foo
24+
_ = fn(0, 0)
25+
}

0 commit comments

Comments
 (0)