Skip to content

Commit 2e2cf98

Browse files
authored
Utility helpers to convert between std::vector and NSArray.
Differential Revision: D71752746 Pull Request resolved: #9597
1 parent 9e8503c commit 2e2cf98

File tree

3 files changed

+82
-21
lines changed

3 files changed

+82
-21
lines changed

extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,7 @@ using namespace runtime;
2323
* @param number The NSNumber instance whose scalar type is to be deduced.
2424
* @return The corresponding ScalarType.
2525
*/
26-
static inline ScalarType deduceType(NSNumber *number) {
27-
auto type = [number objCType][0];
28-
type = (type >= 'A' && type <= 'Z') ? type + ('a' - 'A') : type;
29-
if (type == 'c') {
30-
return ScalarType::Byte;
31-
} else if (type == 's') {
32-
return ScalarType::Short;
33-
} else if (type == 'i') {
34-
return ScalarType::Int;
35-
} else if (type == 'q' || type == 'l') {
36-
return ScalarType::Long;
37-
} else if (type == 'f') {
38-
return ScalarType::Float;
39-
} else if (type == 'd') {
40-
return ScalarType::Double;
41-
}
42-
ET_CHECK_MSG(false, "Unsupported type: %c", type);
43-
return ScalarType::Undefined;
44-
}
26+
ScalarType deduceType(NSNumber *number);
4527

4628
/**
4729
* Converts the value held in the NSNumber to the specified C++ type T.
@@ -51,8 +33,8 @@ static inline ScalarType deduceType(NSNumber *number) {
5133
* @return The value converted to type T.
5234
*/
5335
template <typename T>
54-
static inline T extractValue(NSNumber *number) {
55-
ET_CHECK_MSG(!(isFloatingType(deduceScalarType(number)) &&
36+
T extractValue(NSNumber *number) {
37+
ET_CHECK_MSG(!(isFloatingType(deduceType(number)) &&
5638
isIntegralType(CppTypeToScalarType<T>::value, true)),
5739
"Cannot convert floating point to integral type");
5840
T value;
@@ -93,6 +75,49 @@ static inline T extractValue(NSNumber *number) {
9375
return value;
9476
}
9577

78+
/**
79+
* Converts an NSArray of NSNumber objects to a std::vector of type T.
80+
*
81+
* @tparam T The target C++ numeric type.
82+
* @param array The NSArray containing NSNumber objects.
83+
* @return A std::vector with the values extracted as type T.
84+
*/
85+
template <typename T>
86+
std::vector<T> toVector(NSArray<NSNumber *> *array) {
87+
std::vector<T> vector;
88+
vector.reserve(array.count);
89+
for (NSNumber *number in array) {
90+
vector.push_back(extractValue<T>(number));
91+
}
92+
return vector;
93+
}
94+
95+
// Trait for types that can be wrapped into an NSNumber.
96+
template <typename T>
97+
constexpr bool isNSNumberWrapable =
98+
std::is_arithmetic_v<T> ||
99+
std::is_same_v<T, BOOL> ||
100+
std::is_same_v<T, BFloat16> ||
101+
std::is_same_v<T, Half>;
102+
103+
/**
104+
* Converts a generic container of numeric values to an NSArray of NSNumber objects.
105+
*
106+
* @tparam Container The container type holding numeric values.
107+
* @param container The container whose items are to be converted.
108+
* @return An NSArray populated with NSNumber objects representing the container's items.
109+
*/
110+
template <typename Container>
111+
NSArray<NSNumber *> *toNSArray(const Container &container) {
112+
static_assert(isNSNumberWrapable<typename Container::value_type>, "Invalid container value type");
113+
const NSUInteger count = std::distance(std::begin(container), std::end(container));
114+
NSMutableArray<NSNumber *> *array = [NSMutableArray arrayWithCapacity:count];
115+
for (const auto &item : container) {
116+
[array addObject:@(item)];
117+
}
118+
return array;
119+
}
120+
96121
} // namespace executorch::extension::utils
97122

98123
#endif // __cplusplus
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#import "ExecuTorchUtils.h"
10+
11+
namespace executorch::extension::utils {
12+
using namespace aten;
13+
using namespace runtime;
14+
15+
ScalarType deduceType(NSNumber *number) {
16+
auto type = [number objCType][0];
17+
type = (type >= 'A' && type <= 'Z') ? type + ('a' - 'A') : type;
18+
if (type == 'c') {
19+
return ScalarType::Byte;
20+
} else if (type == 's') {
21+
return ScalarType::Short;
22+
} else if (type == 'i') {
23+
return ScalarType::Int;
24+
} else if (type == 'q' || type == 'l') {
25+
return ScalarType::Long;
26+
} else if (type == 'f') {
27+
return ScalarType::Float;
28+
} else if (type == 'd') {
29+
return ScalarType::Double;
30+
}
31+
ET_CHECK_MSG(false, "Unsupported type: %c", type);
32+
return ScalarType::Undefined;
33+
}
34+
35+
} // namespace executorch::extension::utils

runtime/core/span.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ namespace runtime {
3535
template <typename T>
3636
class Span final {
3737
public:
38+
using value_type = T;
3839
using iterator = T*;
3940
using size_type = size_t;
4041

0 commit comments

Comments
 (0)