@@ -148,6 +148,54 @@ class TensorTest: XCTestCase {
148
148
}
149
149
}
150
150
151
+ func testWithUnsafeBytes( ) throws {
152
+ var data : [ Float ] = [ 1 , 2 , 3 , 4 , 5 , 6 ]
153
+ let tensor = data. withUnsafeMutableBytes {
154
+ Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 2 , 3 ] , dataType: . float)
155
+ }
156
+ let array : [ Float ] = try tensor. withUnsafeBytes { Array ( $0) }
157
+ XCTAssertEqual ( array, data)
158
+ }
159
+
160
+ func testWithUnsafeMutableBytes( ) throws {
161
+ var data = [ 1 , 2 , 3 , 4 ]
162
+ let tensor = data. withUnsafeMutableBytes {
163
+ Tensor ( bytes: $0. baseAddress!, shape: [ 4 ] , dataType: . long)
164
+ }
165
+ try tensor. withUnsafeMutableBytes { ( buffer: UnsafeMutableBufferPointer < Int > ) in
166
+ for i in buffer. indices {
167
+ buffer [ i] *= 2
168
+ }
169
+ }
170
+ try tensor. withUnsafeBytes { buffer in
171
+ XCTAssertEqual ( Array ( buffer) , [ 2 , 4 , 6 , 8 ] )
172
+ }
173
+ }
174
+
175
+ func testWithUnsafeBytesFloat16( ) throws {
176
+ var data : [ Float16 ] = [ 1 , 2 , 3 , 4 , 5 , 6 ]
177
+ let tensor = data. withUnsafeMutableBytes {
178
+ Tensor ( bytesNoCopy: $0. baseAddress!, shape: [ 6 ] , dataType: . half)
179
+ }
180
+ let array : [ Float16 ] = try tensor. withUnsafeBytes { Array ( $0) }
181
+ XCTAssertEqual ( array, data)
182
+ }
183
+
184
+ func testWithUnsafeMutableBytesFloat16( ) throws {
185
+ var data : [ Float16 ] = [ 1 , 2 , 3 , 4 ]
186
+ let tensor = data. withUnsafeMutableBytes { buf in
187
+ Tensor ( bytes: buf. baseAddress!, shape: [ 4 ] , dataType: . half)
188
+ }
189
+ try tensor. withUnsafeMutableBytes { ( buffer: UnsafeMutableBufferPointer < Float16 > ) in
190
+ for i in buffer. indices {
191
+ buffer [ i] *= 2
192
+ }
193
+ }
194
+ try tensor. withUnsafeBytes { buffer in
195
+ XCTAssertEqual ( Array ( buffer) , data. map { $0 * 2 } )
196
+ }
197
+ }
198
+
151
199
func testInitWithTensor( ) {
152
200
var data : [ Int ] = [ 10 , 20 , 30 , 40 ]
153
201
let tensor1 = data. withUnsafeMutableBytes {
@@ -618,7 +666,7 @@ class TensorTest: XCTestCase {
618
666
}
619
667
}
620
668
}
621
-
669
+
622
670
func testZeros( ) {
623
671
let tensor = Tensor . zeros ( shape: [ 2 , 3 ] , dataType: . double)
624
672
XCTAssertEqual ( tensor. shape, [ 2 , 3 ] )
0 commit comments