Skip to content

Commit 54b8508

Browse files
marcrasidan-zheng
authored andcommitted
[AutoDiff] forbid @Derivative of protocol req (#28890)
1 parent 5db80ee commit 54b8508

File tree

6 files changed

+97
-44
lines changed

6 files changed

+97
-44
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3516,10 +3516,12 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
35163516
};
35173517

35183518
auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) {
3519-
return checkFunctionSignature(
3520-
cast<AnyFunctionType>(originalFnType->getCanonicalType()),
3521-
originalCandidate->getInterfaceType()->getCanonicalType(),
3522-
checkGenericSignatureSatisfied);
3519+
// TODO(TF-982): Allow derivatives on protocol requirements.
3520+
return !isa<ProtocolDecl>(originalCandidate->getDeclContext()) &&
3521+
checkFunctionSignature(
3522+
cast<AnyFunctionType>(originalFnType->getCanonicalType()),
3523+
originalCandidate->getInterfaceType()->getCanonicalType(),
3524+
checkGenericSignatureSatisfied);
35233525
};
35243526

35253527
auto noneValidDiagnostic = [&]() {

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ protocol StaticMethod: Differentiable {
185185
static func generic<T: Differentiable>(_ x: T) -> T
186186
}
187187

188+
extension StaticMethod {
189+
static func foo(_ x: Float) -> Float { x }
190+
static func generic<T: Differentiable>(_ x: T) -> T { x }
191+
}
192+
188193
extension StaticMethod {
189194
@derivative(of: foo)
190195
static func jvpFoo(x: Float) -> (value: Float, differential: (Float) -> Float)
@@ -215,13 +220,19 @@ extension StaticMethod {
215220
// Test instance methods.
216221

217222
protocol InstanceMethod: Differentiable {
218-
// expected-note @+1 {{'foo' defined here}}
219223
func foo(_ x: Self) -> Self
220224

221-
// expected-note @+1 {{'generic' defined here}}
222225
func generic<T: Differentiable>(_ x: T) -> Self
223226
}
224227

228+
extension InstanceMethod {
229+
// expected-note @+1 {{'foo' defined here}}
230+
func foo(_ x: Self) -> Self { self }
231+
232+
// expected-note @+1 {{'generic' defined here}}
233+
func generic<T: Differentiable>(_ x: T) -> Self { self }
234+
}
235+
225236
extension InstanceMethod {
226237
@derivative(of: foo)
227238
func jvpFoo(x: Self) -> (
@@ -538,27 +549,15 @@ extension HasStoredProperty {
538549

539550
// Test cross-file derivative registration. Currently unsupported.
540551
// TODO(TF-1021): Lift this restriction.
541-
542-
extension AdditiveArithmetic where Self: Differentiable {
552+
extension FloatingPoint where Self: Differentiable {
543553
// expected-error @+1 {{derivative not in the same file as the original function}}
544-
@derivative(of: +)
545-
static func vjpPlus(x: Self, y: Self) -> (
546-
value: Self,
547-
pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)
548-
) {
549-
return (x + y, { v in (v, v) })
554+
@derivative(of: rounded)
555+
func vjpRounded() -> (value: Self, pullback: (TangentVector) -> TangentVector) {
556+
fatalError()
550557
}
551558
}
552559

553-
extension FloatingPoint where Self: Differentiable, Self == Self.TangentVector {
554-
// expected-error @+1 {{derivative not in the same file as the original function}}
555-
@derivative(of: +)
556-
static func vjpPlus(x: Self, y: Self) -> (
557-
value: Self, pullback: (Self) -> (Self, Self)
558-
) {
559-
return (x + y, { v in (v, v) })
560-
}
561-
}
560+
// Test static methods.
562561

563562
extension Differentiable where Self: AdditiveArithmetic {
564563
// expected-error @+1 {{'+' is not defined in the current type context}}
@@ -581,3 +580,29 @@ where Self: Differentiable, Self == Self.TangentVector {
581580
return (x + y, { v in (v, v) })
582581
}
583582
}
583+
584+
// Test derivatives of default implementations.
585+
protocol HasADefaultImplementation {
586+
func req(_ x: Float) -> Float
587+
}
588+
extension HasADefaultImplementation {
589+
func req(_ x: Float) -> Float { x }
590+
// ok
591+
@derivative(of: req)
592+
func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
593+
(x, { 10 * $0 })
594+
}
595+
}
596+
597+
// Test default derivatives of requirements.
598+
protocol HasADefaultDerivative {
599+
func req(_ x: Float) -> Float
600+
}
601+
extension HasADefaultDerivative {
602+
// TODO(TF-982): Make this ok.
603+
// expected-error @+1 {{could not find function 'req'}}
604+
@derivative(of: req)
605+
func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
606+
(x, { 10 * $0 })
607+
}
608+
}

test/AutoDiff/compiler_crashers_fixed/tf1039-cloned-curry-thunk-verification.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@ protocol P {
99
@differentiable
1010
func foo(_ x: Float) -> Float
1111
}
12-
extension P {
13-
@derivative(of: foo)
14-
func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
15-
return (x, { $0 })
16-
}
17-
}
1812
struct S: P {
1913
@differentiable
2014
func foo(_ x: Float) -> Float { x }

test/AutoDiff/derivative_attr_type_checking.swift

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -151,23 +151,17 @@ func vjpFooExtraGenericRequirements<T : FloatingPoint & Differentiable & BinaryI
151151
return (x, { $0 })
152152
}
153153

154-
// Test static methods.
155-
156-
extension AdditiveArithmetic where Self : Differentiable {
154+
// Test cross-file derivative registration. Currently unsupported.
155+
// TODO(TF-1021): Lift this restriction.
156+
extension FloatingPoint where Self: Differentiable {
157157
// expected-error @+1 {{derivative not in the same file as the original function}}
158-
@derivative(of: +)
159-
static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) {
160-
return (x + y, { v in (v, v) })
158+
@derivative(of: rounded)
159+
func vjpRounded() -> (value: Self, pullback: (TangentVector) -> TangentVector) {
160+
fatalError()
161161
}
162162
}
163163

164-
extension FloatingPoint where Self : Differentiable, Self == Self.TangentVector {
165-
// expected-error @+1 {{derivative not in the same file as the original function}}
166-
@derivative(of: +)
167-
static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self) -> (Self, Self)) {
168-
return (x + y, { v in (v, v) })
169-
}
170-
}
164+
// Test static methods.
171165

172166
extension Differentiable where Self : AdditiveArithmetic {
173167
// expected-error @+1 {{'+' is not defined in the current type context}}
@@ -188,14 +182,21 @@ extension AdditiveArithmetic where Self : Differentiable, Self == Self.TangentVe
188182
// Test instance methods.
189183

190184
protocol InstanceMethod : Differentiable {
191-
// expected-note @+1 {{'foo' defined here}}
192185
func foo(_ x: Self) -> Self
193186
func foo2(_ x: Self) -> Self
194-
// expected-note @+1 {{'bar' defined here}}
195187
func bar<T : Differentiable>(_ x: T) -> Self
196188
func bar2<T : Differentiable>(_ x: T) -> Self
197189
}
198190

191+
extension InstanceMethod {
192+
// expected-note @+1 {{'foo' defined here}}
193+
func foo(_ x: Self) -> Self { self }
194+
func foo2(_ x: Self) -> Self { self }
195+
// expected-note @+1 {{'bar' defined here}}
196+
func bar<T : Differentiable>(_ x: T) -> Self { self }
197+
func bar2<T : Differentiable>(_ x: T) -> Self { self}
198+
}
199+
199200
extension InstanceMethod {
200201
// If `Self` conforms to `Differentiable`, then `Self` is currently always inferred to be a differentiation parameter.
201202
// expected-error @+2 {{function result's 'pullback' type does not match 'foo'}}
@@ -261,6 +262,10 @@ protocol GenericInstanceMethod : Differentiable where Self == Self.TangentVector
261262
func instanceMethod<T : Differentiable>(_ x: T) -> T
262263
}
263264

265+
extension GenericInstanceMethod {
266+
func instanceMethod<T : Differentiable>(_ x: T) -> T { x }
267+
}
268+
264269
extension GenericInstanceMethod {
265270
@derivative(of: instanceMethod)
266271
func jvpInstanceMethod<T : Differentiable>(_ x: T) -> (value: T, differential: (Self, T.TangentVector) -> (T.TangentVector)) {
@@ -297,6 +302,9 @@ func vjpBaz<T : Differentiable, U : Differentiable>(_ x: T, _ y: U)
297302
protocol InstanceMethodProto {
298303
func bar() -> Float
299304
}
305+
extension InstanceMethodProto {
306+
func bar() -> Float { 0 }
307+
}
300308
extension InstanceMethodProto where Self : Differentiable {
301309
@derivative(of: bar)
302310
func vjpBar() -> (value: Float, pullback: (Float) -> TangentVector) {
@@ -318,6 +326,9 @@ protocol Protocol: Differentiable {
318326
func requirementOverlapping() -> Self
319327
}
320328
extension Protocol {
329+
@differentiable
330+
func requirementOverlapping() -> Self { self }
331+
321332
func nonRequirementOnlyDerivativeAttr() -> Self { self }
322333

323334
@differentiable

test/AutoDiff/derivative_registration.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,21 @@ DerivativeRegistrationTests.testWithLeakChecking("NonCanonicalizedGenericSignatu
176176
expectEqual(0, dx)
177177
}
178178

179+
// Test derivatives of default implementations.
180+
protocol HasADefaultImplementation {
181+
func req(_ x: Tracked<Float>) -> Tracked<Float>
182+
}
183+
extension HasADefaultImplementation {
184+
func req(_ x: Tracked<Float>) -> Tracked<Float> { x }
185+
@derivative(of: req)
186+
func req(_ x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
187+
(x, { 10 * $0 })
188+
}
189+
}
190+
struct StructConformingToHasADefaultImplementation : HasADefaultImplementation {}
191+
DerivativeRegistrationTests.testWithLeakChecking("DerivativeOfDefaultImplementation") {
192+
let dx = gradient(at: Tracked<Float>(0)) { StructConformingToHasADefaultImplementation().req($0) }
193+
expectEqual(Tracked<Float>(10), dx)
194+
}
195+
179196
runAllTests()

test/Serialization/derivative_attr.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ protocol InstanceMethod : Differentiable {
3636
func foo(_ x: Self) -> Self
3737
func bar<T : Differentiable>(_ x: T) -> Self
3838
}
39+
extension InstanceMethod {
40+
func foo(_ x: Self) -> Self { self }
41+
func bar<T : Differentiable>(_ x: T) -> Self { self }
42+
}
3943
extension InstanceMethod {
4044
// CHECK: @derivative(of: foo, wrt: (self, x))
4145
@derivative(of: foo)

0 commit comments

Comments
 (0)