Skip to content

Commit 37265a1

Browse files
authored
Remove false positive error message in the executor_runner (#7170)
### Summary There are false positive error messages from the type check that attempts to verify several allowed types at once, but uses one function call per type leading to some of them not passing the check and emitting an error message. ``` > ./executor_runner --model-path ~/model.pte ... I 00:00:00.519929 executorch:executor_runner.cpp:91] Using method forward I 00:00:00.519931 executorch:executor_runner.cpp:138] Setting up planned buffer 0, size 10753549312. I 00:00:01.212589 executorch:executor_runner.cpp:161] Method loaded. I 00:00:01.213544 executorch:executor_runner.cpp:171] Inputs prepared. E 00:00:02.320414 executorch:tensor_util.h:481] Expected to find Float type, but tensor has type Half E 00:00:02.320521 executorch:tensor_util.h:487] Check failed (t.scalar_type() == dtype): Expected to find Float type, but tensor has type Half ... I 00:00:24.911098 executorch:executor_runner.cpp:180] Model executed successfully. ... ``` Note the repeated lines > Check failed (t.scalar_type() == dtype): Expected to find Float type, but tensor has type Half ### Test plan I can only provide the log message from my local run for that ``` ... I 00:00:00.745945 executorch:executor_runner.cpp:138] Setting up planned buffer 0, size 13026185216. I 00:00:02.139255 executorch:executor_runner.cpp:161] Method loaded. I 00:00:02.139295 executorch:executor_runner.cpp:171] Inputs prepared. I 00:00:28.297681 executorch:executor_runner.cpp:180] Model executed successfully. ... ```
1 parent 2303947 commit 37265a1

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

kernels/portable/cpu/util/dtype_util.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,14 @@ bool check_tensor_dtype(
2828
case SupportedTensorDtypes::INTB:
2929
return executorch::runtime::tensor_is_integral_type(t, true);
3030
case SupportedTensorDtypes::BOOL_OR_BYTE:
31-
return (
32-
executorch::runtime::tensor_is_type(t, ScalarType::Bool) ||
33-
executorch::runtime::tensor_is_type(t, ScalarType::Byte));
31+
return (executorch::runtime::tensor_is_type(
32+
t, ScalarType::Bool, ScalarType::Byte));
3433
case SupportedTensorDtypes::SAME_AS_COMPUTE:
3534
return executorch::runtime::tensor_is_type(t, compute_type);
3635
case SupportedTensorDtypes::SAME_AS_COMMON: {
3736
if (compute_type == ScalarType::Float) {
38-
return (
39-
executorch::runtime::tensor_is_type(t, ScalarType::Float) ||
40-
executorch::runtime::tensor_is_type(t, ScalarType::Half) ||
41-
executorch::runtime::tensor_is_type(t, ScalarType::BFloat16));
37+
return (executorch::runtime::tensor_is_type(
38+
t, ScalarType::Float, ScalarType::Half, ScalarType::BFloat16));
4239
} else {
4340
return executorch::runtime::tensor_is_type(t, compute_type);
4441
}

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,37 @@ inline bool tensor_is_type(
475475
return true;
476476
}
477477

478+
inline bool tensor_is_type(
479+
executorch::aten::Tensor t,
480+
executorch::aten::ScalarType dtype,
481+
executorch::aten::ScalarType dtype2) {
482+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
483+
t.scalar_type() == dtype || t.scalar_type() == dtype2,
484+
"Expected to find %s or %s type, but tensor has type %s",
485+
torch::executor::toString(dtype),
486+
torch::executor::toString(dtype2),
487+
torch::executor::toString(t.scalar_type()));
488+
489+
return true;
490+
}
491+
492+
inline bool tensor_is_type(
493+
executorch::aten::Tensor t,
494+
executorch::aten::ScalarType dtype,
495+
executorch::aten::ScalarType dtype2,
496+
executorch::aten::ScalarType dtype3) {
497+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
498+
t.scalar_type() == dtype || t.scalar_type() == dtype2 ||
499+
t.scalar_type() == dtype3,
500+
"Expected to find %s, %s, or %s type, but tensor has type %s",
501+
torch::executor::toString(dtype),
502+
torch::executor::toString(dtype2),
503+
torch::executor::toString(dtype3),
504+
torch::executor::toString(t.scalar_type()));
505+
506+
return true;
507+
}
508+
478509
inline bool tensor_is_integral_type(
479510
executorch::aten::Tensor t,
480511
bool includeBool = false) {

0 commit comments

Comments
 (0)