Skip to content

Commit 2d0d7d7

Browse files
authored
Create a native Tensor swift extension (#11076)
1 parent 62873b3 commit 2d0d7d7

File tree

3 files changed

+140
-2
lines changed

3 files changed

+140
-2
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
@_exported import ExecuTorch
10+
11+
/// A protocol that types conform to in order to be used as tensor element types.
12+
/// Provides the mapping from the Swift type to the underlying `DataType`.
13+
@available(*, deprecated, message: "This API is experimental.")
14+
protocol Scalar {
15+
/// The `DataType` corresponding to this scalar type.
16+
static var dataType: DataType { get }
17+
}
18+
19+
@available(*, deprecated, message: "This API is experimental.")
20+
extension UInt8: Scalar { static var dataType: DataType { .byte } }
21+
@available(*, deprecated, message: "This API is experimental.")
22+
extension Int8: Scalar { static var dataType: DataType { .char } }
23+
@available(*, deprecated, message: "This API is experimental.")
24+
extension Int16: Scalar { static var dataType: DataType { .short } }
25+
@available(*, deprecated, message: "This API is experimental.")
26+
extension Int32: Scalar { static var dataType: DataType { .int } }
27+
@available(*, deprecated, message: "This API is experimental.")
28+
extension Int64: Scalar { static var dataType: DataType { .long } }
29+
@available(*, deprecated, message: "This API is experimental.")
30+
extension Int: Scalar { static var dataType: DataType { .long } }
31+
@available(macOS 11.0, *)
32+
@available(*, deprecated, message: "This API is experimental.")
33+
extension Float16: Scalar { static var dataType: DataType { .half } }
34+
@available(*, deprecated, message: "This API is experimental.")
35+
extension Float: Scalar { static var dataType: DataType { .float } }
36+
@available(*, deprecated, message: "This API is experimental.")
37+
extension Double: Scalar { static var dataType: DataType { .double } }
38+
@available(*, deprecated, message: "This API is experimental.")
39+
extension Bool: Scalar { static var dataType: DataType { .bool } }
40+
@available(*, deprecated, message: "This API is experimental.")
41+
extension UInt16: Scalar { static var dataType: DataType { .uInt16 } }
42+
@available(*, deprecated, message: "This API is experimental.")
43+
extension UInt32: Scalar { static var dataType: DataType { .uInt32 } }
44+
@available(*, deprecated, message: "This API is experimental.")
45+
extension UInt64: Scalar { static var dataType: DataType { .uInt64 } }
46+
@available(*, deprecated, message: "This API is experimental.")
47+
extension UInt: Scalar { static var dataType: DataType { .uInt64 } }
48+
49+
@available(*, deprecated, message: "This API is experimental.")
50+
extension Tensor {
51+
/// Calls the closure with a typed, immutable buffer pointer over the tensor’s elements.
52+
///
53+
/// - Parameter body: A closure that receives an `UnsafeBufferPointer<T>` bound to the tensor’s data.
54+
/// - Returns: The value returned by `body`.
55+
/// - Throws: `Error(code: .invalidArgument)` if `T.dataType` doesn’t match the tensor’s `dataType`,
56+
/// or any error thrown by `body`.
57+
func withUnsafeBytes<T: Scalar, R>(_ body: (UnsafeBufferPointer<T>) throws -> R) throws -> R {
58+
guard dataType == T.dataType else { throw Error(code: .invalidArgument) }
59+
var result: Result<R, Error>?
60+
bytes { pointer, count, _ in
61+
result = Result { try body(
62+
UnsafeBufferPointer(
63+
start: pointer.assumingMemoryBound(to: T.self),
64+
count: count
65+
)
66+
) }
67+
}
68+
return try result!.get()
69+
}
70+
71+
/// Calls the closure with a typed, mutable buffer pointer over the tensor’s elements.
72+
///
73+
/// - Parameter body: A closure that receives an `UnsafeMutableBufferPointer<T>` bound to the tensor’s data.
74+
/// - Returns: The value returned by `body`.
75+
/// - Throws: `Error(code: .invalidArgument)` if `T.dataType` doesn’t match the tensor’s `dataType`,
76+
/// or any error thrown by `body`.
77+
func withUnsafeMutableBytes<T: Scalar, R>(_ body: (UnsafeMutableBufferPointer<T>) throws -> R) throws -> R {
78+
guard dataType == T.dataType else { throw Error(code: .invalidArgument) }
79+
var result: Result<R, Error>?
80+
mutableBytes { pointer, count, _ in
81+
result = Result { try body(
82+
UnsafeMutableBufferPointer(
83+
start: pointer.assumingMemoryBound(to: T.self),
84+
count: count
85+
)
86+
) }
87+
}
88+
return try result!.get()
89+
}
90+
}

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,54 @@ class TensorTest: XCTestCase {
148148
}
149149
}
150150

151+
func testWithUnsafeBytes() throws {
152+
var data: [Float] = [1, 2, 3, 4, 5, 6]
153+
let tensor = data.withUnsafeMutableBytes {
154+
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3], dataType: .float)
155+
}
156+
let array: [Float] = try tensor.withUnsafeBytes { Array($0) }
157+
XCTAssertEqual(array, data)
158+
}
159+
160+
func testWithUnsafeMutableBytes() throws {
161+
var data = [1, 2, 3, 4]
162+
let tensor = data.withUnsafeMutableBytes {
163+
Tensor(bytes: $0.baseAddress!, shape: [4], dataType: .long)
164+
}
165+
try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer<Int>) in
166+
for i in buffer.indices {
167+
buffer[i] *= 2
168+
}
169+
}
170+
try tensor.withUnsafeBytes { buffer in
171+
XCTAssertEqual(Array(buffer), [2, 4, 6, 8])
172+
}
173+
}
174+
175+
func testWithUnsafeBytesFloat16() throws {
176+
var data: [Float16] = [1, 2, 3, 4, 5, 6]
177+
let tensor = data.withUnsafeMutableBytes {
178+
Tensor(bytesNoCopy: $0.baseAddress!, shape: [6], dataType: .half)
179+
}
180+
let array: [Float16] = try tensor.withUnsafeBytes { Array($0) }
181+
XCTAssertEqual(array, data)
182+
}
183+
184+
func testWithUnsafeMutableBytesFloat16() throws {
185+
var data: [Float16] = [1, 2, 3, 4]
186+
let tensor = data.withUnsafeMutableBytes { buffer in
187+
Tensor(bytes: buffer.baseAddress!, shape: [4], dataType: .half)
188+
}
189+
try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer<Float16>) in
190+
for i in buffer.indices {
191+
buffer[i] *= 2
192+
}
193+
}
194+
try tensor.withUnsafeBytes { buffer in
195+
XCTAssertEqual(Array(buffer), data.map { $0 * 2 })
196+
}
197+
}
198+
151199
func testInitWithTensor() {
152200
var data: [Int] = [10, 20, 30, 40]
153201
let tensor1 = data.withUnsafeMutableBytes {
@@ -618,7 +666,7 @@ class TensorTest: XCTestCase {
618666
}
619667
}
620668
}
621-
669+
622670
func testZeros() {
623671
let tensor = Tensor.zeros(shape: [2, 3], dataType: .double)
624672
XCTAssertEqual(tensor.shape, [2, 3])

scripts/build_apple_frameworks.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ libextension_data_loader.a,\
3333
libextension_flat_tensor.a,\
3434
libextension_module.a,\
3535
libextension_tensor.a,\
36-
:$HEADERS_PATH"
36+
:$HEADERS_PATH:ExecuTorch"
3737

3838
FRAMEWORK_BACKEND_COREML="backend_coreml:\
3939
libcoreml_util.a,\

0 commit comments

Comments
 (0)