Skip to content

Commit d69e892

Browse files
author
Marc Rasi
committed
[AutoDiff upstream] forbid @Derivative of protocol req
1 parent 36fff23 commit d69e892

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3503,6 +3503,9 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
35033503
};
35043504

35053505
auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) {
3506+
// TODO(TF-982): Allow derivatives on protocol requirements.
3507+
if (isa<ProtocolDecl>(originalCandidate->getDeclContext()))
3508+
return false;
35063509
return checkFunctionSignature(
35073510
cast<AnyFunctionType>(originalFnType->getCanonicalType()),
35083511
originalCandidate->getInterfaceType()->getCanonicalType(),

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ protocol StaticMethod: Differentiable {
185185
static func generic<T: Differentiable>(_ x: T) -> T
186186
}
187187

188+
extension StaticMethod {
189+
static func foo(_ x: Float) -> Float { x }
190+
static func generic<T: Differentiable>(_ x: T) -> T { x }
191+
}
192+
188193
extension StaticMethod {
189194
@derivative(of: foo)
190195
static func jvpFoo(x: Float) -> (value: Float, differential: (Float) -> Float)
@@ -215,11 +220,16 @@ extension StaticMethod {
215220
// Test instance methods.
216221

217222
protocol InstanceMethod: Differentiable {
218-
// expected-note @+1 {{'foo' defined here}}
219223
func foo(_ x: Self) -> Self
224+
func generic<T: Differentiable>(_ x: T) -> Self
225+
}
226+
227+
extension InstanceMethod {
228+
// expected-note @+1 {{'foo' defined here}}
229+
func foo(_ x: Self) -> Self { x }
220230

221231
// expected-note @+1 {{'generic' defined here}}
222-
func generic<T: Differentiable>(_ x: T) -> Self
232+
func generic<T: Differentiable>(_ x: T) -> Self { self }
223233
}
224234

225235
extension InstanceMethod {
@@ -539,24 +549,14 @@ extension HasStoredProperty {
539549
// Test cross-file derivative registration. Currently unsupported.
540550
// TODO(TF-1021): Lift this restriction.
541551

542-
extension AdditiveArithmetic where Self: Differentiable {
552+
extension FloatingPoint where Self: Differentiable {
543553
// expected-error @+1 {{derivative not in the same file as the original function}}
544-
@derivative(of: +)
545-
static func vjpPlus(x: Self, y: Self) -> (
554+
@derivative(of: rounded)
555+
func vjpRounded() -> (
546556
value: Self,
547-
pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)
548-
) {
549-
return (x + y, { v in (v, v) })
550-
}
551-
}
552-
553-
extension FloatingPoint where Self: Differentiable, Self == Self.TangentVector {
554-
// expected-error @+1 {{derivative not in the same file as the original function}}
555-
@derivative(of: +)
556-
static func vjpPlus(x: Self, y: Self) -> (
557-
value: Self, pullback: (Self) -> (Self, Self)
557+
pullback: (Self.TangentVector) -> (Self.TangentVector)
558558
) {
559-
return (x + y, { v in (v, v) })
559+
fatalError()
560560
}
561561
}
562562

0 commit comments

Comments
 (0)