@@ -17,34 +17,34 @@ import XCTest
17
17
import CTensorFlow
18
18
19
19
extension TensorDataType : Equatable {
20
- public static func == ( lhs: TensorDataType , rhs: TensorDataType ) -> Bool {
21
- return Int ( lhs. _cDataType. rawValue) == Int ( rhs. _cDataType. rawValue)
22
- }
20
+ public static func == ( lhs: TensorDataType , rhs: TensorDataType ) -> Bool {
21
+ return Int ( lhs. _cDataType. rawValue) == Int ( rhs. _cDataType. rawValue)
22
+ }
23
23
}
24
24
25
25
struct Empty : TensorGroup { }
26
26
27
27
struct Simple : TensorGroup , Equatable {
28
- var w , b : Tensor < Float >
28
+ var w , b : Tensor < Float >
29
29
}
30
30
31
31
struct Mixed : TensorGroup , Equatable {
32
- // Mutable.
33
- var float : Tensor < Float >
34
- // Immutable.
35
- let int : Tensor < Int32 >
32
+ // Mutable.
33
+ var float : Tensor < Float >
34
+ // Immutable.
35
+ let int : Tensor < Int32 >
36
36
}
37
37
38
38
struct Nested : TensorGroup , Equatable {
39
- // Immutable.
40
- let simple : Simple
41
- // Mutable.
42
- var mixed : Mixed
39
+ // Immutable.
40
+ let simple : Simple
41
+ // Mutable.
42
+ var mixed : Mixed
43
43
}
44
44
45
45
struct Generic < T: TensorGroup & Equatable , U: TensorGroup & Equatable > : TensorGroup , Equatable {
46
- var t : T
47
- var u : U
46
+ var t : T
47
+ var u : U
48
48
}
49
49
50
50
final class TensorGroupTests : XCTestCase {
@@ -61,54 +61,54 @@ final class TensorGroupTests: XCTestCase {
61
61
let w = Tensor < Float > ( 0.1 )
62
62
let b = Tensor < Float > ( 0.1 )
63
63
let simple = Simple ( w: w, b: b)
64
-
64
+
65
65
let status = TF_NewStatus ( )
66
66
let wHandle = TFE_TensorHandleCopySharingTensor (
67
67
w. handle. _cTensorHandle, status) !
68
68
let bHandle = TFE_TensorHandleCopySharingTensor (
69
69
b. handle. _cTensorHandle, status) !
70
70
TF_DeleteStatus ( status)
71
-
71
+
72
72
let buffer = UnsafeMutableBufferPointer< CTensorHandle> . allocate(
73
73
capacity: 2 )
74
74
let _ = buffer. initialize ( from: [ wHandle, bHandle] )
75
75
let expectedSimple = Simple ( _owning: UnsafePointer ( buffer. baseAddress) )
76
-
76
+
77
77
XCTAssertEqual ( expectedSimple, simple)
78
78
}
79
-
79
+
80
80
func testMixedTypeList( ) {
81
81
let float = Float . tensorFlowDataType
82
82
let int = Int32 . tensorFlowDataType
83
83
XCTAssertEqual ( [ float, int] , Mixed . _typeList)
84
84
}
85
-
85
+
86
86
func testMixedInit( ) {
87
87
let float = Tensor < Float > ( 0.1 )
88
88
let int = Tensor < Int32 > ( 1 )
89
89
let mixed = Mixed ( float: float, int: int)
90
-
90
+
91
91
let status = TF_NewStatus ( )
92
92
let floatHandle = TFE_TensorHandleCopySharingTensor (
93
93
float. handle. _cTensorHandle, status) !
94
94
let intHandle = TFE_TensorHandleCopySharingTensor (
95
95
int. handle. _cTensorHandle, status) !
96
96
TF_DeleteStatus ( status)
97
-
97
+
98
98
let buffer = UnsafeMutableBufferPointer< CTensorHandle> . allocate(
99
99
capacity: 2 )
100
100
let _ = buffer. initialize ( from: [ floatHandle, intHandle] )
101
101
let expectedMixed = Mixed ( _owning: UnsafePointer ( buffer. baseAddress) )
102
-
102
+
103
103
XCTAssertEqual ( expectedMixed, mixed)
104
104
}
105
-
105
+
106
106
func testNestedTypeList( ) {
107
107
let float = Float . tensorFlowDataType
108
108
let int = Int32 . tensorFlowDataType
109
109
XCTAssertEqual ( [ float, float, float, int] , Nested . _typeList)
110
110
}
111
-
111
+
112
112
func testNestedInit( ) {
113
113
let w = Tensor < Float > ( 0.1 )
114
114
let b = Tensor < Float > ( 0.1 )
@@ -117,7 +117,7 @@ final class TensorGroupTests: XCTestCase {
117
117
let int = Tensor < Int32 > ( 1 )
118
118
let mixed = Mixed ( float: float, int: int)
119
119
let nested = Nested ( simple: simple, mixed: mixed)
120
-
120
+
121
121
let status = TF_NewStatus ( )
122
122
let wHandle = TFE_TensorHandleCopySharingTensor (
123
123
w. handle. _cTensorHandle, status) !
@@ -128,24 +128,24 @@ final class TensorGroupTests: XCTestCase {
128
128
let intHandle = TFE_TensorHandleCopySharingTensor (
129
129
int. handle. _cTensorHandle, status) !
130
130
TF_DeleteStatus ( status)
131
-
131
+
132
132
let buffer = UnsafeMutableBufferPointer< CTensorHandle> . allocate(
133
133
capacity: 4 )
134
134
let _ = buffer. initialize (
135
135
from: [ wHandle, bHandle, floatHandle, intHandle] )
136
136
let expectedNested = Nested (
137
137
_owning: UnsafePointer ( buffer. baseAddress) )
138
-
138
+
139
139
XCTAssertEqual ( expectedNested, nested)
140
140
}
141
-
141
+
142
142
func testGenericTypeList( ) {
143
143
let float = Float . tensorFlowDataType
144
144
let int = Int32 . tensorFlowDataType
145
145
XCTAssertEqual (
146
146
[ float, float, float, int] , Generic< Simple, Mixed> . _typeList)
147
147
}
148
-
148
+
149
149
func testGenericInit( ) {
150
150
let w = Tensor < Float > ( 0.1 )
151
151
let b = Tensor < Float > ( 0.1 )
@@ -154,7 +154,7 @@ final class TensorGroupTests: XCTestCase {
154
154
let int = Tensor < Int32 > ( 1 )
155
155
let mixed = Mixed ( float: float, int: int)
156
156
let generic = Generic ( t: simple, u: mixed)
157
-
157
+
158
158
let status = TF_NewStatus ( )
159
159
let wHandle = TFE_TensorHandleCopySharingTensor (
160
160
w. handle. _cTensorHandle, status) !
@@ -165,17 +165,17 @@ final class TensorGroupTests: XCTestCase {
165
165
let intHandle = TFE_TensorHandleCopySharingTensor (
166
166
int. handle. _cTensorHandle, status) !
167
167
TF_DeleteStatus ( status)
168
-
168
+
169
169
let buffer = UnsafeMutableBufferPointer< CTensorHandle> . allocate(
170
170
capacity: 4 )
171
171
let _ = buffer. initialize (
172
172
from: [ wHandle, bHandle, floatHandle, intHandle] )
173
173
let expectedGeneric = Generic < Simple , Mixed > (
174
174
_owning: UnsafePointer ( buffer. baseAddress) )
175
-
175
+
176
176
XCTAssertEqual ( expectedGeneric, generic)
177
177
}
178
-
178
+
179
179
func testNestedGenericTypeList( ) {
180
180
struct NestedGeneric {
181
181
func function( ) {
@@ -191,10 +191,10 @@ final class TensorGroupTests: XCTestCase {
191
191
UltraNested< Simple, Mixed> . _typeList)
192
192
}
193
193
}
194
-
194
+
195
195
NestedGeneric ( ) . function ( )
196
196
}
197
-
197
+
198
198
func testNestedGenericInit( ) {
199
199
struct NestedGeneric {
200
200
func function( ) {
@@ -204,7 +204,7 @@ final class TensorGroupTests: XCTestCase {
204
204
var a : Generic < T , V >
205
205
var b : Generic < V , T >
206
206
}
207
-
207
+
208
208
let w = Tensor < Float > ( 0.1 )
209
209
let b = Tensor < Float > ( 0.1 )
210
210
let simple = Simple ( w: w, b: b)
@@ -214,28 +214,38 @@ final class TensorGroupTests: XCTestCase {
214
214
let genericSM = Generic < Simple , Mixed > ( t: simple, u: mixed)
215
215
let genericMS = Generic < Mixed , Simple > ( t: mixed, u: simple)
216
216
let generic = UltraNested ( a: genericSM, b: genericMS)
217
-
217
+
218
218
let status = TF_NewStatus ( )
219
- let wHandle1 = TFE_TensorHandleCopySharingTensor ( w. handle. _cTensorHandle, status) !
220
- let wHandle2 = TFE_TensorHandleCopySharingTensor ( w. handle. _cTensorHandle, status) !
221
- let bHandle1 = TFE_TensorHandleCopySharingTensor ( b. handle. _cTensorHandle, status) !
222
- let bHandle2 = TFE_TensorHandleCopySharingTensor ( b. handle. _cTensorHandle, status) !
223
- let floatHandle1 = TFE_TensorHandleCopySharingTensor ( float. handle. _cTensorHandle, status) !
224
- let floatHandle2 = TFE_TensorHandleCopySharingTensor ( float. handle. _cTensorHandle, status) !
225
- let intHandle1 = TFE_TensorHandleCopySharingTensor ( int. handle. _cTensorHandle, status) !
226
- let intHandle2 = TFE_TensorHandleCopySharingTensor ( int. handle. _cTensorHandle, status) !
219
+ let wHandle1 = TFE_TensorHandleCopySharingTensor (
220
+ w. handle. _cTensorHandle, status) !
221
+ let wHandle2 = TFE_TensorHandleCopySharingTensor (
222
+ w. handle. _cTensorHandle, status) !
223
+ let bHandle1 = TFE_TensorHandleCopySharingTensor (
224
+ b. handle. _cTensorHandle, status) !
225
+ let bHandle2 = TFE_TensorHandleCopySharingTensor (
226
+ b. handle. _cTensorHandle, status) !
227
+ let floatHandle1 = TFE_TensorHandleCopySharingTensor (
228
+ float. handle. _cTensorHandle, status) !
229
+ let floatHandle2 = TFE_TensorHandleCopySharingTensor (
230
+ float. handle. _cTensorHandle, status) !
231
+ let intHandle1 = TFE_TensorHandleCopySharingTensor (
232
+ int. handle. _cTensorHandle, status) !
233
+ let intHandle2 = TFE_TensorHandleCopySharingTensor (
234
+ int. handle. _cTensorHandle, status) !
227
235
TF_DeleteStatus ( status)
228
-
229
- let buffer = UnsafeMutableBufferPointer< CTensorHandle> . allocate( capacity: 8 )
230
- let _ = buffer. initialize ( from: [ wHandle1, bHandle1, floatHandle1, intHandle1,
236
+
237
+ let buffer = UnsafeMutableBufferPointer< CTensorHandle> . allocate(
238
+ capacity: 8 )
239
+ let _ = buffer. initialize (
240
+ from: [ wHandle1, bHandle1, floatHandle1, intHandle1,
231
241
floatHandle2, intHandle2, wHandle2, bHandle2] )
232
242
let expectedGeneric = UltraNested < Simple , Mixed > (
233
243
_owning: UnsafePointer ( buffer. baseAddress) )
234
-
244
+
235
245
XCTAssertEqual ( expectedGeneric, generic)
236
246
}
237
247
}
238
-
248
+
239
249
NestedGeneric ( ) . function ( )
240
250
}
241
251
0 commit comments