Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 85019f4

Browse files
committed
Moving the tests for TensorGroup from swift repo to swift-apis.
1 parent 3e8fa4f commit 85019f4

File tree

2 files changed

+257
-0
lines changed

2 files changed

+257
-0
lines changed
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
@testable import TensorFlow
17+
import CTensorFlow
18+
19+
extension TensorDataType : Equatable {
20+
public static func == (lhs: TensorDataType, rhs: TensorDataType) -> Bool {
21+
return Int(lhs._cDataType.rawValue) == Int(rhs._cDataType.rawValue)
22+
}
23+
}
24+
25+
struct Empty : TensorGroup {}
26+
27+
struct Simple : TensorGroup, Equatable {
28+
var w, b: Tensor<Float>
29+
}
30+
31+
struct Mixed : TensorGroup, Equatable {
32+
// Mutable.
33+
var float: Tensor<Float>
34+
// Immutable.
35+
let int: Tensor<Int32>
36+
}
37+
38+
struct Nested : TensorGroup, Equatable {
39+
// Immutable.
40+
let simple: Simple
41+
// Mutable.
42+
var mixed: Mixed
43+
}
44+
45+
struct Generic<T: TensorGroup & Equatable, U: TensorGroup & Equatable> : TensorGroup, Equatable {
46+
var t: T
47+
var u: U
48+
}
49+
50+
final class TensorGroupTests: XCTestCase {
51+
func testEmptyList() {
52+
XCTAssertEqual([], Empty._typeList)
53+
}
54+
55+
func testSimpleTypeList() {
56+
let float = Float.tensorFlowDataType
57+
XCTAssertEqual([float, float], Simple._typeList)
58+
}
59+
60+
func testSimpleInit() {
61+
let w = Tensor<Float>(0.1)
62+
let b = Tensor<Float>(0.1)
63+
let simple = Simple(w: w, b: b)
64+
65+
let status = TF_NewStatus()
66+
let wHandle = TFE_TensorHandleCopySharingTensor(
67+
w.handle._cTensorHandle, status)!
68+
let bHandle = TFE_TensorHandleCopySharingTensor(
69+
b.handle._cTensorHandle, status)!
70+
TF_DeleteStatus(status)
71+
72+
let buffer = UnsafeMutableBufferPointer<CTensorHandle>.allocate(
73+
capacity: 2)
74+
let _ = buffer.initialize(from: [wHandle, bHandle])
75+
let expectedSimple = Simple(_owning: UnsafePointer(buffer.baseAddress))
76+
77+
XCTAssertEqual(expectedSimple, simple)
78+
}
79+
80+
func testMixedTypeList() {
81+
let float = Float.tensorFlowDataType
82+
let int = Int32.tensorFlowDataType
83+
XCTAssertEqual([float, int], Mixed._typeList)
84+
}
85+
86+
func testMixedInit() {
87+
let float = Tensor<Float>(0.1)
88+
let int = Tensor<Int32>(1)
89+
let mixed = Mixed(float: float, int: int)
90+
91+
let status = TF_NewStatus()
92+
let floatHandle = TFE_TensorHandleCopySharingTensor(
93+
float.handle._cTensorHandle, status)!
94+
let intHandle = TFE_TensorHandleCopySharingTensor(
95+
int.handle._cTensorHandle, status)!
96+
TF_DeleteStatus(status)
97+
98+
let buffer = UnsafeMutableBufferPointer<CTensorHandle>.allocate(
99+
capacity: 2)
100+
let _ = buffer.initialize(from: [floatHandle, intHandle])
101+
let expectedMixed = Mixed(_owning: UnsafePointer(buffer.baseAddress))
102+
103+
XCTAssertEqual(expectedMixed, mixed)
104+
}
105+
106+
func testNestedTypeList() {
107+
let float = Float.tensorFlowDataType
108+
let int = Int32.tensorFlowDataType
109+
XCTAssertEqual([float, float, float, int], Nested._typeList)
110+
}
111+
112+
func testNestedInit() {
113+
let w = Tensor<Float>(0.1)
114+
let b = Tensor<Float>(0.1)
115+
let simple = Simple(w: w, b: b)
116+
let float = Tensor<Float>(0.1)
117+
let int = Tensor<Int32>(1)
118+
let mixed = Mixed(float: float, int: int)
119+
let nested = Nested(simple: simple, mixed: mixed)
120+
121+
let status = TF_NewStatus()
122+
let wHandle = TFE_TensorHandleCopySharingTensor(
123+
w.handle._cTensorHandle, status)!
124+
let bHandle = TFE_TensorHandleCopySharingTensor(
125+
b.handle._cTensorHandle, status)!
126+
let floatHandle = TFE_TensorHandleCopySharingTensor(
127+
float.handle._cTensorHandle, status)!
128+
let intHandle = TFE_TensorHandleCopySharingTensor(
129+
int.handle._cTensorHandle, status)!
130+
TF_DeleteStatus(status)
131+
132+
let buffer = UnsafeMutableBufferPointer<CTensorHandle>.allocate(
133+
capacity: 4)
134+
let _ = buffer.initialize(
135+
from: [wHandle, bHandle, floatHandle, intHandle])
136+
let expectedNested = Nested(
137+
_owning: UnsafePointer(buffer.baseAddress))
138+
139+
XCTAssertEqual(expectedNested, nested)
140+
}
141+
142+
func testGenericTypeList() {
143+
let float = Float.tensorFlowDataType
144+
let int = Int32.tensorFlowDataType
145+
XCTAssertEqual(
146+
[float, float, float, int], Generic<Simple, Mixed>._typeList)
147+
}
148+
149+
func testGenericInit() {
150+
let w = Tensor<Float>(0.1)
151+
let b = Tensor<Float>(0.1)
152+
let simple = Simple(w: w, b: b)
153+
let float = Tensor<Float>(0.1)
154+
let int = Tensor<Int32>(1)
155+
let mixed = Mixed(float: float, int: int)
156+
let generic = Generic(t: simple, u: mixed)
157+
158+
let status = TF_NewStatus()
159+
let wHandle = TFE_TensorHandleCopySharingTensor(
160+
w.handle._cTensorHandle, status)!
161+
let bHandle = TFE_TensorHandleCopySharingTensor(
162+
b.handle._cTensorHandle, status)!
163+
let floatHandle = TFE_TensorHandleCopySharingTensor(
164+
float.handle._cTensorHandle, status)!
165+
let intHandle = TFE_TensorHandleCopySharingTensor(
166+
int.handle._cTensorHandle, status)!
167+
TF_DeleteStatus(status)
168+
169+
let buffer = UnsafeMutableBufferPointer<CTensorHandle>.allocate(
170+
capacity: 4)
171+
let _ = buffer.initialize(
172+
from: [wHandle, bHandle, floatHandle, intHandle])
173+
let expectedGeneric = Generic<Simple, Mixed>(
174+
_owning: UnsafePointer(buffer.baseAddress))
175+
176+
XCTAssertEqual(expectedGeneric, generic)
177+
}
178+
179+
func testNestedGenericTypeList() {
180+
struct NestedGeneric {
181+
func function() {
182+
struct UltraNested<
183+
T: TensorGroup & Equatable, V: TensorGroup & Equatable>
184+
: TensorGroup, Equatable {
185+
var a: Generic<T, V>
186+
var b: Generic<V, T>
187+
}
188+
let float = Float.tensorFlowDataType
189+
let int = Int32.tensorFlowDataType
190+
XCTAssertEqual([float, float, float, int, float, int, float, float],
191+
UltraNested<Simple, Mixed>._typeList)
192+
}
193+
}
194+
195+
NestedGeneric().function()
196+
}
197+
198+
func testNestedGenericInit() {
199+
struct NestedGeneric {
200+
func function() {
201+
struct UltraNested<
202+
T: TensorGroup & Equatable, V: TensorGroup & Equatable>
203+
: TensorGroup, Equatable {
204+
var a: Generic<T, V>
205+
var b: Generic<V, T>
206+
}
207+
208+
let w = Tensor<Float>(0.1)
209+
let b = Tensor<Float>(0.1)
210+
let simple = Simple(w: w, b: b)
211+
let float = Tensor<Float>(0.1)
212+
let int = Tensor<Int32>(1)
213+
let mixed = Mixed(float: float, int: int)
214+
let genericSM = Generic<Simple, Mixed>(t: simple, u: mixed)
215+
let genericMS = Generic<Mixed, Simple>(t: mixed, u: simple)
216+
let generic = UltraNested(a: genericSM, b: genericMS)
217+
218+
let status = TF_NewStatus()
219+
let wHandle1 = TFE_TensorHandleCopySharingTensor(w.handle._cTensorHandle, status)!
220+
let wHandle2 = TFE_TensorHandleCopySharingTensor(w.handle._cTensorHandle, status)!
221+
let bHandle1 = TFE_TensorHandleCopySharingTensor(b.handle._cTensorHandle, status)!
222+
let bHandle2 = TFE_TensorHandleCopySharingTensor(b.handle._cTensorHandle, status)!
223+
let floatHandle1 = TFE_TensorHandleCopySharingTensor(float.handle._cTensorHandle, status)!
224+
let floatHandle2 = TFE_TensorHandleCopySharingTensor(float.handle._cTensorHandle, status)!
225+
let intHandle1 = TFE_TensorHandleCopySharingTensor(int.handle._cTensorHandle, status)!
226+
let intHandle2 = TFE_TensorHandleCopySharingTensor(int.handle._cTensorHandle, status)!
227+
TF_DeleteStatus(status)
228+
229+
let buffer = UnsafeMutableBufferPointer<CTensorHandle>.allocate(capacity: 8)
230+
let _ = buffer.initialize(from: [wHandle1, bHandle1, floatHandle1, intHandle1,
231+
floatHandle2, intHandle2, wHandle2, bHandle2])
232+
let expectedGeneric = UltraNested<Simple, Mixed>(
233+
_owning: UnsafePointer(buffer.baseAddress))
234+
235+
XCTAssertEqual(expectedGeneric, generic)
236+
}
237+
}
238+
239+
NestedGeneric().function()
240+
}
241+
242+
static var allTests = [
243+
("testEmptyList", testEmptyList),
244+
("testSimpleTypeList", testSimpleTypeList),
245+
("testSimpleInit", testSimpleInit),
246+
("testMixedTypelist", testMixedTypeList),
247+
("testMixedInit", testMixedInit),
248+
("testNestedTypeList", testNestedTypeList),
249+
("testNestedInit", testNestedInit),
250+
("testGenericTypeList", testGenericTypeList),
251+
("testGenericInit", testGenericInit),
252+
("testNestedGenericTypeList", testNestedGenericTypeList),
253+
("testNestedGenericInit", testNestedGenericInit)
254+
]
255+
256+
}

Tests/TensorFlowTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public func allTests() -> [XCTestCaseEntry] {
2323
testCase(SequentialTests.allTests),
2424
testCase(LayerTests.allTests),
2525
testCase(TensorTests.allTests),
26+
testCase(TensorGroupTests.allTests),
2627
testCase(BasicOperatorTests.allTests),
2728
testCase(ComparisonOperatorTests.allTests),
2829
testCase(DatasetTests.allTests),

0 commit comments

Comments
 (0)