Skip to content

Commit d5aed23

Browse files
author
Marc Rasi
committed
[AutoDiff] forbid @Derivative of protocol req
1 parent 7d3ae09 commit d5aed23

File tree

5 files changed

+93
-44
lines changed

5 files changed

+93
-44
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3466,10 +3466,12 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
34663466
};
34673467

34683468
auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) {
3469-
return checkFunctionSignature(
3470-
cast<AnyFunctionType>(originalFnType->getCanonicalType()),
3471-
originalCandidate->getInterfaceType()->getCanonicalType(),
3472-
checkGenericSignatureSatisfied);
3469+
// TODO(TF-982): Allow derivatives on protocol requirements.
3470+
return !isa<ProtocolDecl>(originalCandidate->getDeclContext()) &&
3471+
checkFunctionSignature(
3472+
cast<AnyFunctionType>(originalFnType->getCanonicalType()),
3473+
originalCandidate->getInterfaceType()->getCanonicalType(),
3474+
checkGenericSignatureSatisfied);
34733475
};
34743476

34753477
auto noneValidDiagnostic = [&]() {

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ protocol StaticMethod: Differentiable {
185185
static func generic<T: Differentiable>(_ x: T) -> T
186186
}
187187

188+
extension StaticMethod {
189+
static func foo(_ x: Float) -> Float { x }
190+
static func generic<T: Differentiable>(_ x: T) -> T { x }
191+
}
192+
188193
extension StaticMethod {
189194
@derivative(of: foo)
190195
static func jvpFoo(x: Float) -> (value: Float, differential: (Float) -> Float)
@@ -214,13 +219,19 @@ extension StaticMethod {
214219
// Test instance methods.
215220

216221
protocol InstanceMethod: Differentiable {
217-
// expected-note @+1 {{'foo' defined here}}
218222
func foo(_ x: Self) -> Self
219223

220-
// expected-note @+1 {{'generic' defined here}}
221224
func generic<T: Differentiable>(_ x: T) -> Self
222225
}
223226

227+
extension InstanceMethod {
228+
// expected-note @+1 {{'foo' defined here}}
229+
func foo(_ x: Self) -> Self { self }
230+
231+
// expected-note @+1 {{'generic' defined here}}
232+
func generic<T: Differentiable>(_ x: T) -> Self { self }
233+
}
234+
224235
extension InstanceMethod {
225236
@derivative(of: foo)
226237
func jvpFoo(x: Self) -> (
@@ -499,27 +510,15 @@ extension HasStoredProperty {
499510

500511
// Test cross-file derivative registration. Currently unsupported.
501512
// TODO(TF-1021): Lift this restriction.
502-
503-
extension AdditiveArithmetic where Self: Differentiable {
513+
extension FloatingPoint where Self: Differentiable {
504514
// expected-error @+1 {{derivative not in the same file as the original function}}
505-
@derivative(of: +)
506-
static func vjpPlus(x: Self, y: Self) -> (
507-
value: Self,
508-
pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)
509-
) {
510-
return (x + y, { v in (v, v) })
515+
@derivative(of: rounded)
516+
func vjpRounded() -> (value: Self, pullback: (TangentVector) -> TangentVector) {
517+
fatalError()
511518
}
512519
}
513520

514-
extension FloatingPoint where Self: Differentiable, Self == Self.TangentVector {
515-
// expected-error @+1 {{derivative not in the same file as the original function}}
516-
@derivative(of: +)
517-
static func vjpPlus(x: Self, y: Self) -> (
518-
value: Self, pullback: (Self) -> (Self, Self)
519-
) {
520-
return (x + y, { v in (v, v) })
521-
}
522-
}
521+
// Test static methods.
523522

524523
extension Differentiable where Self: AdditiveArithmetic {
525524
// expected-error @+1 {{'+' is not defined in the current type context}}
@@ -542,3 +541,29 @@ where Self: Differentiable, Self == Self.TangentVector {
542541
return (x + y, { v in (v, v) })
543542
}
544543
}
544+
545+
// Test derivatives of default implementations.
546+
protocol HasADefaultImplementation {
547+
func req(_ x: Float) -> Float
548+
}
549+
extension HasADefaultImplementation {
550+
func req(_ x: Float) -> Float { x }
551+
// ok
552+
@derivative(of: req)
553+
func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
554+
(x, { 10 * $0 })
555+
}
556+
}
557+
558+
// Test default derivatives of requirements.
559+
protocol HasADefaultDerivative {
560+
func req(_ x: Float) -> Float
561+
}
562+
extension HasADefaultDerivative {
563+
// TODO(TF-982): Make this ok.
564+
// expected-error @+1 {{could not find function 'req'}}
565+
@derivative(of: req)
566+
func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
567+
(x, { 10 * $0 })
568+
}
569+
}

test/AutoDiff/compiler_crashers_fixed/tf1039-cloned-curry-thunk-verification.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@ protocol P {
99
@differentiable
1010
func foo(_ x: Float) -> Float
1111
}
12-
extension P {
13-
@derivative(of: foo)
14-
func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
15-
return (x, { $0 })
16-
}
17-
}
1812
struct S: P {
1913
@differentiable
2014
func foo(_ x: Float) -> Float { x }

test/AutoDiff/derivative_attr_type_checking.swift

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -151,23 +151,17 @@ func vjpFooExtraGenericRequirements<T : FloatingPoint & Differentiable & BinaryI
151151
return (x, { $0 })
152152
}
153153

154-
// Test static methods.
155-
156-
extension AdditiveArithmetic where Self : Differentiable {
154+
// Test cross-file derivative registration. Currently unsupported.
155+
// TODO(TF-1021): Lift this restriction.
156+
extension FloatingPoint where Self: Differentiable {
157157
// expected-error @+1 {{derivative not in the same file as the original function}}
158-
@derivative(of: +)
159-
static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) {
160-
return (x + y, { v in (v, v) })
158+
@derivative(of: rounded)
159+
func vjpRounded() -> (value: Self, pullback: (TangentVector) -> TangentVector) {
160+
fatalError()
161161
}
162162
}
163163

164-
extension FloatingPoint where Self : Differentiable, Self == Self.TangentVector {
165-
// expected-error @+1 {{derivative not in the same file as the original function}}
166-
@derivative(of: +)
167-
static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self) -> (Self, Self)) {
168-
return (x + y, { v in (v, v) })
169-
}
170-
}
164+
// Test static methods.
171165

172166
extension Differentiable where Self : AdditiveArithmetic {
173167
// expected-error @+1 {{'+' is not defined in the current type context}}
@@ -188,14 +182,21 @@ extension AdditiveArithmetic where Self : Differentiable, Self == Self.TangentVe
188182
// Test instance methods.
189183

190184
protocol InstanceMethod : Differentiable {
191-
// expected-note @+1 {{'foo' defined here}}
192185
func foo(_ x: Self) -> Self
193186
func foo2(_ x: Self) -> Self
194-
// expected-note @+1 {{'bar' defined here}}
195187
func bar<T : Differentiable>(_ x: T) -> Self
196188
func bar2<T : Differentiable>(_ x: T) -> Self
197189
}
198190

191+
extension InstanceMethod {
192+
// expected-note @+1 {{'foo' defined here}}
193+
func foo(_ x: Self) -> Self { self }
194+
func foo2(_ x: Self) -> Self { self }
195+
// expected-note @+1 {{'bar' defined here}}
196+
func bar<T : Differentiable>(_ x: T) -> Self { self }
197+
func bar2<T : Differentiable>(_ x: T) -> Self { self}
198+
}
199+
199200
extension InstanceMethod {
200201
// If `Self` conforms to `Differentiable`, then `Self` is currently always inferred to be a differentiation parameter.
201202
// expected-error @+2 {{function result's 'pullback' type does not match 'foo'}}
@@ -261,6 +262,10 @@ protocol GenericInstanceMethod : Differentiable where Self == Self.TangentVector
261262
func instanceMethod<T : Differentiable>(_ x: T) -> T
262263
}
263264

265+
extension GenericInstanceMethod {
266+
func instanceMethod<T : Differentiable>(_ x: T) -> T { x }
267+
}
268+
264269
extension GenericInstanceMethod {
265270
@derivative(of: instanceMethod)
266271
func jvpInstanceMethod<T : Differentiable>(_ x: T) -> (value: T, differential: (Self, T.TangentVector) -> (T.TangentVector)) {
@@ -297,6 +302,9 @@ func vjpBaz<T : Differentiable, U : Differentiable>(_ x: T, _ y: U)
297302
protocol InstanceMethodProto {
298303
func bar() -> Float
299304
}
305+
extension InstanceMethodProto {
306+
func bar() -> Float { 0 }
307+
}
300308
extension InstanceMethodProto where Self : Differentiable {
301309
@derivative(of: bar)
302310
func vjpBar() -> (value: Float, pullback: (Float) -> TangentVector) {
@@ -318,6 +326,9 @@ protocol Protocol: Differentiable {
318326
func requirementOverlapping() -> Self
319327
}
320328
extension Protocol {
329+
@differentiable
330+
func requirementOverlapping() -> Self { self }
331+
321332
func nonRequirementOnlyDerivativeAttr() -> Self { self }
322333

323334
@differentiable

test/AutoDiff/derivative_registration.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,21 @@ DerivativeRegistrationTests.testWithLeakChecking("NonCanonicalizedGenericSignatu
176176
expectEqual(0, dx)
177177
}
178178

179+
// Test derivatives of default implementations.
180+
protocol HasADefaultImplementation {
181+
func req(_ x: Tracked<Float>) -> Tracked<Float>
182+
}
183+
extension HasADefaultImplementation {
184+
func req(_ x: Tracked<Float>) -> Tracked<Float> { x }
185+
@derivative(of: req)
186+
func req(_ x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
187+
(x, { 10 * $0 })
188+
}
189+
}
190+
struct StructConformingToHasADefaultImplementation : HasADefaultImplementation {}
191+
DerivativeRegistrationTests.testWithLeakChecking("DerivativeOfDefaultImplementation") {
192+
let dx = gradient(at: Tracked<Float>(0)) { StructConformingToHasADefaultImplementation().req($0) }
193+
expectEqual(Tracked<Float>(10), dx)
194+
}
195+
179196
runAllTests()

0 commit comments

Comments
 (0)