Skip to content

Commit f13aeff

Browse files
authored
Utility helper to deduce scalar type from NSNumber. (#9552)
Summary: . Differential Revision: D71752750
1 parent cabc4e9 commit f13aeff

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,40 @@
77
*/
88

99
#import <Foundation/Foundation.h>
10+
11+
#ifdef __cplusplus
12+
13+
#import <executorch/runtime/core/exec_aten/exec_aten.h>
14+
15+
namespace executorch::extension::utils {
16+
using namespace aten;
17+
18+
/**
19+
* Deduces the scalar type for a given NSNumber based on its type encoding.
20+
*
21+
* @param number The NSNumber instance whose scalar type is to be deduced.
22+
* @return The corresponding ScalarType.
23+
*/
24+
static inline ScalarType deduceScalarType(NSNumber *number) {
25+
auto type = [number objCType][0];
26+
type = (type >= 'A' && type <= 'Z') ? type + ('a' - 'A') : type;
27+
if (type == 'c') {
28+
return ScalarType::Byte;
29+
} else if (type == 's') {
30+
return ScalarType::Short;
31+
} else if (type == 'i') {
32+
return ScalarType::Int;
33+
} else if (type == 'q' || type == 'l') {
34+
return ScalarType::Long;
35+
} else if (type == 'f') {
36+
return ScalarType::Float;
37+
} else if (type == 'd') {
38+
return ScalarType::Double;
39+
}
40+
ET_CHECK_MSG(false, "Unsupported type: %c", type);
41+
return ScalarType::Undefined;
42+
}
43+
44+
} // namespace executorch::extension::utils
45+
46+
#endif // __cplusplus

0 commit comments

Comments
 (0)