Skip to content

Commit 9a55217

Browse files
committed
Add test.
1 parent 4f7f9bf commit 9a55217

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// RUN: %target-run-simple-swift
2+
// NOTE: Verify whether forward-mode differentiation crashes. It currently does.
3+
// RUN: not --crash %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil %s
4+
// REQUIRES: executable_test
5+
6+
import StdlibUnittest
7+
import DifferentiationUnittest
8+
9+
var ClassTests = TestSuite("ClassDifferentiation")
10+
11+
ClassTests.test("TrivialMember") {
12+
class C: Differentiable {
13+
@differentiable
14+
var float: Float
15+
16+
@noDerivative
17+
final var noDerivative: Float = 1
18+
19+
init(_ float: Float) {
20+
self.float = float
21+
}
22+
23+
@differentiable
24+
func method(_ x: Float) -> Float {
25+
x * float
26+
}
27+
28+
@differentiable
29+
func testNoDerivative() -> Float {
30+
noDerivative
31+
}
32+
33+
@differentiable
34+
static func controlFlow(_ c1: C, _ c2: C, _ flag: Bool) -> Float {
35+
var result: Float = 0
36+
if flag {
37+
result = c1.float * c2.float
38+
} else {
39+
result = c2.float * c1.float
40+
}
41+
return result
42+
}
43+
}
44+
expectEqual((.init(float: 3), 10), gradient(at: C(10), 3, in: { c, x in c.method(x) }))
45+
expectEqual(.init(float: 0), gradient(at: C(10), in: { c in c.testNoDerivative() }))
46+
expectEqual((.init(float: 20), .init(float: 10)),
47+
gradient(at: C(10), C(20), in: { c1, c2 in C.controlFlow(c1, c2, true) }))
48+
}
49+
50+
ClassTests.test("NontrivialMember") {
51+
class C: Differentiable {
52+
@differentiable
53+
var float: Tracked<Float>
54+
55+
init(_ float: Tracked<Float>) {
56+
self.float = float
57+
}
58+
59+
@differentiable
60+
func method(_ x: Tracked<Float>) -> Tracked<Float> {
61+
x * float
62+
}
63+
64+
@differentiable
65+
static func controlFlow(_ c1: C, _ c2: C, _ flag: Bool) -> Tracked<Float> {
66+
var result: Tracked<Float> = 0
67+
if flag {
68+
result = c1.float * c2.float
69+
} else {
70+
result = c2.float * c1.float
71+
}
72+
return result
73+
}
74+
}
75+
expectEqual((.init(float: 3), 10), gradient(at: C(10), 3, in: { c, x in c.method(x) }))
76+
expectEqual((.init(float: 20), .init(float: 10)),
77+
gradient(at: C(10), C(20), in: { c1, c2 in C.controlFlow(c1, c2, true) }))
78+
}
79+
80+
// TF-1149: Test class with loadable type but address-only `TangentVector` type.
81+
// TODO(TF-1149): Uncomment when supported.
82+
/*
83+
ClassTests.test("AddressOnlyTangentVector") {
84+
class C<T: Differentiable>: Differentiable {
85+
@differentiable
86+
var stored: T
87+
88+
init(_ stored: T) {
89+
self.stored = stored
90+
}
91+
92+
@differentiable
93+
func method(_ x: T) -> T {
94+
stored
95+
}
96+
}
97+
expectEqual((.init(stored: Float(3)), 10),
98+
gradient(at: C<Float>(3), 3, in: { c, x in c.method(x) }))
99+
}
100+
*/
101+
102+
runAllTests()

0 commit comments

Comments
 (0)