Skip to content

Commit 433821b

Browse files
author
marcrasi
authored
[AutoDiff upstream] add more validation tests (#31112)
1 parent 8035d39 commit 433821b

File tree

7 files changed

+909
-0
lines changed

7 files changed

+909
-0
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
import DifferentiationUnittest
6+
7+
var DerivativeRegistrationTests = TestSuite("DerivativeRegistration")
8+
9+
@_semantics("autodiff.opaque")
10+
func unary(x: Tracked<Float>) -> Tracked<Float> {
11+
return x
12+
}
13+
@derivative(of: unary)
14+
func _vjpUnary(x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
15+
return (value: x, pullback: { v in v })
16+
}
17+
DerivativeRegistrationTests.testWithLeakChecking("UnaryFreeFunction") {
18+
expectEqual(1, gradient(at: 3.0, in: unary))
19+
}
20+
21+
@_semantics("autodiff.opaque")
22+
func multiply(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
23+
return x * y
24+
}
25+
@derivative(of: multiply)
26+
func _vjpMultiply(_ x: Tracked<Float>, _ y: Tracked<Float>)
27+
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (Tracked<Float>, Tracked<Float>)) {
28+
return (x * y, { v in (v * y, v * x) })
29+
}
30+
DerivativeRegistrationTests.testWithLeakChecking("BinaryFreeFunction") {
31+
expectEqual((3.0, 2.0), gradient(at: 2.0, 3.0, in: { x, y in multiply(x, y) }))
32+
}
33+
34+
struct Wrapper : Differentiable {
35+
var float: Tracked<Float>
36+
}
37+
38+
extension Wrapper {
39+
@_semantics("autodiff.opaque")
40+
init(_ x: Tracked<Float>, _ y: Tracked<Float>) {
41+
self.float = x * y
42+
}
43+
44+
@derivative(of: init(_:_:))
45+
static func _vjpInit(_ x: Tracked<Float>, _ y: Tracked<Float>)
46+
-> (value: Self, pullback: (TangentVector) -> (Tracked<Float>, Tracked<Float>)) {
47+
return (.init(x, y), { v in (v.float * y, v.float * x) })
48+
}
49+
}
50+
DerivativeRegistrationTests.testWithLeakChecking("Initializer") {
51+
let v = Wrapper.TangentVector(float: 1)
52+
let (𝛁x, 𝛁y) = pullback(at: 3, 4, in: { x, y in Wrapper(x, y) })(v)
53+
expectEqual(4, 𝛁x)
54+
expectEqual(3, 𝛁y)
55+
}
56+
57+
extension Wrapper {
58+
@_semantics("autodiff.opaque")
59+
static func multiply(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
60+
return x * y
61+
}
62+
63+
@derivative(of: multiply)
64+
static func _vjpMultiply(_ x: Tracked<Float>, _ y: Tracked<Float>)
65+
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (Tracked<Float>, Tracked<Float>)) {
66+
return (x * y, { v in (v * y, v * x) })
67+
}
68+
}
69+
DerivativeRegistrationTests.testWithLeakChecking("StaticMethod") {
70+
expectEqual((3.0, 2.0), gradient(at: 2.0, 3.0, in: { x, y in Wrapper.multiply(x, y) }))
71+
}
72+
73+
extension Wrapper {
74+
@_semantics("autodiff.opaque")
75+
func multiply(_ x: Tracked<Float>) -> Tracked<Float> {
76+
return float * x
77+
}
78+
79+
@derivative(of: multiply)
80+
func _vjpMultiply(_ x: Tracked<Float>)
81+
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (Wrapper.TangentVector, Tracked<Float>)) {
82+
return (float * x, { v in
83+
(TangentVector(float: v * x), v * self.float)
84+
})
85+
}
86+
}
87+
DerivativeRegistrationTests.testWithLeakChecking("InstanceMethod") {
88+
let x: Tracked<Float> = 2
89+
let wrapper = Wrapper(float: 3)
90+
let (𝛁wrapper, 𝛁x) = gradient(at: wrapper, x) { wrapper, x in wrapper.multiply(x) }
91+
expectEqual(Wrapper.TangentVector(float: 2), 𝛁wrapper)
92+
expectEqual(3, 𝛁x)
93+
}
94+
95+
extension Wrapper {
96+
subscript(_ x: Tracked<Float>) -> Tracked<Float> {
97+
@_semantics("autodiff.opaque")
98+
get { float * x }
99+
set {}
100+
}
101+
102+
@derivative(of: subscript(_:))
103+
func _vjpSubscript(_ x: Tracked<Float>)
104+
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (Wrapper.TangentVector, Tracked<Float>)) {
105+
return (self[x], { v in
106+
(TangentVector(float: v * x), v * self.float)
107+
})
108+
}
109+
}
110+
DerivativeRegistrationTests.testWithLeakChecking("Subscript") {
111+
let x: Tracked<Float> = 2
112+
let wrapper = Wrapper(float: 3)
113+
let (𝛁wrapper, 𝛁x) = gradient(at: wrapper, x) { wrapper, x in wrapper[x] }
114+
expectEqual(Wrapper.TangentVector(float: 2), 𝛁wrapper)
115+
expectEqual(3, 𝛁x)
116+
}
117+
118+
extension Wrapper {
119+
var computedProperty: Tracked<Float> {
120+
@_semantics("autodiff.opaque")
121+
get { float * float }
122+
set {}
123+
}
124+
125+
@derivative(of: computedProperty)
126+
func _vjpComputedProperty()
127+
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Wrapper.TangentVector) {
128+
return (computedProperty, { [f = self.float] v in
129+
TangentVector(float: v * (f + f))
130+
})
131+
}
132+
}
133+
DerivativeRegistrationTests.testWithLeakChecking("ComputedProperty") {
134+
let wrapper = Wrapper(float: 3)
135+
let 𝛁wrapper = gradient(at: wrapper) { wrapper in wrapper.computedProperty }
136+
expectEqual(Wrapper.TangentVector(float: 6), 𝛁wrapper)
137+
}
138+
139+
struct Generic<T> {
140+
@differentiable // derivative generic signature: none
141+
func instanceMethod(_ x: Tracked<Float>) -> Tracked<Float> {
142+
x
143+
}
144+
}
145+
extension Generic {
146+
@derivative(of: instanceMethod) // derivative generic signature: <T>
147+
func vjpInstanceMethod(_ x: Tracked<Float>)
148+
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
149+
(x, { v in 1000 })
150+
}
151+
}
152+
DerivativeRegistrationTests.testWithLeakChecking("DerivativeGenericSignature") {
153+
let generic = Generic<Float>()
154+
let x: Tracked<Float> = 3
155+
let dx = gradient(at: x) { x in generic.instanceMethod(x) }
156+
expectEqual(1000, dx)
157+
}
158+
159+
// When non-canonicalized generic signatures are used to compare derivative configurations, the
160+
// `@differentiable` and `@derivative` attributes create separate derivatives, and we get a
161+
// duplicate symbol error in TBDGen.
162+
public protocol RefinesDifferentiable: Differentiable {}
163+
extension Float: RefinesDifferentiable {}
164+
@differentiable(where T: Differentiable, T: RefinesDifferentiable)
165+
public func nonCanonicalizedGenSigComparison<T>(_ t: T) -> T { t }
166+
@derivative(of: nonCanonicalizedGenSigComparison)
167+
public func dNonCanonicalizedGenSigComparison<T: RefinesDifferentiable>(_ t: T)
168+
-> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
169+
{
170+
(t, { _ in T.TangentVector.zero })
171+
}
172+
DerivativeRegistrationTests.testWithLeakChecking("NonCanonicalizedGenericSignatureComparison") {
173+
let dx = gradient(at: Float(0), in: nonCanonicalizedGenSigComparison)
174+
// Expect that we use the custom registered derivative, not a generated derivative (which would
175+
// give a gradient of 1).
176+
expectEqual(0, dx)
177+
}
178+
179+
// Test derivatives of default implementations.
180+
protocol HasADefaultImplementation {
181+
func req(_ x: Tracked<Float>) -> Tracked<Float>
182+
}
183+
extension HasADefaultImplementation {
184+
func req(_ x: Tracked<Float>) -> Tracked<Float> { x }
185+
@derivative(of: req)
186+
func req(_ x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
187+
(x, { 10 * $0 })
188+
}
189+
}
190+
struct StructConformingToHasADefaultImplementation : HasADefaultImplementation {}
191+
DerivativeRegistrationTests.testWithLeakChecking("DerivativeOfDefaultImplementation") {
192+
let dx = gradient(at: Tracked<Float>(0)) { StructConformingToHasADefaultImplementation().req($0) }
193+
expectEqual(Tracked<Float>(10), dx)
194+
}
195+
196+
runAllTests()
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
// An end-to-end test that we can differentiate property accesses, with custom
5+
// VJPs for the properties specified in various ways.
6+
7+
import StdlibUnittest
8+
import DifferentiationUnittest
9+
10+
var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty")
11+
12+
struct TangentSpace : AdditiveArithmetic {
13+
let x, y: Tracked<Float>
14+
}
15+
16+
extension TangentSpace : Differentiable {
17+
typealias TangentVector = TangentSpace
18+
}
19+
20+
struct Space {
21+
/// `x` is a computed property with a custom vjp.
22+
var x: Tracked<Float> {
23+
@differentiable
24+
get { storedX }
25+
set { storedX = newValue }
26+
}
27+
28+
@derivative(of: x)
29+
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> TangentSpace) {
30+
return (x, { v in TangentSpace(x: v, y: 0) } )
31+
}
32+
33+
private var storedX: Tracked<Float>
34+
35+
@differentiable
36+
var y: Tracked<Float>
37+
38+
init(x: Tracked<Float>, y: Tracked<Float>) {
39+
self.storedX = x
40+
self.y = y
41+
}
42+
}
43+
44+
extension Space : Differentiable {
45+
typealias TangentVector = TangentSpace
46+
mutating func move(along direction: TangentSpace) {
47+
x.move(along: direction.x)
48+
y.move(along: direction.y)
49+
}
50+
}
51+
52+
E2EDifferentiablePropertyTests.testWithLeakChecking("computed property") {
53+
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Tracked<Float> in
54+
return 2 * point.x
55+
}
56+
let expectedGrad = TangentSpace(x: 2, y: 0)
57+
expectEqual(expectedGrad, actualGrad)
58+
}
59+
60+
E2EDifferentiablePropertyTests.testWithLeakChecking("stored property") {
61+
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Tracked<Float> in
62+
return 3 * point.y
63+
}
64+
let expectedGrad = TangentSpace(x: 0, y: 3)
65+
expectEqual(expectedGrad, actualGrad)
66+
}
67+
68+
struct GenericMemberWrapper<T : Differentiable> : Differentiable {
69+
// Stored property.
70+
@differentiable
71+
var x: T
72+
73+
func vjpX() -> (T, (T.TangentVector) -> GenericMemberWrapper.TangentVector) {
74+
return (x, { TangentVector(x: $0) })
75+
}
76+
}
77+
78+
E2EDifferentiablePropertyTests.testWithLeakChecking("generic stored property") {
79+
let actualGrad = gradient(at: GenericMemberWrapper<Tracked<Float>>(x: 1)) { point in
80+
return 2 * point.x
81+
}
82+
let expectedGrad = GenericMemberWrapper<Tracked<Float>>.TangentVector(x: 2)
83+
expectEqual(expectedGrad, actualGrad)
84+
}
85+
86+
struct ProductSpaceSelfTangent : AdditiveArithmetic {
87+
let x, y: Tracked<Float>
88+
}
89+
90+
extension ProductSpaceSelfTangent : Differentiable {
91+
typealias TangentVector = ProductSpaceSelfTangent
92+
}
93+
94+
E2EDifferentiablePropertyTests.testWithLeakChecking("fieldwise product space, self tangent") {
95+
let actualGrad = gradient(at: ProductSpaceSelfTangent(x: 0, y: 0)) { (point: ProductSpaceSelfTangent) -> Tracked<Float> in
96+
return 5 * point.y
97+
}
98+
let expectedGrad = ProductSpaceSelfTangent(x: 0, y: 5)
99+
expectEqual(expectedGrad, actualGrad)
100+
}
101+
102+
struct ProductSpaceOtherTangentTangentSpace : AdditiveArithmetic {
103+
let x, y: Tracked<Float>
104+
}
105+
106+
extension ProductSpaceOtherTangentTangentSpace : Differentiable {
107+
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
108+
}
109+
110+
struct ProductSpaceOtherTangent {
111+
var x, y: Tracked<Float>
112+
}
113+
114+
extension ProductSpaceOtherTangent : Differentiable {
115+
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
116+
mutating func move(along direction: ProductSpaceOtherTangentTangentSpace) {
117+
x.move(along: direction.x)
118+
y.move(along: direction.y)
119+
}
120+
}
121+
122+
E2EDifferentiablePropertyTests.testWithLeakChecking("fieldwise product space, other tangent") {
123+
let actualGrad = gradient(
124+
at: ProductSpaceOtherTangent(x: 0, y: 0)
125+
) { (point: ProductSpaceOtherTangent) -> Tracked<Float> in
126+
return 7 * point.y
127+
}
128+
let expectedGrad = ProductSpaceOtherTangentTangentSpace(x: 0, y: 7)
129+
expectEqual(expectedGrad, actualGrad)
130+
}
131+
132+
E2EDifferentiablePropertyTests.testWithLeakChecking("computed property") {
133+
struct TF_544 : Differentiable {
134+
var value: Tracked<Float>
135+
@differentiable
136+
var computed: Tracked<Float> {
137+
get { value }
138+
set { value = newValue }
139+
}
140+
}
141+
let actualGrad = gradient(at: TF_544(value: 2.4)) { x in
142+
return x.computed * x.computed
143+
}
144+
let expectedGrad = TF_544.TangentVector(value: 4.8)
145+
expectEqual(expectedGrad, actualGrad)
146+
}
147+
148+
runAllTests()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
import DifferentiationUnittest
6+
7+
var ExistentialTests = TestSuite("Existential")
8+
9+
protocol A {
10+
@differentiable(wrt: x)
11+
func a(_ x: Tracked<Float>) -> Tracked<Float>
12+
}
13+
func b(g: A) -> Tracked<Float> {
14+
return gradient(at: 3) { x in g.a(x) }
15+
}
16+
17+
struct B : A {
18+
@differentiable(wrt: x)
19+
func a(_ x: Tracked<Float>) -> Tracked<Float> { return x * 5 }
20+
}
21+
22+
ExistentialTests.testWithLeakChecking("Existential method VJP") {
23+
expectEqual(5.0, b(g: B()))
24+
}
25+
26+
runAllTests()

0 commit comments

Comments
 (0)