Skip to content

[AutoDiff] forbid @derivative of protocol req #28890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3473,10 +3473,12 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
};

auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) {
return checkFunctionSignature(
cast<AnyFunctionType>(originalFnType->getCanonicalType()),
originalCandidate->getInterfaceType()->getCanonicalType(),
checkGenericSignatureSatisfied);
// TODO(TF-982): Allow derivatives on protocol requirements.
return !isa<ProtocolDecl>(originalCandidate->getDeclContext()) &&
checkFunctionSignature(
cast<AnyFunctionType>(originalFnType->getCanonicalType()),
originalCandidate->getInterfaceType()->getCanonicalType(),
checkGenericSignatureSatisfied);
};

auto noneValidDiagnostic = [&]() {
Expand Down
63 changes: 44 additions & 19 deletions test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ protocol StaticMethod: Differentiable {
static func generic<T: Differentiable>(_ x: T) -> T
}

extension StaticMethod {
static func foo(_ x: Float) -> Float { x }
static func generic<T: Differentiable>(_ x: T) -> T { x }
}

extension StaticMethod {
@derivative(of: foo)
static func jvpFoo(x: Float) -> (value: Float, differential: (Float) -> Float)
Expand Down Expand Up @@ -214,13 +219,19 @@ extension StaticMethod {
// Test instance methods.

protocol InstanceMethod: Differentiable {
// expected-note @+1 {{'foo' defined here}}
func foo(_ x: Self) -> Self

// expected-note @+1 {{'generic' defined here}}
func generic<T: Differentiable>(_ x: T) -> Self
}

extension InstanceMethod {
// expected-note @+1 {{'foo' defined here}}
func foo(_ x: Self) -> Self { self }

// expected-note @+1 {{'generic' defined here}}
func generic<T: Differentiable>(_ x: T) -> Self { self }
}

extension InstanceMethod {
@derivative(of: foo)
func jvpFoo(x: Self) -> (
Expand Down Expand Up @@ -499,27 +510,15 @@ extension HasStoredProperty {

// Test cross-file derivative registration. Currently unsupported.
// TODO(TF-1021): Lift this restriction.

extension AdditiveArithmetic where Self: Differentiable {
extension FloatingPoint where Self: Differentiable {
// expected-error @+1 {{derivative not in the same file as the original function}}
@derivative(of: +)
static func vjpPlus(x: Self, y: Self) -> (
value: Self,
pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)
) {
return (x + y, { v in (v, v) })
@derivative(of: rounded)
func vjpRounded() -> (value: Self, pullback: (TangentVector) -> TangentVector) {
fatalError()
}
}

extension FloatingPoint where Self: Differentiable, Self == Self.TangentVector {
// expected-error @+1 {{derivative not in the same file as the original function}}
@derivative(of: +)
static func vjpPlus(x: Self, y: Self) -> (
value: Self, pullback: (Self) -> (Self, Self)
) {
return (x + y, { v in (v, v) })
}
}
// Test static methods.

extension Differentiable where Self: AdditiveArithmetic {
// expected-error @+1 {{'+' is not defined in the current type context}}
Expand All @@ -542,3 +541,29 @@ where Self: Differentiable, Self == Self.TangentVector {
return (x + y, { v in (v, v) })
}
}

// Test derivatives of default implementations.
protocol HasADefaultImplementation {
func req(_ x: Float) -> Float
}
extension HasADefaultImplementation {
func req(_ x: Float) -> Float { x }
// ok
@derivative(of: req)
func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { 10 * $0 })
}
}

// Test default derivatives of requirements.
protocol HasADefaultDerivative {
func req(_ x: Float) -> Float
}
extension HasADefaultDerivative {
// TODO(TF-982): Make this ok.
// expected-error @+1 {{could not find function 'req'}}
@derivative(of: req)
func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { 10 * $0 })
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ protocol P {
@differentiable
func foo(_ x: Float) -> Float
}
extension P {
@derivative(of: foo)
func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
}
struct S: P {
@differentiable
func foo(_ x: Float) -> Float { x }
Expand Down
41 changes: 26 additions & 15 deletions test/AutoDiff/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -151,23 +151,17 @@ func vjpFooExtraGenericRequirements<T : FloatingPoint & Differentiable & BinaryI
return (x, { $0 })
}

// Test static methods.

extension AdditiveArithmetic where Self : Differentiable {
// Test cross-file derivative registration. Currently unsupported.
// TODO(TF-1021): Lift this restriction.
extension FloatingPoint where Self: Differentiable {
// expected-error @+1 {{derivative not in the same file as the original function}}
@derivative(of: +)
static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) {
return (x + y, { v in (v, v) })
@derivative(of: rounded)
func vjpRounded() -> (value: Self, pullback: (TangentVector) -> TangentVector) {
fatalError()
}
}

extension FloatingPoint where Self : Differentiable, Self == Self.TangentVector {
// expected-error @+1 {{derivative not in the same file as the original function}}
@derivative(of: +)
static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self) -> (Self, Self)) {
return (x + y, { v in (v, v) })
}
}
// Test static methods.

extension Differentiable where Self : AdditiveArithmetic {
// expected-error @+1 {{'+' is not defined in the current type context}}
Expand All @@ -188,14 +182,21 @@ extension AdditiveArithmetic where Self : Differentiable, Self == Self.TangentVe
// Test instance methods.

protocol InstanceMethod : Differentiable {
// expected-note @+1 {{'foo' defined here}}
func foo(_ x: Self) -> Self
func foo2(_ x: Self) -> Self
// expected-note @+1 {{'bar' defined here}}
func bar<T : Differentiable>(_ x: T) -> Self
func bar2<T : Differentiable>(_ x: T) -> Self
}

extension InstanceMethod {
// expected-note @+1 {{'foo' defined here}}
func foo(_ x: Self) -> Self { self }
func foo2(_ x: Self) -> Self { self }
// expected-note @+1 {{'bar' defined here}}
func bar<T : Differentiable>(_ x: T) -> Self { self }
func bar2<T : Differentiable>(_ x: T) -> Self { self}
}

extension InstanceMethod {
// If `Self` conforms to `Differentiable`, then `Self` is currently always inferred to be a differentiation parameter.
// expected-error @+2 {{function result's 'pullback' type does not match 'foo'}}
Expand Down Expand Up @@ -261,6 +262,10 @@ protocol GenericInstanceMethod : Differentiable where Self == Self.TangentVector
func instanceMethod<T : Differentiable>(_ x: T) -> T
}

extension GenericInstanceMethod {
func instanceMethod<T : Differentiable>(_ x: T) -> T { x }
}

extension GenericInstanceMethod {
@derivative(of: instanceMethod)
func jvpInstanceMethod<T : Differentiable>(_ x: T) -> (value: T, differential: (Self, T.TangentVector) -> (T.TangentVector)) {
Expand Down Expand Up @@ -297,6 +302,9 @@ func vjpBaz<T : Differentiable, U : Differentiable>(_ x: T, _ y: U)
protocol InstanceMethodProto {
func bar() -> Float
}
extension InstanceMethodProto {
func bar() -> Float { 0 }
}
extension InstanceMethodProto where Self : Differentiable {
@derivative(of: bar)
func vjpBar() -> (value: Float, pullback: (Float) -> TangentVector) {
Expand All @@ -318,6 +326,9 @@ protocol Protocol: Differentiable {
func requirementOverlapping() -> Self
}
extension Protocol {
@differentiable
func requirementOverlapping() -> Self { self }

func nonRequirementOnlyDerivativeAttr() -> Self { self }

@differentiable
Expand Down
17 changes: 17 additions & 0 deletions test/AutoDiff/derivative_registration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,21 @@ DerivativeRegistrationTests.testWithLeakChecking("NonCanonicalizedGenericSignatu
expectEqual(0, dx)
}

// Test derivatives of default implementations.
protocol HasADefaultImplementation {
func req(_ x: Tracked<Float>) -> Tracked<Float>
}
extension HasADefaultImplementation {
func req(_ x: Tracked<Float>) -> Tracked<Float> { x }
@derivative(of: req)
func req(_ x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
(x, { 10 * $0 })
}
}
struct StructConformingToHasADefaultImplementation : HasADefaultImplementation {}
DerivativeRegistrationTests.testWithLeakChecking("DerivativeOfDefaultImplementation") {
let dx = gradient(at: Tracked<Float>(0)) { StructConformingToHasADefaultImplementation().req($0) }
expectEqual(Tracked<Float>(10), dx)
}

runAllTests()
4 changes: 4 additions & 0 deletions test/Serialization/derivative_attr.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ protocol InstanceMethod : Differentiable {
func foo(_ x: Self) -> Self
func bar<T : Differentiable>(_ x: T) -> Self
}
extension InstanceMethod {
func foo(_ x: Self) -> Self { self }
func bar<T : Differentiable>(_ x: T) -> Self { self }
}
extension InstanceMethod {
// CHECK: @derivative(of: foo, wrt: (self, x))
@derivative(of: foo)
Expand Down