Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 78e5a1d

Browse files
authored
Moving the tests for TensorGroup from swift repo to swift-apis. (#158)
* Moving the tests for TensorGroup from swift repo to swift-apis. * Fix indentation and whitespace issues.
1 parent 3e8fa4f commit 78e5a1d

File tree

2 files changed

+267
-0
lines changed

2 files changed

+267
-0
lines changed
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
@testable import TensorFlow
17+
import CTensorFlow
18+
19+
extension TensorDataType : Equatable {
20+
public static func == (lhs: TensorDataType, rhs: TensorDataType) -> Bool {
21+
return Int(lhs._cDataType.rawValue) == Int(rhs._cDataType.rawValue)
22+
}
23+
}
24+
25+
struct Empty : TensorGroup {}
26+
27+
struct Simple : TensorGroup, Equatable {
28+
var w, b: Tensor<Float>
29+
}
30+
31+
struct Mixed : TensorGroup, Equatable {
32+
// Mutable.
33+
var float: Tensor<Float>
34+
// Immutable.
35+
let int: Tensor<Int32>
36+
}
37+
38+
struct Nested : TensorGroup, Equatable {
39+
// Immutable.
40+
let simple: Simple
41+
// Mutable.
42+
var mixed: Mixed
43+
}
44+
45+
struct Generic<T: TensorGroup & Equatable, U: TensorGroup & Equatable> : TensorGroup, Equatable {
46+
var t: T
47+
var u: U
48+
}
49+
50+
final class TensorGroupTests: XCTestCase {
51+
func testEmptyList() {
52+
XCTAssertEqual([], Empty._typeList)
53+
}
54+
55+
func testSimpleTypeList() {
56+
let float = Float.tensorFlowDataType
57+
XCTAssertEqual([float, float], Simple._typeList)
58+
}
59+
60+
func testSimpleInit() {
61+
let w = Tensor<Float>(0.1)
62+
let b = Tensor<Float>(0.1)
63+
let simple = Simple(w: w, b: b)
64+
65+
let status = TF_NewStatus()
66+
let wHandle = TFE_TensorHandleCopySharingTensor(
67+
w.handle._cTensorHandle, status)!
68+
let bHandle = TFE_TensorHandleCopySharingTensor(
69+
b.handle._cTensorHandle, status)!
70+
TF_DeleteStatus(status)
71+
72+
let buffer = UnsafeMutableBufferPointer<CTensorHandle>.allocate(
73+
capacity: 2)
74+
let _ = buffer.initialize(from: [wHandle, bHandle])
75+
let expectedSimple = Simple(_owning: UnsafePointer(buffer.baseAddress))
76+
77+
XCTAssertEqual(expectedSimple, simple)
78+
}
79+
80+
func testMixedTypeList() {
81+
let float = Float.tensorFlowDataType
82+
let int = Int32.tensorFlowDataType
83+
XCTAssertEqual([float, int], Mixed._typeList)
84+
}
85+
86+
func testMixedInit() {
87+
let float = Tensor<Float>(0.1)
88+
let int = Tensor<Int32>(1)
89+
let mixed = Mixed(float: float, int: int)
90+
91+
let status = TF_NewStatus()
92+
let floatHandle = TFE_TensorHandleCopySharingTensor(
93+
float.handle._cTensorHandle, status)!
94+
let intHandle = TFE_TensorHandleCopySharingTensor(
95+
int.handle._cTensorHandle, status)!
96+
TF_DeleteStatus(status)
97+
98+
let buffer = UnsafeMutableBufferPointer<CTensorHandle>.allocate(
99+
capacity: 2)
100+
let _ = buffer.initialize(from: [floatHandle, intHandle])
101+
let expectedMixed = Mixed(_owning: UnsafePointer(buffer.baseAddress))
102+
103+
XCTAssertEqual(expectedMixed, mixed)
104+
}
105+
106+
func testNestedTypeList() {
107+
let float = Float.tensorFlowDataType
108+
let int = Int32.tensorFlowDataType
109+
XCTAssertEqual([float, float, float, int], Nested._typeList)
110+
}
111+
112+
func testNestedInit() {
113+
let w = Tensor<Float>(0.1)
114+
let b = Tensor<Float>(0.1)
115+
let simple = Simple(w: w, b: b)
116+
let float = Tensor<Float>(0.1)
117+
let int = Tensor<Int32>(1)
118+
let mixed = Mixed(float: float, int: int)
119+
let nested = Nested(simple: simple, mixed: mixed)
120+
121+
let status = TF_NewStatus()
122+
let wHandle = TFE_TensorHandleCopySharingTensor(
123+
w.handle._cTensorHandle, status)!
124+
let bHandle = TFE_TensorHandleCopySharingTensor(
125+
b.handle._cTensorHandle, status)!
126+
let floatHandle = TFE_TensorHandleCopySharingTensor(
127+
float.handle._cTensorHandle, status)!
128+
let intHandle = TFE_TensorHandleCopySharingTensor(
129+
int.handle._cTensorHandle, status)!
130+
TF_DeleteStatus(status)
131+
132+
let buffer = UnsafeMutableBufferPointer<CTensorHandle>.allocate(
133+
capacity: 4)
134+
let _ = buffer.initialize(
135+
from: [wHandle, bHandle, floatHandle, intHandle])
136+
let expectedNested = Nested(
137+
_owning: UnsafePointer(buffer.baseAddress))
138+
139+
XCTAssertEqual(expectedNested, nested)
140+
}
141+
142+
func testGenericTypeList() {
143+
let float = Float.tensorFlowDataType
144+
let int = Int32.tensorFlowDataType
145+
XCTAssertEqual(
146+
[float, float, float, int], Generic<Simple, Mixed>._typeList)
147+
}
148+
149+
func testGenericInit() {
150+
let w = Tensor<Float>(0.1)
151+
let b = Tensor<Float>(0.1)
152+
let simple = Simple(w: w, b: b)
153+
let float = Tensor<Float>(0.1)
154+
let int = Tensor<Int32>(1)
155+
let mixed = Mixed(float: float, int: int)
156+
let generic = Generic(t: simple, u: mixed)
157+
158+
let status = TF_NewStatus()
159+
let wHandle = TFE_TensorHandleCopySharingTensor(
160+
w.handle._cTensorHandle, status)!
161+
let bHandle = TFE_TensorHandleCopySharingTensor(
162+
b.handle._cTensorHandle, status)!
163+
let floatHandle = TFE_TensorHandleCopySharingTensor(
164+
float.handle._cTensorHandle, status)!
165+
let intHandle = TFE_TensorHandleCopySharingTensor(
166+
int.handle._cTensorHandle, status)!
167+
TF_DeleteStatus(status)
168+
169+
let buffer = UnsafeMutableBufferPointer<CTensorHandle>.allocate(
170+
capacity: 4)
171+
let _ = buffer.initialize(
172+
from: [wHandle, bHandle, floatHandle, intHandle])
173+
let expectedGeneric = Generic<Simple, Mixed>(
174+
_owning: UnsafePointer(buffer.baseAddress))
175+
176+
XCTAssertEqual(expectedGeneric, generic)
177+
}
178+
179+
func testNestedGenericTypeList() {
180+
struct NestedGeneric {
181+
func function() {
182+
struct UltraNested<
183+
T: TensorGroup & Equatable, V: TensorGroup & Equatable>
184+
: TensorGroup, Equatable {
185+
var a: Generic<T, V>
186+
var b: Generic<V, T>
187+
}
188+
let float = Float.tensorFlowDataType
189+
let int = Int32.tensorFlowDataType
190+
XCTAssertEqual([float, float, float, int, float, int, float, float],
191+
UltraNested<Simple, Mixed>._typeList)
192+
}
193+
}
194+
195+
NestedGeneric().function()
196+
}
197+
198+
func testNestedGenericInit() {
199+
struct NestedGeneric {
200+
func function() {
201+
struct UltraNested<
202+
T: TensorGroup & Equatable, V: TensorGroup & Equatable>
203+
: TensorGroup, Equatable {
204+
var a: Generic<T, V>
205+
var b: Generic<V, T>
206+
}
207+
208+
let w = Tensor<Float>(0.1)
209+
let b = Tensor<Float>(0.1)
210+
let simple = Simple(w: w, b: b)
211+
let float = Tensor<Float>(0.1)
212+
let int = Tensor<Int32>(1)
213+
let mixed = Mixed(float: float, int: int)
214+
let genericSM = Generic<Simple, Mixed>(t: simple, u: mixed)
215+
let genericMS = Generic<Mixed, Simple>(t: mixed, u: simple)
216+
let generic = UltraNested(a: genericSM, b: genericMS)
217+
218+
let status = TF_NewStatus()
219+
let wHandle1 = TFE_TensorHandleCopySharingTensor(
220+
w.handle._cTensorHandle, status)!
221+
let wHandle2 = TFE_TensorHandleCopySharingTensor(
222+
w.handle._cTensorHandle, status)!
223+
let bHandle1 = TFE_TensorHandleCopySharingTensor(
224+
b.handle._cTensorHandle, status)!
225+
let bHandle2 = TFE_TensorHandleCopySharingTensor(
226+
b.handle._cTensorHandle, status)!
227+
let floatHandle1 = TFE_TensorHandleCopySharingTensor(
228+
float.handle._cTensorHandle, status)!
229+
let floatHandle2 = TFE_TensorHandleCopySharingTensor(
230+
float.handle._cTensorHandle, status)!
231+
let intHandle1 = TFE_TensorHandleCopySharingTensor(
232+
int.handle._cTensorHandle, status)!
233+
let intHandle2 = TFE_TensorHandleCopySharingTensor(
234+
int.handle._cTensorHandle, status)!
235+
TF_DeleteStatus(status)
236+
237+
let buffer = UnsafeMutableBufferPointer<CTensorHandle>.allocate(
238+
capacity: 8)
239+
let _ = buffer.initialize(
240+
from: [wHandle1, bHandle1, floatHandle1, intHandle1,
241+
floatHandle2, intHandle2, wHandle2, bHandle2])
242+
let expectedGeneric = UltraNested<Simple, Mixed>(
243+
_owning: UnsafePointer(buffer.baseAddress))
244+
245+
XCTAssertEqual(expectedGeneric, generic)
246+
}
247+
}
248+
249+
NestedGeneric().function()
250+
}
251+
252+
static var allTests = [
253+
("testEmptyList", testEmptyList),
254+
("testSimpleTypeList", testSimpleTypeList),
255+
("testSimpleInit", testSimpleInit),
256+
("testMixedTypelist", testMixedTypeList),
257+
("testMixedInit", testMixedInit),
258+
("testNestedTypeList", testNestedTypeList),
259+
("testNestedInit", testNestedInit),
260+
("testGenericTypeList", testGenericTypeList),
261+
("testGenericInit", testGenericInit),
262+
("testNestedGenericTypeList", testNestedGenericTypeList),
263+
("testNestedGenericInit", testNestedGenericInit)
264+
]
265+
266+
}

Tests/TensorFlowTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public func allTests() -> [XCTestCaseEntry] {
2323
testCase(SequentialTests.allTests),
2424
testCase(LayerTests.allTests),
2525
testCase(TensorTests.allTests),
26+
testCase(TensorGroupTests.allTests),
2627
testCase(BasicOperatorTests.allTests),
2728
testCase(ComparisonOperatorTests.allTests),
2829
testCase(DatasetTests.allTests),

0 commit comments

Comments
 (0)