Skip to content

Commit b6e0ffd

Browse files
authored
Merge pull request #29031 from marcrasi/forbid-protocol-req
2 parents 19f0d52 + 29465f8 commit b6e0ffd

File tree

2 files changed

+37
-17
lines changed

2 files changed

+37
-17
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3498,6 +3498,9 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
34983498
};
34993499

35003500
auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) {
3501+
// TODO(TF-982): Allow derivatives on protocol requirements.
3502+
if (isa<ProtocolDecl>(originalCandidate->getDeclContext()))
3503+
return false;
35013504
return checkFunctionSignature(
35023505
cast<AnyFunctionType>(originalFnType->getCanonicalType()),
35033506
originalCandidate->getInterfaceType()->getCanonicalType(),

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ protocol StaticMethod: Differentiable {
185185
static func generic<T: Differentiable>(_ x: T) -> T
186186
}
187187

188+
extension StaticMethod {
189+
static func foo(_ x: Float) -> Float { x }
190+
static func generic<T: Differentiable>(_ x: T) -> T { x }
191+
}
192+
188193
extension StaticMethod {
189194
@derivative(of: foo)
190195
static func jvpFoo(x: Float) -> (value: Float, differential: (Float) -> Float)
@@ -215,11 +220,16 @@ extension StaticMethod {
215220
// Test instance methods.
216221

217222
protocol InstanceMethod: Differentiable {
218-
// expected-note @+1 {{'foo' defined here}}
219223
func foo(_ x: Self) -> Self
224+
func generic<T: Differentiable>(_ x: T) -> Self
225+
}
226+
227+
extension InstanceMethod {
228+
// expected-note @+1 {{'foo' defined here}}
229+
func foo(_ x: Self) -> Self { x }
220230

221231
// expected-note @+1 {{'generic' defined here}}
222-
func generic<T: Differentiable>(_ x: T) -> Self
232+
func generic<T: Differentiable>(_ x: T) -> Self { self }
223233
}
224234

225235
extension InstanceMethod {
@@ -536,27 +546,34 @@ extension HasStoredProperty {
536546
}
537547
}
538548

539-
// Test cross-file derivative registration. Currently unsupported.
540-
// TODO(TF-1021): Lift this restriction.
549+
// Test derivative registration for protocol requirements. Currently unsupported.
550+
// TODO(TF-982): Lift this restriction and add proper support.
541551

542-
extension AdditiveArithmetic where Self: Differentiable {
543-
// expected-error @+1 {{derivative not in the same file as the original function}}
544-
@derivative(of: +)
545-
static func vjpPlus(x: Self, y: Self) -> (
546-
value: Self,
547-
pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)
548-
) {
549-
return (x + y, { v in (v, v) })
552+
protocol ProtocolRequirementDerivative {
553+
func requirement(_ x: Float) -> Float
554+
}
555+
extension ProtocolRequirementDerivative {
556+
// NOTE: the error is misleading because `findAbstractFunctionDecl` in
557+
// TypeCheckAttr.cpp is not setup to show customized error messages for
558+
// invalid original function candidates.
559+
// expected-error @+1 {{could not find function 'requirement' with expected type '<Self where Self : ProtocolRequirementDerivative> (Self) -> (Float) -> Float'}}
560+
@derivative(of: requirement)
561+
func vjpRequirement(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
562+
fatalError()
550563
}
551564
}
552565

553-
extension FloatingPoint where Self: Differentiable, Self == Self.TangentVector {
566+
// Test cross-file derivative registration. Currently unsupported.
567+
// TODO(TF-1021): Lift this restriction.
568+
569+
extension FloatingPoint where Self: Differentiable {
554570
// expected-error @+1 {{derivative not in the same file as the original function}}
555-
@derivative(of: +)
556-
static func vjpPlus(x: Self, y: Self) -> (
557-
value: Self, pullback: (Self) -> (Self, Self)
571+
@derivative(of: rounded)
572+
func vjpRounded() -> (
573+
value: Self,
574+
pullback: (Self.TangentVector) -> (Self.TangentVector)
558575
) {
559-
return (x + y, { v in (v, v) })
576+
fatalError()
560577
}
561578
}
562579

0 commit comments

Comments
 (0)