Skip to content

Commit 9123e91

Browse files
authored
Add SupportedTensorDtypes::BOOL (#9584)
1 parent 2dedc9e commit 9123e91

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

kernels/portable/cpu/util/dtype_util.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ bool check_tensor_dtype(
2727
return executorch::runtime::tensor_is_floating_type(t);
2828
case SupportedTensorDtypes::INTB:
2929
return executorch::runtime::tensor_is_integral_type(t, true);
30+
case SupportedTensorDtypes::BOOL:
31+
return executorch::runtime::tensor_is_type(t, ScalarType::Bool);
3032
case SupportedTensorDtypes::BOOL_OR_BYTE:
3133
return (executorch::runtime::tensor_is_type(
3234
t, ScalarType::Bool, ScalarType::Byte));

kernels/portable/cpu/util/dtype_util.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_intb(const Tensor& t) {
7272
return result;
7373
}
7474

75+
template <typename CTYPE_COMPUTE, const char* op_name>
76+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool(const Tensor& t) {
77+
ET_CHECK_MSG(
78+
t.scalar_type() == ScalarType::Bool,
79+
"Unhandled dtype %s for %s",
80+
::executorch::runtime::toString(t.scalar_type()),
81+
op_name);
82+
return internal::load_and_convert<CTYPE_COMPUTE, bool>;
83+
}
84+
7585
template <typename CTYPE_COMPUTE, const char* op_name>
7686
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool_or_byte(
7787
const Tensor& t) {
@@ -165,6 +175,17 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_intb(
165175
return result;
166176
}
167177

178+
template <typename CTYPE_COMPUTE, const char* op_name>
179+
store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_bool(
180+
const Tensor& t) {
181+
ET_CHECK_MSG(
182+
t.scalar_type() == ScalarType::Bool,
183+
"Unhandled dtype %s for %s",
184+
::executorch::runtime::toString(t.scalar_type()),
185+
op_name);
186+
return internal::convert_and_store<bool, CTYPE_COMPUTE>;
187+
}
188+
168189
template <typename CTYPE_COMPUTE, const char* op_name>
169190
store_compute_to_tensor_fn<CTYPE_COMPUTE>
170191
get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) {
@@ -219,6 +240,7 @@ enum class SupportedTensorDtypes {
219240
REALHBF16,
220241
FLOATHBF16,
221242
INTB,
243+
BOOL,
222244
BOOL_OR_BYTE,
223245
// DEPRECATED: not likely to be correct; use SAME_AS_COMMON.
224246
SAME_AS_COMPUTE,
@@ -240,6 +262,8 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_impl(
240262
return get_load_to_compute_fn_realhbf16<CTYPE_COMPUTE, op_name>(t);
241263
case SupportedTensorDtypes::INTB:
242264
return get_load_to_compute_fn_intb<CTYPE_COMPUTE, op_name>(t);
265+
case SupportedTensorDtypes::BOOL:
266+
return get_load_to_compute_fn_bool<CTYPE_COMPUTE, op_name>(t);
243267
case SupportedTensorDtypes::BOOL_OR_BYTE:
244268
return get_load_to_compute_fn_bool_or_byte<CTYPE_COMPUTE, op_name>(t);
245269
case SupportedTensorDtypes::SAME_AS_COMPUTE:
@@ -271,6 +295,8 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn(
271295
t);
272296
case SupportedTensorDtypes::INTB:
273297
return get_store_compute_to_tensor_fn_intb<CTYPE_COMPUTE, op_name>(t);
298+
case SupportedTensorDtypes::BOOL:
299+
return get_store_compute_to_tensor_fn_bool<CTYPE_COMPUTE, op_name>(t);
274300
case SupportedTensorDtypes::BOOL_OR_BYTE:
275301
return get_store_compute_to_tensor_fn_bool_or_byte<
276302
CTYPE_COMPUTE,
@@ -318,12 +344,14 @@ bool check_tensor_dtype(
318344
const ScalarType compute_type);
319345

320346
/// Return the one output type we are willing to emit specialized code
321-
/// to handle, given a compute type of CTYPE_COMMON and supported
347+
/// to handle, given a compute type of CTYPE_COMPUTE and supported
322348
/// output types of out_dtypes.
323349
template <typename CTYPE_COMPUTE>
324350
inline constexpr ScalarType specialized_output_scalar_type(
325351
SupportedTensorDtypes out_dtypes) {
326352
switch (out_dtypes) {
353+
case SupportedTensorDtypes::BOOL:
354+
return ScalarType::Bool;
327355
case SupportedTensorDtypes::BOOL_OR_BYTE:
328356
return ScalarType::Bool;
329357
case SupportedTensorDtypes::REALHBBF16:

0 commit comments

Comments
 (0)