Skip to content

Commit 5c676dc

Browse files
authored
Reapply: "relax tolerances for all unary float ops (#9585)", "Add SupportedTensorDtypes::BOOL (#9584)", new op_mul test (#11206)
Differential Revision: D76754823 Pull Request resolved: #11942
1 parent 91c9ffa commit 5c676dc

File tree

4 files changed

+54
-13
lines changed

4 files changed

+54
-13
lines changed

kernels/portable/cpu/util/dtype_util.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ bool check_tensor_dtype(
2727
return executorch::runtime::tensor_is_floating_type(t);
2828
case SupportedTensorDtypes::INTB:
2929
return executorch::runtime::tensor_is_integral_type(t, true);
30+
case SupportedTensorDtypes::BOOL:
31+
return executorch::runtime::tensor_is_type(t, ScalarType::Bool);
3032
case SupportedTensorDtypes::BOOL_OR_BYTE:
3133
return (executorch::runtime::tensor_is_type(
3234
t, ScalarType::Bool, ScalarType::Byte));

kernels/portable/cpu/util/dtype_util.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_intb(const Tensor& t) {
7272
return result;
7373
}
7474

75+
template <typename CTYPE_COMPUTE, const char* op_name>
76+
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool(const Tensor& t) {
77+
ET_CHECK_MSG(
78+
t.scalar_type() == ScalarType::Bool,
79+
"Unhandled dtype %s for %s",
80+
::executorch::runtime::toString(t.scalar_type()),
81+
op_name);
82+
return internal::load_and_convert<CTYPE_COMPUTE, bool>;
83+
}
84+
7585
template <typename CTYPE_COMPUTE, const char* op_name>
7686
load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool_or_byte(
7787
const Tensor& t) {
@@ -165,6 +175,17 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_intb(
165175
return result;
166176
}
167177

178+
template <typename CTYPE_COMPUTE, const char* op_name>
179+
store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_bool(
180+
const Tensor& t) {
181+
ET_CHECK_MSG(
182+
t.scalar_type() == ScalarType::Bool,
183+
"Unhandled dtype %s for %s",
184+
::executorch::runtime::toString(t.scalar_type()),
185+
op_name);
186+
return internal::convert_and_store<bool, CTYPE_COMPUTE>;
187+
}
188+
168189
template <typename CTYPE_COMPUTE, const char* op_name>
169190
store_compute_to_tensor_fn<CTYPE_COMPUTE>
170191
get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) {
@@ -219,6 +240,7 @@ enum class SupportedTensorDtypes {
219240
REALHBF16,
220241
FLOATHBF16,
221242
INTB,
243+
BOOL,
222244
BOOL_OR_BYTE,
223245
// DEPRECATED: not likely to be correct; use SAME_AS_COMMON.
224246
SAME_AS_COMPUTE,
@@ -240,6 +262,8 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_impl(
240262
return get_load_to_compute_fn_realhbf16<CTYPE_COMPUTE, op_name>(t);
241263
case SupportedTensorDtypes::INTB:
242264
return get_load_to_compute_fn_intb<CTYPE_COMPUTE, op_name>(t);
265+
case SupportedTensorDtypes::BOOL:
266+
return get_load_to_compute_fn_bool<CTYPE_COMPUTE, op_name>(t);
243267
case SupportedTensorDtypes::BOOL_OR_BYTE:
244268
return get_load_to_compute_fn_bool_or_byte<CTYPE_COMPUTE, op_name>(t);
245269
case SupportedTensorDtypes::SAME_AS_COMPUTE:
@@ -271,6 +295,8 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn(
271295
t);
272296
case SupportedTensorDtypes::INTB:
273297
return get_store_compute_to_tensor_fn_intb<CTYPE_COMPUTE, op_name>(t);
298+
case SupportedTensorDtypes::BOOL:
299+
return get_store_compute_to_tensor_fn_bool<CTYPE_COMPUTE, op_name>(t);
274300
case SupportedTensorDtypes::BOOL_OR_BYTE:
275301
return get_store_compute_to_tensor_fn_bool_or_byte<
276302
CTYPE_COMPUTE,
@@ -318,12 +344,14 @@ bool check_tensor_dtype(
318344
const ScalarType compute_type);
319345

320346
/// Return the one output type we are willing to emit specialized code
321-
/// to handle, given a compute type of CTYPE_COMMON and supported
347+
/// to handle, given a compute type of CTYPE_COMPUTE and supported
322348
/// output types of out_dtypes.
323349
template <typename CTYPE_COMPUTE>
324350
inline constexpr ScalarType specialized_output_scalar_type(
325351
SupportedTensorDtypes out_dtypes) {
326352
switch (out_dtypes) {
353+
case SupportedTensorDtypes::BOOL:
354+
return ScalarType::Bool;
327355
case SupportedTensorDtypes::BOOL_OR_BYTE:
328356
return ScalarType::Bool;
329357
case SupportedTensorDtypes::REALHBBF16:

kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,16 @@ class UnaryUfuncRealHBBF16ToFloatHBF16Test : public OperatorTest {
7272

7373
auto expected = tf_out.make({1, 6}, expected_vector);
7474
if (IN_DTYPE == ScalarType::BFloat16 || OUT_DTYPE == ScalarType::BFloat16) {
75-
double rtol = executorch::runtime::testing::internal::kDefaultRtol;
76-
// It appears we need a higher tolerance for at least some ATen
77-
// tests, like aten_op_acosh_test.
78-
if (get_supported_features()->is_aten) {
79-
rtol = 3e-3;
80-
}
75+
// Raise tolerance because both we and ATen run these
76+
// computations at internal float32 precision rather than
77+
// float64.
78+
double rtol = 3e-3;
8179
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultBFloat16Atol);
8280
} else if (IN_DTYPE == ScalarType::Half || OUT_DTYPE == ScalarType::Half) {
83-
double rtol = executorch::runtime::testing::internal::kDefaultRtol;
84-
// It appears we need a higher tolerance for at least some ATen
85-
// tests, like aten_op_acosh_test.
86-
if (get_supported_features()->is_aten) {
87-
rtol = 1e-3;
88-
}
81+
// Raise tolerance because both we and ATen run these
82+
// computations at internal float32 precision rather than
83+
// float64.
84+
double rtol = 1e-3;
8985
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultHalfAtol);
9086
} else {
9187
EXPECT_TENSOR_CLOSE(out, expected);

kernels/test/op_mul_test.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,21 @@ TEST_F(OpMulOutTest, DynamicShapeUnbound) {
746746
EXPECT_TENSOR_CLOSE(out, expected_result);
747747
}
748748

749+
// >>> torch.ops.aten.mul(torch.tensor([100], dtype=torch.int8),
750+
// torch.tensor([100], dtype=torch.int8), out=torch.zeros([1],
751+
// dtype=torch.long)) tensor([16])
752+
TEST_F(OpMulOutTest, MixedIntegerDtypeMatchesATen) {
753+
TensorFactory<ScalarType::Char> tf_in;
754+
TensorFactory<ScalarType::Long> tf_out;
755+
756+
Tensor in = tf_in.make({1}, {100});
757+
Tensor out = tf_out.zeros({1});
758+
Tensor ret = op_mul_out(in, in, out);
759+
760+
Tensor expected = tf_out.make({1}, {16});
761+
EXPECT_TENSOR_CLOSE(out, expected);
762+
}
763+
749764
TEST_F(OpMulScalarOutTest, SanityCheck) {
750765
TensorFactory<ScalarType::Bool> tf_a;
751766
TensorFactory<ScalarType::Float> tf_out;

0 commit comments

Comments
 (0)