Skip to content

Commit ddd9f4e

Browse files
committed
Add SupportedTensorDtypes::BOOL
ghstack-source-id: d483b1c ghstack-comment-id: 2751961032 Pull Request resolved: #9584
1 parent fa9e0f9 commit ddd9f4e

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

kernels/portable/cpu/util/dtype_util.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ bool check_tensor_dtype(
2727
return executorch::runtime::tensor_is_floating_type(t);
2828
case SupportedTensorDtypes::INTB:
2929
return executorch::runtime::tensor_is_integral_type(t, true);
30+
case SupportedTensorDtypes::BOOL:
31+
return executorch::runtime::tensor_is_type(t, ScalarType::Bool);
3032
case SupportedTensorDtypes::BOOL_OR_BYTE:
3133
return (executorch::runtime::tensor_is_type(
3234
t, ScalarType::Bool, ScalarType::Byte));

kernels/portable/cpu/util/dtype_util.h

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_intb(const Tensor& t) {
7272
return result;
7373
}
7474

75+
template <typename CTYPE_COMPUTE, const char* op_name>
76+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool(const Tensor& t) {
77+
ET_CHECK_MSG(
78+
t.scalar_type() == ScalarType::Bool,
79+
"Unhandled dtype %s for %s",
80+
::executorch::runtime::toString(t.scalar_type()),
81+
op_name);
82+
return internal::load_and_convert<CTYPE_COMPUTE, bool>;
83+
}
84+
7585
template <typename CTYPE_COMPUTE, const char* op_name>
7686
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool_or_byte(
7787
const Tensor& t) {
@@ -165,6 +175,17 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_intb(
165175
return result;
166176
}
167177

178+
template <typename CTYPE_COMPUTE, const char* op_name>
179+
store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_bool(
180+
const Tensor& t) {
181+
ET_CHECK_MSG(
182+
t.scalar_type() == ScalarType::Bool,
183+
"Unhandled dtype %s for %s",
184+
::executorch::runtime::toString(t.scalar_type()),
185+
op_name);
186+
return internal::convert_and_store<bool, CTYPE_COMPUTE>;
187+
}
188+
168189
template <typename CTYPE_COMPUTE, const char* op_name>
169190
store_compute_to_tensor_fn<CTYPE_COMPUTE>
170191
get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) {
@@ -219,6 +240,7 @@ enum class SupportedTensorDtypes {
219240
REALHBF16,
220241
FLOATHBF16,
221242
INTB,
243+
BOOL,
222244
BOOL_OR_BYTE,
223245
// DEPRECATED: not likely to be correct; use SAME_AS_COMMON.
224246
SAME_AS_COMPUTE,
@@ -240,6 +262,8 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_impl(
240262
return get_load_to_compute_fn_realhbf16<CTYPE_COMPUTE, op_name>(t);
241263
case SupportedTensorDtypes::INTB:
242264
return get_load_to_compute_fn_intb<CTYPE_COMPUTE, op_name>(t);
265+
case SupportedTensorDtypes::BOOL:
266+
return get_load_to_compute_fn_bool<CTYPE_COMPUTE, op_name>(t);
243267
case SupportedTensorDtypes::BOOL_OR_BYTE:
244268
return get_load_to_compute_fn_bool_or_byte<CTYPE_COMPUTE, op_name>(t);
245269
case SupportedTensorDtypes::SAME_AS_COMPUTE:
@@ -261,20 +285,18 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn(
261285
SupportedTensorDtypes dtypes) {
262286
switch (dtypes) {
263287
case SupportedTensorDtypes::REALHBBF16:
264-
return get_store_compute_to_tensor_fn_realhbbf16<CTYPE_COMPUTE, op_name>(
265-
t);
288+
return get_store_compute_to_tensor_fn_realhbbf16<CTYPE_COMPUTE, op_name>(t);
266289
case SupportedTensorDtypes::REALHBF16:
267-
return get_store_compute_to_tensor_fn_realhbf16<CTYPE_COMPUTE, op_name>(
268-
t);
290+
return get_store_compute_to_tensor_fn_realhbf16<CTYPE_COMPUTE, op_name>(t);
269291
case SupportedTensorDtypes::FLOATHBF16:
270-
return get_store_compute_to_tensor_fn_floathbf16<CTYPE_COMPUTE, op_name>(
271-
t);
292+
return get_store_compute_to_tensor_fn_floathbf16<CTYPE_COMPUTE, op_name>(t);
272293
case SupportedTensorDtypes::INTB:
273294
return get_store_compute_to_tensor_fn_intb<CTYPE_COMPUTE, op_name>(t);
295+
case SupportedTensorDtypes::BOOL:
296+
return get_store_compute_to_tensor_fn_bool<CTYPE_COMPUTE, op_name>(t);
274297
case SupportedTensorDtypes::BOOL_OR_BYTE:
275-
return get_store_compute_to_tensor_fn_bool_or_byte<
276-
CTYPE_COMPUTE,
277-
op_name>(t);
298+
return get_store_compute_to_tensor_fn_bool_or_byte<CTYPE_COMPUTE, op_name>(
299+
t);
278300
case SupportedTensorDtypes::SAME_AS_COMPUTE:
279301
return get_store_compute_to_tensor_fn_same_as_compute<
280302
CTYPE_COMPUTE,
@@ -318,12 +340,14 @@ bool check_tensor_dtype(
318340
const ScalarType compute_type);
319341

320342
/// Return the one output type we are willing to emit specialized code
321-
/// to handle, given a compute type of CTYPE_COMMON and supported
343+
/// to handle, given a compute type of CTYPE_COMPUTE and supported
322344
/// output types of out_dtypes.
323345
template <typename CTYPE_COMPUTE>
324346
inline constexpr ScalarType specialized_output_scalar_type(
325347
SupportedTensorDtypes out_dtypes) {
326348
switch (out_dtypes) {
349+
case SupportedTensorDtypes::BOOL:
350+
return ScalarType::Bool;
327351
case SupportedTensorDtypes::BOOL_OR_BYTE:
328352
return ScalarType::Bool;
329353
case SupportedTensorDtypes::REALHBBF16:

0 commit comments

Comments
 (0)