Skip to content

Commit 3e6764e

Browse files
committed
simplify
1 parent 0225f2f commit 3e6764e

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

kernels/portable/cpu/util/dtype_util.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ bool check_tensor_dtype(
2929
return executorch::runtime::tensor_is_integral_type(t, true);
3030
case SupportedTensorDtypes::BOOL_OR_BYTE:
3131
return (executorch::runtime::tensor_is_type(
32-
t, {ScalarType::Bool, ScalarType::Byte}));
32+
t, ScalarType::Bool, ScalarType::Byte));
3333
case SupportedTensorDtypes::SAME_AS_COMPUTE:
3434
return executorch::runtime::tensor_is_type(t, compute_type);
3535
case SupportedTensorDtypes::SAME_AS_COMMON: {
3636
if (compute_type == ScalarType::Float) {
3737
return (executorch::runtime::tensor_is_type(
38-
t, {ScalarType::Float, ScalarType::Half, ScalarType::BFloat16}));
38+
t, ScalarType::Float, ScalarType::Half, ScalarType::BFloat16));
3939
} else {
4040
return executorch::runtime::tensor_is_type(t, compute_type);
4141
}

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
#include <cmath>
1515
#include <cstddef> // size_t
1616
#include <limits>
17-
#include <sstream>
18-
#include <vector>
1917

2018
#include <executorch/runtime/core/array_ref.h>
2119
#include <executorch/runtime/core/error.h>
@@ -488,26 +486,33 @@ inline bool tensor_is_type(
488486

489487
inline bool tensor_is_type(
490488
executorch::aten::Tensor t,
491-
const std::vector<executorch::aten::ScalarType>& dtypes) {
492-
if (std::find(dtypes.begin(), dtypes.end(), t.scalar_type()) !=
493-
dtypes.end()) {
494-
return true;
495-
}
489+
executorch::aten::ScalarType dtype,
490+
executorch::aten::ScalarType dtype2) {
491+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
492+
t.scalar_type() == dtype || t.scalar_type() == dtype2,
493+
"Expected to find %s or %s type, but tensor has type %s",
494+
torch::executor::toString(dtype),
495+
torch::executor::toString(dtype2),
496+
torch::executor::toString(t.scalar_type()));
496497

497-
std::stringstream dtype_ss;
498-
for (size_t i = 0; i < dtypes.size(); i++) {
499-
if (i != 0) {
500-
dtype_ss << ", ";
501-
}
502-
dtype_ss << torch::executor::toString(dtypes[i]);
503-
}
498+
return true;
499+
}
504500

501+
inline bool tensor_is_type(
502+
executorch::aten::Tensor t,
503+
executorch::aten::ScalarType dtype,
504+
executorch::aten::ScalarType dtype2,
505+
executorch::aten::ScalarType dtype3) {
505506
ET_LOG_MSG_AND_RETURN_IF_FALSE(
506-
false,
507-
"Expected to find one of %s types, but tensor has type %s",
508-
dtype_ss.str().c_str(),
507+
t.scalar_type() == dtype || t.scalar_type() == dtype2 ||
508+
t.scalar_type() == dtype3,
509+
"Expected to find %s, %s, or %s type, but tensor has type %s",
510+
torch::executor::toString(dtype),
511+
torch::executor::toString(dtype2),
512+
torch::executor::toString(dtype3),
509513
torch::executor::toString(t.scalar_type()));
510-
return false;
514+
515+
return true;
511516
}
512517

513518
inline bool tensor_is_integral_type(

0 commit comments

Comments
 (0)