Skip to content

Commit 0225f2f

Browse files
committed
Remove false positive error message in the executor_runner
1 parent a347665 commit 0225f2f

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

kernels/portable/cpu/util/dtype_util.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,14 @@ bool check_tensor_dtype(
2828
case SupportedTensorDtypes::INTB:
2929
return executorch::runtime::tensor_is_integral_type(t, true);
3030
case SupportedTensorDtypes::BOOL_OR_BYTE:
31-
return (
32-
executorch::runtime::tensor_is_type(t, ScalarType::Bool) ||
33-
executorch::runtime::tensor_is_type(t, ScalarType::Byte));
31+
return (executorch::runtime::tensor_is_type(
32+
t, {ScalarType::Bool, ScalarType::Byte}));
3433
case SupportedTensorDtypes::SAME_AS_COMPUTE:
3534
return executorch::runtime::tensor_is_type(t, compute_type);
3635
case SupportedTensorDtypes::SAME_AS_COMMON: {
3736
if (compute_type == ScalarType::Float) {
38-
return (
39-
executorch::runtime::tensor_is_type(t, ScalarType::Float) ||
40-
executorch::runtime::tensor_is_type(t, ScalarType::Half) ||
41-
executorch::runtime::tensor_is_type(t, ScalarType::BFloat16));
37+
return (executorch::runtime::tensor_is_type(
38+
t, {ScalarType::Float, ScalarType::Half, ScalarType::BFloat16}));
4239
} else {
4340
return executorch::runtime::tensor_is_type(t, compute_type);
4441
}

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <cmath>
1515
#include <cstddef> // size_t
1616
#include <limits>
17+
#include <sstream>
18+
#include <vector>
1719

1820
#include <executorch/runtime/core/array_ref.h>
1921
#include <executorch/runtime/core/error.h>
@@ -484,6 +486,30 @@ inline bool tensor_is_type(
484486
return true;
485487
}
486488

489+
inline bool tensor_is_type(
490+
executorch::aten::Tensor t,
491+
const std::vector<executorch::aten::ScalarType>& dtypes) {
492+
if (std::find(dtypes.begin(), dtypes.end(), t.scalar_type()) !=
493+
dtypes.end()) {
494+
return true;
495+
}
496+
497+
std::stringstream dtype_ss;
498+
for (size_t i = 0; i < dtypes.size(); i++) {
499+
if (i != 0) {
500+
dtype_ss << ", ";
501+
}
502+
dtype_ss << torch::executor::toString(dtypes[i]);
503+
}
504+
505+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
506+
false,
507+
"Expected to find one of %s types, but tensor has type %s",
508+
dtype_ss.str().c_str(),
509+
torch::executor::toString(t.scalar_type()));
510+
return false;
511+
}
512+
487513
inline bool tensor_is_integral_type(
488514
executorch::aten::Tensor t,
489515
bool includeBool = false) {

0 commit comments

Comments
 (0)