Skip to content

Commit 129aac6

Browse files
authored
Update TensorTest.swift
1 parent 70a695a commit 129aac6

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,54 @@ class TensorTest: XCTestCase {
148148
}
149149
}
150150

151+
func testWithUnsafeBytes() throws {
152+
var data: [Float] = [1, 2, 3, 4, 5, 6]
153+
let tensor = data.withUnsafeMutableBytes {
154+
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3], dataType: .float)
155+
}
156+
let array: [Float] = try tensor.withUnsafeBytes { Array($0) }
157+
XCTAssertEqual(array, data)
158+
}
159+
160+
func testWithUnsafeMutableBytes() throws {
161+
var data = [1, 2, 3, 4]
162+
let tensor = data.withUnsafeMutableBytes {
163+
Tensor(bytes: $0.baseAddress!, shape: [4], dataType: .long)
164+
}
165+
try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer<Int>) in
166+
for i in buffer.indices {
167+
buffer[i] *= 2
168+
}
169+
}
170+
try tensor.withUnsafeBytes { buffer in
171+
XCTAssertEqual(Array(buffer), [2, 4, 6, 8])
172+
}
173+
}
174+
175+
func testWithUnsafeBytesFloat16() throws {
176+
var data: [Float16] = [1, 2, 3, 4, 5, 6]
177+
let tensor = data.withUnsafeMutableBytes {
178+
Tensor(bytesNoCopy: $0.baseAddress!, shape: [6], dataType: .half)
179+
}
180+
let array: [Float16] = try tensor.withUnsafeBytes { Array($0) }
181+
XCTAssertEqual(array, data)
182+
}
183+
184+
func testWithUnsafeMutableBytesFloat16() throws {
185+
var data: [Float16] = [1, 2, 3, 4]
186+
let tensor = data.withUnsafeMutableBytes { buf in
187+
Tensor(bytes: buf.baseAddress!, shape: [4], dataType: .half)
188+
}
189+
try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer<Float16>) in
190+
for i in buffer.indices {
191+
buffer[i] *= 2
192+
}
193+
}
194+
try tensor.withUnsafeBytes { buffer in
195+
XCTAssertEqual(Array(buffer), data.map { $0 * 2 })
196+
}
197+
}
198+
151199
func testInitWithTensor() {
152200
var data: [Int] = [10, 20, 30, 40]
153201
let tensor1 = data.withUnsafeMutableBytes {
@@ -618,7 +666,7 @@ class TensorTest: XCTestCase {
618666
}
619667
}
620668
}
621-
669+
622670
func testZeros() {
623671
let tensor = Tensor.zeros(shape: [2, 3], dataType: .double)
624672
XCTAssertEqual(tensor.shape, [2, 3])

0 commit comments

Comments
 (0)