7
7
*/
8
8
9
9
#include < executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10
+ #include < executorch/kernels/test/ScalarOverflowTestMacros.h>
10
11
#include < executorch/kernels/test/TestUtil.h>
11
12
#include < executorch/kernels/test/supported_features.h>
12
13
#include < executorch/runtime/core/exec_aten/exec_aten.h>
15
16
16
17
#include < gtest/gtest.h>
17
18
18
- #include < iostream>
19
-
20
19
using namespace ::testing;
21
20
using executorch::aten::Scalar;
22
21
using executorch::aten::ScalarType;
@@ -231,6 +230,27 @@ class OpAddOutKernelTest : public OperatorTest {
231
230
EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
232
231
EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
233
232
}
233
+
234
+ template <ScalarType DTYPE>
235
+ void expect_bad_alpha_value_dies (const Scalar& bad_value) {
236
+ TensorFactory<DTYPE> tf;
237
+ Tensor a = tf.ones ({2 , 2 });
238
+ Tensor b = tf.ones ({2 , 2 });
239
+ Tensor out = tf.zeros ({2 , 2 });
240
+
241
+ ET_EXPECT_KERNEL_FAILURE (context_, op_add_out (a, b, bad_value, out));
242
+ }
243
+
244
+ // The GENERATE_SCALAR_OVERFLOW_TESTS macro used to generate scalar overflow
245
+ // test cases requires a method called expect_bad_scalar_value_dies. However,
246
+ // for add operation, these checks only apply to the alpha argument.
247
+ // We are being explicit about this by naming the above function
248
+ // expect_bad_alpha_value_dies, and creating this wrapper in order to use the
249
+ // macro.
250
+ template <ScalarType DTYPE>
251
+ void expect_bad_scalar_value_dies (const Scalar& bad_value) {
252
+ expect_bad_alpha_value_dies<DTYPE>(bad_value);
253
+ }
234
254
};
235
255
236
256
class OpAddScalarOutKernelTest : public OperatorTest {
@@ -242,6 +262,27 @@ class OpAddScalarOutKernelTest : public OperatorTest {
242
262
Tensor& out) {
243
263
return torch::executor::aten::add_outf (context_, self, other, alpha, out);
244
264
}
265
+
266
+ template <ScalarType DTYPE>
267
+ void expect_bad_alpha_value_dies (const Scalar& bad_value) {
268
+ TensorFactory<DTYPE> tf;
269
+ Tensor a = tf.ones ({2 , 2 });
270
+ Scalar b = 1 ;
271
+ Tensor out = tf.zeros ({2 , 2 });
272
+
273
+ ET_EXPECT_KERNEL_FAILURE (context_, op_add_scalar_out (a, b, bad_value, out));
274
+ }
275
+
276
+ // The GENERATE_SCALAR_OVERFLOW_TESTS macro used to generate scalar overflow
277
+ // test cases requires a method called expect_bad_scalar_value_dies. However,
278
+ // for the add operation, these checks only apply to the alpha argument.
279
+ // We are being explicit about this by naming the above function
280
+ // expect_bad_alpha_value_dies, and creating this wrapper in order to use the
281
+ // macro.
282
+ template <ScalarType DTYPE>
283
+ void expect_bad_scalar_value_dies (const Scalar& bad_value) {
284
+ expect_bad_alpha_value_dies<DTYPE>(bad_value);
285
+ }
245
286
};
246
287
247
288
/* *
@@ -794,3 +835,26 @@ TEST_F(OpAddScalarOutKernelTest, DtypeTest_float16_bool_int_float16) {
794
835
op_add_scalar_out (self, other, alpha, out);
795
836
EXPECT_TENSOR_CLOSE (out, out_expected);
796
837
}
838
+
839
+ TEST_F (OpAddOutKernelTest, ByteTensorFloatingPointAlphaDies) {
840
+ // Cannot be represented by a uint8_t.
841
+ expect_bad_alpha_value_dies<ScalarType::Byte>(2.2 );
842
+ }
843
+
844
+ TEST_F (OpAddOutKernelTest, IntTensorFloatingPointAlphaDies) {
845
+ // Cannot be represented by a uint32_t.
846
+ expect_bad_alpha_value_dies<ScalarType::Int>(2.2 );
847
+ }
848
+
849
+ TEST_F (OpAddScalarOutKernelTest, ByteTensorFloatingPointAlphaDies) {
850
+ // Cannot be represented by a uint8_t.
851
+ expect_bad_alpha_value_dies<ScalarType::Byte>(2.2 );
852
+ }
853
+
854
+ TEST_F (OpAddScalarOutKernelTest, IntTensorFloatingPointAlphaDies) {
855
+ // Cannot be represented by a uint32_t.
856
+ expect_bad_alpha_value_dies<ScalarType::Int>(2.2 );
857
+ }
858
+
859
+ GENERATE_SCALAR_OVERFLOW_TESTS (OpAddOutKernelTest)
860
+ GENERATE_SCALAR_OVERFLOW_TESTS(OpAddScalarOutKernelTest)
0 commit comments