Skip to content

Commit 157bf91

Browse files
Use non-fatal input checks for view operator (#9264)
### Summary Fixes #9130 Use non-fatal input checks for view operator, just like https://github.com/pytorch/executorch/pull/2115/files, since there are existing cases that an internal app crashed due to a check failure. Test is updated as well for `RegisterPrimOpsTest`. ### Test plan Pass updated test: ``` mkdir cmake-out && cd cmake-out && cmake -DEXECUTORCH_BUILD_TESTS=ON .. && cmake --build . --target kernels_prim_ops_test -j9 && ctest -R kernels_prim_ops_test ``` Pass existing CI/CD. cc @GregoryComer
1 parent 23fe285 commit 157bf91

File tree

3 files changed

+55
-45
lines changed

3 files changed

+55
-45
lines changed

kernels/prim_ops/et_view.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,28 +72,38 @@ void et_view(KernelRuntimeContext& context, EValue** stack) {
7272
auto size = (*stack[1]).toIntList();
7373
auto out = (*stack[2]).toTensor();
7474

75-
ET_CHECK(tensors_have_same_dtype(self, out));
75+
ET_KERNEL_CHECK(
76+
context, tensors_have_same_dtype(self, out), InvalidArgument, );
7677

7778
// Compute output size
7879
SizesType expected_output_size[kTensorDimensionLimit];
79-
ET_CHECK(get_view_target_size(self, size, out.dim(), expected_output_size));
80+
ET_KERNEL_CHECK(
81+
context,
82+
get_view_target_size(self, size, out.dim(), expected_output_size),
83+
InvalidArgument, );
8084

8185
// Resize for dynamic shape
82-
ET_CHECK_MSG(
86+
ET_KERNEL_CHECK_MSG(
87+
context,
8388
resize_tensor(
8489
out, {expected_output_size, static_cast<size_t>(out.dim())}) ==
8590
Error::Ok,
91+
Internal,
92+
,
8693
"Failed to resize output tensor.");
8794

8895
// Do some checks
89-
ET_CHECK(self.numel() == out.numel());
96+
ET_KERNEL_CHECK(context, self.numel() == out.numel(), InvalidArgument, );
9097

9198
// Update data ptr
92-
ET_CHECK_MSG(
99+
ET_KERNEL_CHECK_MSG(
100+
context,
93101
internal::set_tensor_data(
94102
out,
95103
/*buffer=*/self.mutable_data_ptr(),
96104
/*buffer_size=*/out.nbytes()) == Error::Ok,
105+
Internal,
106+
,
97107
"Failed to set data_ptr for out to self.");
98108
}
99109

kernels/prim_ops/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ runtime.cxx_test(
2424
],
2525
deps = [
2626
"//executorch/kernels/prim_ops:prim_ops_registry", # @manual
27+
"//executorch/kernels/test:test_util", # @manual
2728
"//executorch/runtime/core:evalue", # @manual
2829
"//executorch/runtime/core/exec_aten:lib", # @manual
2930
"//executorch/runtime/core/exec_aten/testing_util:tensor_util", # @manual

kernels/prim_ops/test/prim_ops_test.cpp

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <gtest/gtest.h>
1010

11+
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/runtime/core/evalue.h>
1213
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1314
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
@@ -16,7 +17,6 @@
1617
#include <executorch/runtime/kernel/kernel_runtime_context.h>
1718
#include <executorch/runtime/kernel/operator_registry.h>
1819
#include <executorch/runtime/platform/runtime.h>
19-
#include <executorch/test/utils/DeathTest.h>
2020
#include <cstdint>
2121
#include <cstdio>
2222

@@ -27,12 +27,10 @@ using torch::executor::resize_tensor;
2727
namespace torch {
2828
namespace executor {
2929

30-
class RegisterPrimOpsTest : public ::testing::Test {
30+
class RegisterPrimOpsTest : public OperatorTest {
3131
protected:
32-
KernelRuntimeContext context;
3332
void SetUp() override {
34-
torch::executor::runtime_init();
35-
context = KernelRuntimeContext();
33+
context_ = KernelRuntimeContext();
3634
}
3735
};
3836

@@ -57,7 +55,7 @@ TEST_F(RegisterPrimOpsTest, SymSizeReturnsCorrectValue) {
5755
stack[i] = &values[i];
5856
}
5957

60-
getOpsFn("aten::sym_size.int")(context, stack);
58+
getOpsFn("aten::sym_size.int")(context_, stack);
6159

6260
int64_t expected = 5;
6361
EXPECT_EQ(stack[2]->toInt(), expected);
@@ -77,7 +75,7 @@ TEST_F(RegisterPrimOpsTest, SymNumelReturnsCorrectValue) {
7775
stack[i] = &values[i];
7876
}
7977

80-
getOpsFn("aten::sym_numel")(context, stack);
78+
getOpsFn("aten::sym_numel")(context_, stack);
8179

8280
int64_t expected = 15;
8381
EXPECT_EQ(stack[1]->toInt(), expected);
@@ -97,28 +95,28 @@ TEST_F(RegisterPrimOpsTest, TestAlgebraOps) {
9795
stack[i] = &values[i];
9896
}
9997

100-
getOpsFn("executorch_prim::add.Scalar")(context, stack);
98+
getOpsFn("executorch_prim::add.Scalar")(context_, stack);
10199
EXPECT_EQ(stack[2]->toInt(), 7);
102100

103-
getOpsFn("executorch_prim::sub.Scalar")(context, stack);
101+
getOpsFn("executorch_prim::sub.Scalar")(context_, stack);
104102
EXPECT_EQ(stack[2]->toInt(), -1);
105103

106-
getOpsFn("executorch_prim::mul.Scalar")(context, stack);
104+
getOpsFn("executorch_prim::mul.Scalar")(context_, stack);
107105
EXPECT_EQ(stack[2]->toInt(), 12);
108106

109-
getOpsFn("executorch_prim::floordiv.Scalar")(context, stack);
107+
getOpsFn("executorch_prim::floordiv.Scalar")(context_, stack);
110108
EXPECT_EQ(stack[2]->toInt(), 0);
111109

112-
getOpsFn("executorch_prim::truediv.Scalar")(context, stack);
110+
getOpsFn("executorch_prim::truediv.Scalar")(context_, stack);
113111
EXPECT_FLOAT_EQ(stack[2]->toDouble(), 0.75);
114112

115-
getOpsFn("executorch_prim::mod.int")(context, stack);
113+
getOpsFn("executorch_prim::mod.int")(context_, stack);
116114
EXPECT_EQ(stack[2]->toInt(), 3);
117115

118-
getOpsFn("executorch_prim::mod.Scalar")(context, stack);
116+
getOpsFn("executorch_prim::mod.Scalar")(context_, stack);
119117
EXPECT_EQ(stack[2]->toInt(), 3);
120118

121-
getOpsFn("executorch_prim::sym_float.Scalar")(context, stack);
119+
getOpsFn("executorch_prim::sym_float.Scalar")(context_, stack);
122120
EXPECT_FLOAT_EQ(stack[1]->toDouble(), 3.0);
123121
}
124122

@@ -155,7 +153,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndex) {
155153
stack[2] = &values[2];
156154

157155
// Simple test to copy to index 0.
158-
getOpsFn("executorch_prim::et_copy_index.tensor")(context, stack);
156+
getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack);
159157

160158
EXPECT_EQ(copy_to.sizes()[0], 1);
161159
EXPECT_EQ(copy_to.sizes()[1], 2);
@@ -164,7 +162,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndex) {
164162
values[1] = tf.make({2}, {5, 6});
165163
values[2] = EValue((int64_t)1);
166164
// Copy to the next index, 1.
167-
getOpsFn("executorch_prim::et_copy_index.tensor")(context, stack);
165+
getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack);
168166

169167
EXPECT_EQ(copy_to.sizes()[0], 2);
170168
EXPECT_EQ(copy_to.sizes()[1], 2);
@@ -193,7 +191,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexMismatchShape) {
193191
// copy_to.sizes[1:] and to_copy.sizes[:] don't match each other
194192
// which is a pre-requisite for this operator.
195193
ET_EXPECT_DEATH(
196-
getOpsFn("executorch_prim::et_copy_index.tensor")(context, stack), "");
194+
getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack), "");
197195
}
198196

199197
TEST_F(RegisterPrimOpsTest, TestETCopyIndexStaticShape) {
@@ -217,7 +215,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexStaticShape) {
217215
stack[2] = &values[2];
218216

219217
// Copy and replace at index 1.
220-
getOpsFn("executorch_prim::et_copy_index.tensor")(context, stack);
218+
getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack);
221219
EXPECT_EQ(copy_to.sizes()[0], 2);
222220
EXPECT_EQ(copy_to.sizes()[1], 2);
223221
EXPECT_TENSOR_EQ(copy_to, tf.make({2, 2}, {1, 2, 5, 6}));
@@ -228,7 +226,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexStaticShape) {
228226
index = 2;
229227
values[2] = EValue(index);
230228
ET_EXPECT_DEATH(
231-
getOpsFn("executorch_prim::et_copy_index.tensor")(context, stack), "");
229+
getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack), "");
232230
#endif
233231
}
234232

@@ -246,19 +244,19 @@ TEST_F(RegisterPrimOpsTest, TestBooleanOps) {
246244
stack[i] = &values[i];
247245
}
248246

249-
getOpsFn("executorch_prim::ge.Scalar")(context, stack);
247+
getOpsFn("executorch_prim::ge.Scalar")(context_, stack);
250248
EXPECT_EQ(stack[2]->toBool(), false);
251249

252-
getOpsFn("executorch_prim::gt.Scalar")(context, stack);
250+
getOpsFn("executorch_prim::gt.Scalar")(context_, stack);
253251
EXPECT_EQ(stack[2]->toBool(), false);
254252

255-
getOpsFn("executorch_prim::le.Scalar")(context, stack);
253+
getOpsFn("executorch_prim::le.Scalar")(context_, stack);
256254
EXPECT_EQ(stack[2]->toBool(), true);
257255

258-
getOpsFn("executorch_prim::lt.Scalar")(context, stack);
256+
getOpsFn("executorch_prim::lt.Scalar")(context_, stack);
259257
EXPECT_EQ(stack[2]->toBool(), true);
260258

261-
getOpsFn("executorch_prim::eq.Scalar")(context, stack);
259+
getOpsFn("executorch_prim::eq.Scalar")(context_, stack);
262260
EXPECT_EQ(stack[2]->toBool(), false);
263261
}
264262

@@ -277,7 +275,7 @@ TEST_F(RegisterPrimOpsTest, LocalScalarDenseReturnsCorrectValue) {
277275
stack[i] = &values[i];
278276
}
279277

280-
getOpsFn("aten::_local_scalar_dense")(context, stack);
278+
getOpsFn("aten::_local_scalar_dense")(context_, stack);
281279

282280
int64_t expected = 1;
283281
EXPECT_EQ(stack[1]->toInt(), expected);
@@ -295,7 +293,7 @@ TEST_F(RegisterPrimOpsTest, NegScalarReturnsCorrectValue) {
295293
stack[i] = &values[i];
296294
}
297295

298-
getOpsFn("executorch_prim::neg.Scalar")(context, stack);
296+
getOpsFn("executorch_prim::neg.Scalar")(context_, stack);
299297

300298
EXPECT_EQ(stack[1]->toDouble(), -5.0f);
301299

@@ -305,7 +303,7 @@ TEST_F(RegisterPrimOpsTest, NegScalarReturnsCorrectValue) {
305303
values[0] = EValue(a);
306304
values[1] = EValue(b);
307305

308-
getOpsFn("executorch_prim::neg.Scalar")(context, stack);
306+
getOpsFn("executorch_prim::neg.Scalar")(context_, stack);
309307

310308
EXPECT_EQ(stack[1]->toInt(), -5l);
311309
}
@@ -327,7 +325,7 @@ TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) {
327325
}
328326

329327
// Try to negate a tensor, which should cause a runtime error.
330-
ET_EXPECT_DEATH(getOpsFn("executorch_prim::neg.Scalar")(context, stack), "");
328+
ET_EXPECT_DEATH(getOpsFn("executorch_prim::neg.Scalar")(context_, stack), "");
331329
}
332330

333331
TEST_F(RegisterPrimOpsTest, TestETView) {
@@ -410,9 +408,9 @@ TEST_F(RegisterPrimOpsTest, TestETView) {
410408

411409
// Bad stacks expect death
412410
for (int i = 0; i < N_BAD_STACKS; i++) {
413-
ET_EXPECT_DEATH(
414-
getOpsFn("executorch_prim::et_view.default")(context, bad_stacks[i]),
415-
"");
411+
ET_EXPECT_KERNEL_FAILURE(
412+
context_,
413+
getOpsFn("executorch_prim::et_view.default")(context_, bad_stacks[i]));
416414
}
417415

418416
constexpr int N_GOOD_STACKS = N_GOOD_OUTS;
@@ -422,7 +420,7 @@ TEST_F(RegisterPrimOpsTest, TestETView) {
422420

423421
// Good outs expect no death and correct output
424422
for (int i = 0; i < N_GOOD_STACKS; i++) {
425-
getOpsFn("executorch_prim::et_view.default")(context, good_out_stacks[i]);
423+
getOpsFn("executorch_prim::et_view.default")(context_, good_out_stacks[i]);
426424
EXPECT_TENSOR_EQ(good_outs[i], tf.make({1, 3, 2}, {1, 2, 3, 4, 5, 6}));
427425
EXPECT_EQ(good_outs[i].const_data_ptr(), self.const_data_ptr());
428426
}
@@ -456,7 +454,7 @@ TEST_F(RegisterPrimOpsTest, TestETViewDynamic) {
456454

457455
EValue* stack[3] = {&self_evalue, &size_int_list_evalue, &out_evalue};
458456

459-
getOpsFn("executorch_prim::et_view.default")(context, stack);
457+
getOpsFn("executorch_prim::et_view.default")(context_, stack);
460458

461459
EXPECT_TENSOR_EQ(out, tf.make({1, 3, 1}, {1, 2, 3}));
462460
EXPECT_EQ(out.const_data_ptr(), self.const_data_ptr());
@@ -493,14 +491,15 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) {
493491

494492
// good size test
495493
EValue* stack[3] = {&self_evalue, &size_int_list_evalue, &out_evalue};
496-
getOpsFn("executorch_prim::et_view.default")(context, stack);
494+
getOpsFn("executorch_prim::et_view.default")(context_, stack);
497495
EXPECT_TENSOR_EQ(out, tf.make({3, 1, 0}, {}));
498496
EXPECT_EQ(out.const_data_ptr(), self.const_data_ptr());
499497

500498
// bad size test
501499
EValue* bad_stack[3] = {&self_evalue, &bad_size_int_list_evalue, &out_evalue};
502-
ET_EXPECT_DEATH(
503-
getOpsFn("executorch_prim::et_view.default")(context, bad_stack), "");
500+
ET_EXPECT_KERNEL_FAILURE(
501+
context_,
502+
getOpsFn("executorch_prim::et_view.default")(context_, bad_stack));
504503
}
505504

506505
TEST_F(RegisterPrimOpsTest, TestCeil) {
@@ -518,7 +517,7 @@ TEST_F(RegisterPrimOpsTest, TestCeil) {
518517
stack[j] = &values[j];
519518
}
520519

521-
getOpsFn("executorch_prim::ceil.Scalar")(context, stack);
520+
getOpsFn("executorch_prim::ceil.Scalar")(context_, stack);
522521
EXPECT_EQ(stack[1]->toInt(), expected[i]);
523522
}
524523
}
@@ -539,7 +538,7 @@ TEST_F(RegisterPrimOpsTest, TestRound) {
539538
stack[j] = &values[j];
540539
}
541540

542-
getOpsFn("executorch_prim::round.Scalar")(context, stack);
541+
getOpsFn("executorch_prim::round.Scalar")(context_, stack);
543542
EXPECT_EQ(stack[1]->toInt(), expected[i]);
544543
}
545544
}
@@ -559,7 +558,7 @@ TEST_F(RegisterPrimOpsTest, TestTrunc) {
559558
stack[j] = &values[j];
560559
}
561560

562-
getOpsFn("executorch_prim::trunc.Scalar")(context, stack);
561+
getOpsFn("executorch_prim::trunc.Scalar")(context_, stack);
563562
EXPECT_EQ(stack[1]->toInt(), expected[i]);
564563
}
565564
}

0 commit comments

Comments
 (0)