Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 786e774

Browse files
jon-towmarcrasi
authored andcommitted
[Tests] Port tensor control flow AD tests (#413)
* [Tests] Port tensor control flow AD tests * Remove accidental whitespace * Update `testRecursion()` to use XCTAssertEqual assertions * Update comparisons and assertions
1 parent 25c7cfe commit 786e774

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

Tests/TensorFlowTests/TensorAutoDiffTests.swift

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,116 @@ final class TensorAutoDiffTests: XCTestCase {
4141
XCTAssertEqual(gradient(at: Tensor([0.1, 0.2, 0.3]), in: square), [0.2, 0.4, 0.6])
4242
}
4343

44+
func testConditionals() {
45+
func condNestedTupleVar(_ x: Tensor<Float>) -> Tensor<Float> {
46+
// Convoluted function returning `x + x`.
47+
var y: (Tensor<Float>, Tensor<Float>) = (x + x, x - x)
48+
var z: ((Tensor<Float>, Tensor<Float>), Tensor<Float>) = (y, x)
49+
if (x .> 0).all() {
50+
let w = (x, x)
51+
y.0 = w.1
52+
y.1 = w.0
53+
z.0.0 = z.0.0 - y.0
54+
z.0.1 = z.0.1 + y.0
55+
} else {
56+
z = ((y.0 - x, y.1 + x), x)
57+
}
58+
return y.0 + y.1 - z.0.0 + z.0.1
59+
}
60+
XCTAssertTrue((value: Tensor(8), gradient: Tensor(2)) == valueWithGradient(at: Tensor(4), in: condNestedTupleVar))
61+
XCTAssertTrue((value: Tensor(-20), gradient: Tensor(2)) == valueWithGradient(at: Tensor(-10), in: condNestedTupleVar))
62+
XCTAssertTrue((value: Tensor(-2674), gradient: Tensor(2)) == valueWithGradient(at: Tensor(-1337), in: condNestedTupleVar))
63+
64+
func guard2Var(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
65+
var z = y
66+
guard (x .> 0).all() else {
67+
if (y .> 0).all() {
68+
z = z * x
69+
} else if x == Tensor(-1337) {
70+
z = x
71+
z = z * z
72+
} else {
73+
z = Tensor(0)
74+
}
75+
return z
76+
}
77+
return z * y
78+
}
79+
XCTAssertTrue((Tensor(0), Tensor(10)) == gradient(at: Tensor(4), Tensor(5), in: guard2Var))
80+
XCTAssertTrue((Tensor(5), Tensor(-1337)) == gradient(at: Tensor(-1337), Tensor(5), in: guard2Var))
81+
XCTAssertTrue((Tensor(-2674), Tensor(0)) == gradient(at: Tensor(-1337), Tensor(-5), in: guard2Var))
82+
XCTAssertTrue((Tensor(2), Tensor(-3)) == gradient(at: Tensor(-3), Tensor(2), in: guard2Var))
83+
}
84+
85+
func testNestedConditionals() {
86+
// Test tensor-tensor ops.
87+
func condNested1(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
88+
if (x .> 0).all() {
89+
if (y .> 10).all() {
90+
let z = x * y
91+
if (z .> 100).all() {
92+
return x + z
93+
} else if y == Tensor(20) {
94+
return z + z
95+
}
96+
} else {
97+
return x + y
98+
}
99+
}
100+
return -y
101+
}
102+
XCTAssertTrue((Tensor(40), Tensor(8)) == gradient(at: Tensor(4), Tensor(20), in: condNested1))
103+
XCTAssertTrue((Tensor(0), Tensor(-1)) == gradient(at: Tensor(4), Tensor(21), in: condNested1))
104+
XCTAssertTrue((Tensor(1), Tensor(1)) == gradient(at: Tensor(4), Tensor(5), in: condNested1))
105+
XCTAssertTrue((Tensor(0), Tensor(-1)) == gradient(at: Tensor(-3), Tensor(-2), in: condNested1))
106+
107+
// Test tensor-scalar ops.
108+
func condNested2(_ x: Tensor<Float>, _ y: Float) -> Tensor<Float> {
109+
if (x .> 0).all() {
110+
if y > 10 {
111+
let z = x * y
112+
if (z .> 100).all() {
113+
return x + z
114+
} else if y == 20 {
115+
return z + z
116+
}
117+
} else {
118+
return x + y
119+
}
120+
}
121+
return Tensor(-y)
122+
}
123+
XCTAssertTrue((Tensor(40), 8) == gradient(at: Tensor(4), 20, in: condNested2))
124+
XCTAssertTrue((Tensor(0), -1) == gradient(at: Tensor(4), 21, in: condNested2))
125+
XCTAssertTrue((Tensor(1), 1) == gradient(at: Tensor(4), 5, in: condNested2))
126+
XCTAssertTrue((Tensor(0), -1) == gradient(at: Tensor(-3), -2, in: condNested2))
127+
}
128+
129+
func testRecursion() {
130+
func factorial(_ x: Tensor<Float>) -> Tensor<Float> {
131+
if x == Tensor(1) {
132+
return Tensor(1)
133+
}
134+
return x * factorial(x - 1)
135+
}
136+
XCTAssertEqual(gradient(at: Tensor(1), in: factorial), Tensor(0))
137+
XCTAssertEqual(gradient(at: Tensor(2), in: factorial), Tensor(1))
138+
XCTAssertEqual(gradient(at: Tensor(3), in: factorial), Tensor(5))
139+
XCTAssertEqual(gradient(at: Tensor(4), in: factorial), Tensor(26))
140+
XCTAssertEqual(gradient(at: Tensor(5), in: factorial), Tensor(154))
141+
142+
func product(_ x: Tensor<Float>, count: Int) -> Tensor<Float> {
143+
precondition(count > 0)
144+
if count == 1 {
145+
return x
146+
}
147+
return x * product(x, count: count - 1)
148+
}
149+
XCTAssertEqual(gradient(at: Tensor(-10), in: { x in product(x, count: 2) }), Tensor(-20))
150+
XCTAssertEqual(gradient(at: Tensor(10), in: { x in product(x, count: 3) }), Tensor(300))
151+
XCTAssertEqual(gradient(at: Tensor(100), in: { x in product(x, count: 1) }), Tensor(1))
152+
}
153+
44154
func testScalarGenericGrad() {
45155
// Tests TF-287.
46156
func negate<T : TensorFlowFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
@@ -567,6 +677,9 @@ final class TensorAutoDiffTests: XCTestCase {
567677
static var allTests = [
568678
("testSimpleGrad", testSimpleGrad),
569679
("testGenericGrad", testGenericGrad),
680+
("testConditionals", testConditionals),
681+
("testNestedConditionals", testNestedConditionals),
682+
("testRecursion", testRecursion),
570683
("testScalarGenericGrad", testScalarGenericGrad),
571684
("testScalarized", testScalarized),
572685
("testScalars", testScalars),

Tests/TensorFlowTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public func allTests() -> [XCTestCaseEntry] {
3939
testCase(SequentialTests.allTests),
4040
testCase(TensorAutoDiffTests.allTests),
4141
testCase(TensorGroupTests.allTests),
42+
testCase(TensorAutoDiffTests.allTests),
4243
testCase(TensorTests.allTests),
4344
testCase(TrivialModelTests.allTests),
4445
]

0 commit comments

Comments
 (0)