File tree Expand file tree Collapse file tree 2 files changed +30
-7
lines changed
kernels/portable/cpu/util
runtime/core/exec_aten/util Expand file tree Collapse file tree 2 files changed +30
-7
lines changed Original file line number Diff line number Diff line change @@ -28,17 +28,14 @@ bool check_tensor_dtype(
28
28
case SupportedTensorDtypes::INTB:
29
29
return executorch::runtime::tensor_is_integral_type (t, true );
30
30
case SupportedTensorDtypes::BOOL_OR_BYTE:
31
- return (
32
- executorch::runtime::tensor_is_type (t, ScalarType::Bool) ||
33
- executorch::runtime::tensor_is_type (t, ScalarType::Byte));
31
+ return (executorch::runtime::tensor_is_type (
32
+ t, {ScalarType::Bool, ScalarType::Byte}));
34
33
case SupportedTensorDtypes::SAME_AS_COMPUTE:
35
34
return executorch::runtime::tensor_is_type (t, compute_type);
36
35
case SupportedTensorDtypes::SAME_AS_COMMON: {
37
36
if (compute_type == ScalarType::Float) {
38
- return (
39
- executorch::runtime::tensor_is_type (t, ScalarType::Float) ||
40
- executorch::runtime::tensor_is_type (t, ScalarType::Half) ||
41
- executorch::runtime::tensor_is_type (t, ScalarType::BFloat16));
37
+ return (executorch::runtime::tensor_is_type (
38
+ t, {ScalarType::Float, ScalarType::Half, ScalarType::BFloat16}));
42
39
} else {
43
40
return executorch::runtime::tensor_is_type (t, compute_type);
44
41
}
Original file line number Diff line number Diff line change 14
14
#include < cmath>
15
15
#include < cstddef> // size_t
16
16
#include < limits>
17
+ #include < sstream>
18
+ #include < vector>
17
19
18
20
#include < executorch/runtime/core/array_ref.h>
19
21
#include < executorch/runtime/core/error.h>
@@ -484,6 +486,30 @@ inline bool tensor_is_type(
484
486
return true ;
485
487
}
486
488
489
+ inline bool tensor_is_type (
490
+ executorch::aten::Tensor t,
491
+ const std::vector<executorch::aten::ScalarType>& dtypes) {
492
+ if (std::find (dtypes.begin (), dtypes.end (), t.scalar_type ()) !=
493
+ dtypes.end ()) {
494
+ return true ;
495
+ }
496
+
497
+ std::stringstream dtype_ss;
498
+ for (size_t i = 0 ; i < dtypes.size (); i++) {
499
+ if (i != 0 ) {
500
+ dtype_ss << " , " ;
501
+ }
502
+ dtype_ss << torch::executor::toString (dtypes[i]);
503
+ }
504
+
505
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
506
+ false ,
507
+ " Expected to find one of %s types, but tensor has type %s" ,
508
+ dtype_ss.str ().c_str (),
509
+ torch::executor::toString (t.scalar_type ()));
510
+ return false ;
511
+ }
512
+
487
513
inline bool tensor_is_integral_type (
488
514
executorch::aten::Tensor t,
489
515
bool includeBool = false ) {
You can’t perform that action at this time.
0 commit comments