|
14 | 14 | #include <cmath>
|
15 | 15 | #include <cstddef> // size_t
|
16 | 16 | #include <limits>
|
17 |
| -#include <sstream> |
18 |
| -#include <vector> |
19 | 17 |
|
20 | 18 | #include <executorch/runtime/core/array_ref.h>
|
21 | 19 | #include <executorch/runtime/core/error.h>
|
@@ -488,26 +486,33 @@ inline bool tensor_is_type(
|
488 | 486 |
|
489 | 487 | inline bool tensor_is_type(
|
490 | 488 | executorch::aten::Tensor t,
|
491 |
| - const std::vector<executorch::aten::ScalarType>& dtypes) { |
492 |
| - if (std::find(dtypes.begin(), dtypes.end(), t.scalar_type()) != |
493 |
| - dtypes.end()) { |
494 |
| - return true; |
495 |
| - } |
| 489 | + executorch::aten::ScalarType dtype, |
| 490 | + executorch::aten::ScalarType dtype2) { |
| 491 | + ET_LOG_MSG_AND_RETURN_IF_FALSE( |
| 492 | + t.scalar_type() == dtype || t.scalar_type() == dtype2, |
| 493 | + "Expected to find %s or %s type, but tensor has type %s", |
| 494 | + torch::executor::toString(dtype), |
| 495 | + torch::executor::toString(dtype2), |
| 496 | + torch::executor::toString(t.scalar_type())); |
496 | 497 |
|
497 |
| - std::stringstream dtype_ss; |
498 |
| - for (size_t i = 0; i < dtypes.size(); i++) { |
499 |
| - if (i != 0) { |
500 |
| - dtype_ss << ", "; |
501 |
| - } |
502 |
| - dtype_ss << torch::executor::toString(dtypes[i]); |
503 |
| - } |
| 498 | + return true; |
| 499 | +} |
504 | 500 |
|
| 501 | +inline bool tensor_is_type( |
| 502 | + executorch::aten::Tensor t, |
| 503 | + executorch::aten::ScalarType dtype, |
| 504 | + executorch::aten::ScalarType dtype2, |
| 505 | + executorch::aten::ScalarType dtype3) { |
505 | 506 | ET_LOG_MSG_AND_RETURN_IF_FALSE(
|
506 |
| - false, |
507 |
| - "Expected to find one of %s types, but tensor has type %s", |
508 |
| - dtype_ss.str().c_str(), |
| 507 | + t.scalar_type() == dtype || t.scalar_type() == dtype2 || |
| 508 | + t.scalar_type() == dtype3, |
| 509 | + "Expected to find %s, %s, or %s type, but tensor has type %s", |
| 510 | + torch::executor::toString(dtype), |
| 511 | + torch::executor::toString(dtype2), |
| 512 | + torch::executor::toString(dtype3), |
509 | 513 | torch::executor::toString(t.scalar_type()));
|
510 |
| - return false; |
| 514 | + |
| 515 | + return true; |
511 | 516 | }
|
512 | 517 |
|
513 | 518 | inline bool tensor_is_integral_type(
|
|
0 commit comments