@@ -72,6 +72,16 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_intb(const Tensor& t) {
72
72
return result;
73
73
}
74
74
75
+ template <typename CTYPE_COMPUTE, const char * op_name>
76
+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool (const Tensor& t) {
77
+ ET_CHECK_MSG (
78
+ t.scalar_type () == ScalarType::Bool,
79
+ " Unhandled dtype %s for %s" ,
80
+ ::executorch::runtime::toString (t.scalar_type()),
81
+ op_name);
82
+ return internal::load_and_convert<CTYPE_COMPUTE, bool >;
83
+ }
84
+
75
85
template <typename CTYPE_COMPUTE, const char * op_name>
76
86
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool_or_byte (
77
87
const Tensor& t) {
@@ -165,6 +175,17 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_intb(
165
175
return result;
166
176
}
167
177
178
+ template <typename CTYPE_COMPUTE, const char * op_name>
179
+ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_bool (
180
+ const Tensor& t) {
181
+ ET_CHECK_MSG (
182
+ t.scalar_type () == ScalarType::Bool,
183
+ " Unhandled dtype %s for %s" ,
184
+ ::executorch::runtime::toString (t.scalar_type()),
185
+ op_name);
186
+ return internal::convert_and_store<bool , CTYPE_COMPUTE>;
187
+ }
188
+
168
189
template <typename CTYPE_COMPUTE, const char * op_name>
169
190
store_compute_to_tensor_fn<CTYPE_COMPUTE>
170
191
get_store_compute_to_tensor_fn_bool_or_byte (const Tensor& t) {
@@ -219,6 +240,7 @@ enum class SupportedTensorDtypes {
219
240
REALHBF16,
220
241
FLOATHBF16,
221
242
INTB,
243
+ BOOL,
222
244
BOOL_OR_BYTE,
223
245
// DEPRECATED: not likely to be correct; use SAME_AS_COMMON.
224
246
SAME_AS_COMPUTE,
@@ -240,6 +262,8 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_impl(
240
262
return get_load_to_compute_fn_realhbf16<CTYPE_COMPUTE, op_name>(t);
241
263
case SupportedTensorDtypes::INTB:
242
264
return get_load_to_compute_fn_intb<CTYPE_COMPUTE, op_name>(t);
265
+ case SupportedTensorDtypes::BOOL:
266
+ return get_load_to_compute_fn_bool<CTYPE_COMPUTE, op_name>(t);
243
267
case SupportedTensorDtypes::BOOL_OR_BYTE:
244
268
return get_load_to_compute_fn_bool_or_byte<CTYPE_COMPUTE, op_name>(t);
245
269
case SupportedTensorDtypes::SAME_AS_COMPUTE:
@@ -271,6 +295,8 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn(
271
295
t);
272
296
case SupportedTensorDtypes::INTB:
273
297
return get_store_compute_to_tensor_fn_intb<CTYPE_COMPUTE, op_name>(t);
298
+ case SupportedTensorDtypes::BOOL:
299
+ return get_store_compute_to_tensor_fn_bool<CTYPE_COMPUTE, op_name>(t);
274
300
case SupportedTensorDtypes::BOOL_OR_BYTE:
275
301
return get_store_compute_to_tensor_fn_bool_or_byte<
276
302
CTYPE_COMPUTE,
@@ -318,12 +344,14 @@ bool check_tensor_dtype(
318
344
const ScalarType compute_type);
319
345
320
346
// / Return the one output type we are willing to emit specialized code
321
- // / to handle, given a compute type of CTYPE_COMMON and supported
347
+ // / to handle, given a compute type of CTYPE_COMPUTE and supported
322
348
// / output types of out_dtypes.
323
349
template <typename CTYPE_COMPUTE>
324
350
inline constexpr ScalarType specialized_output_scalar_type (
325
351
SupportedTensorDtypes out_dtypes) {
326
352
switch (out_dtypes) {
353
+ case SupportedTensorDtypes::BOOL:
354
+ return ScalarType::Bool;
327
355
case SupportedTensorDtypes::BOOL_OR_BYTE:
328
356
return ScalarType::Bool;
329
357
case SupportedTensorDtypes::REALHBBF16:
0 commit comments