Skip to content

Commit 4374afe

Browse files
[ET][Portable] Test bad alpha values: op_add (#12091)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12030 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/128/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/128/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/128/orig @diff-train-skip-merge Co-authored-by: Manuel Candales <[email protected]> Co-authored-by: Manuel Candales <[email protected]>
1 parent f888bdf commit 4374afe

File tree

2 files changed

+72
-4
lines changed

2 files changed

+72
-4
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ Tensor& add_out(
5151
static constexpr const char op_name[] = "add.out";
5252

5353
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
54-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
54+
CTYPE_COMPUTE val_alpha;
55+
ET_KERNEL_CHECK(
56+
ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, );
5557
utils::apply_bitensor_elementwise_fn<
5658
CTYPE_COMPUTE,
5759
op_name,
@@ -103,7 +105,9 @@ Tensor& add_scalar_out(
103105

104106
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
105107
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
106-
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
108+
CTYPE_COMPUTE val_alpha;
109+
ET_KERNEL_CHECK(
110+
ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, );
107111
auto val_alpha_times_b = val_alpha * val_b;
108112
utils::apply_unitensor_elementwise_fn<
109113
CTYPE_COMPUTE,

kernels/test/op_add_test.cpp

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
1011
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/kernels/test/supported_features.h>
1213
#include <executorch/runtime/core/exec_aten/exec_aten.h>
@@ -15,8 +16,6 @@
1516

1617
#include <gtest/gtest.h>
1718

18-
#include <iostream>
19-
2019
using namespace ::testing;
2120
using executorch::aten::Scalar;
2221
using executorch::aten::ScalarType;
@@ -231,6 +230,27 @@ class OpAddOutKernelTest : public OperatorTest {
231230
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
232231
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
233232
}
233+
234+
template <ScalarType DTYPE>
235+
void expect_bad_alpha_value_dies(const Scalar& bad_value) {
236+
TensorFactory<DTYPE> tf;
237+
Tensor a = tf.ones({2, 2});
238+
Tensor b = tf.ones({2, 2});
239+
Tensor out = tf.zeros({2, 2});
240+
241+
ET_EXPECT_KERNEL_FAILURE(context_, op_add_out(a, b, bad_value, out));
242+
}
243+
244+
// The GENERATE_SCALAR_OVERFLOW_TESTS macro used to generate scalar overflow
245+
// test cases requires a method called expect_bad_scalar_value_dies. However,
246+
// for add operation, these checks only apply to the alpha argument.
247+
// We are being explicit about this by naming the above function
248+
// expect_bad_alpha_value_dies, and creating this wrapper in order to use the
249+
// macro.
250+
template <ScalarType DTYPE>
251+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
252+
expect_bad_alpha_value_dies<DTYPE>(bad_value);
253+
}
234254
};
235255

236256
class OpAddScalarOutKernelTest : public OperatorTest {
@@ -242,6 +262,27 @@ class OpAddScalarOutKernelTest : public OperatorTest {
242262
Tensor& out) {
243263
return torch::executor::aten::add_outf(context_, self, other, alpha, out);
244264
}
265+
266+
template <ScalarType DTYPE>
267+
void expect_bad_alpha_value_dies(const Scalar& bad_value) {
268+
TensorFactory<DTYPE> tf;
269+
Tensor a = tf.ones({2, 2});
270+
Scalar b = 1;
271+
Tensor out = tf.zeros({2, 2});
272+
273+
ET_EXPECT_KERNEL_FAILURE(context_, op_add_scalar_out(a, b, bad_value, out));
274+
}
275+
276+
// The GENERATE_SCALAR_OVERFLOW_TESTS macro used to generate scalar overflow
277+
// test cases requires a method called expect_bad_scalar_value_dies. However,
278+
// for the add operation, these checks only apply to the alpha argument.
279+
// We are being explicit about this by naming the above function
280+
// expect_bad_alpha_value_dies, and creating this wrapper in order to use the
281+
// macro.
282+
template <ScalarType DTYPE>
283+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
284+
expect_bad_alpha_value_dies<DTYPE>(bad_value);
285+
}
245286
};
246287

247288
/**
@@ -794,3 +835,26 @@ TEST_F(OpAddScalarOutKernelTest, DtypeTest_float16_bool_int_float16) {
794835
op_add_scalar_out(self, other, alpha, out);
795836
EXPECT_TENSOR_CLOSE(out, out_expected);
796837
}
838+
839+
TEST_F(OpAddOutKernelTest, ByteTensorFloatingPointAlphaDies) {
840+
// Cannot be represented by a uint8_t.
841+
expect_bad_alpha_value_dies<ScalarType::Byte>(2.2);
842+
}
843+
844+
TEST_F(OpAddOutKernelTest, IntTensorFloatingPointAlphaDies) {
845+
// Cannot be represented by a uint32_t.
846+
expect_bad_alpha_value_dies<ScalarType::Int>(2.2);
847+
}
848+
849+
TEST_F(OpAddScalarOutKernelTest, ByteTensorFloatingPointAlphaDies) {
850+
// Cannot be represented by a uint8_t.
851+
expect_bad_alpha_value_dies<ScalarType::Byte>(2.2);
852+
}
853+
854+
TEST_F(OpAddScalarOutKernelTest, IntTensorFloatingPointAlphaDies) {
855+
// Cannot be represented by a uint32_t.
856+
expect_bad_alpha_value_dies<ScalarType::Int>(2.2);
857+
}
858+
859+
GENERATE_SCALAR_OVERFLOW_TESTS(OpAddOutKernelTest)
860+
GENERATE_SCALAR_OVERFLOW_TESTS(OpAddScalarOutKernelTest)

0 commit comments

Comments
 (0)