8
8
9
9
#include < gtest/gtest.h>
10
10
11
+ #include < executorch/kernels/test/TestUtil.h>
11
12
#include < executorch/runtime/core/evalue.h>
12
13
#include < executorch/runtime/core/exec_aten/exec_aten.h>
13
14
#include < executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
16
17
#include < executorch/runtime/kernel/kernel_runtime_context.h>
17
18
#include < executorch/runtime/kernel/operator_registry.h>
18
19
#include < executorch/runtime/platform/runtime.h>
19
- #include < executorch/test/utils/DeathTest.h>
20
20
#include < cstdint>
21
21
#include < cstdio>
22
22
@@ -27,12 +27,10 @@ using torch::executor::resize_tensor;
27
27
namespace torch {
28
28
namespace executor {
29
29
30
- class RegisterPrimOpsTest : public ::testing::Test {
30
+ class RegisterPrimOpsTest : public OperatorTest {
31
31
protected:
32
- KernelRuntimeContext context;
33
32
void SetUp () override {
34
- torch::executor::runtime_init ();
35
- context = KernelRuntimeContext ();
33
+ context_ = KernelRuntimeContext ();
36
34
}
37
35
};
38
36
@@ -57,7 +55,7 @@ TEST_F(RegisterPrimOpsTest, SymSizeReturnsCorrectValue) {
57
55
stack[i] = &values[i];
58
56
}
59
57
60
- getOpsFn (" aten::sym_size.int" )(context , stack);
58
+ getOpsFn (" aten::sym_size.int" )(context_ , stack);
61
59
62
60
int64_t expected = 5 ;
63
61
EXPECT_EQ (stack[2 ]->toInt (), expected);
@@ -77,7 +75,7 @@ TEST_F(RegisterPrimOpsTest, SymNumelReturnsCorrectValue) {
77
75
stack[i] = &values[i];
78
76
}
79
77
80
- getOpsFn (" aten::sym_numel" )(context , stack);
78
+ getOpsFn (" aten::sym_numel" )(context_ , stack);
81
79
82
80
int64_t expected = 15 ;
83
81
EXPECT_EQ (stack[1 ]->toInt (), expected);
@@ -97,28 +95,28 @@ TEST_F(RegisterPrimOpsTest, TestAlgebraOps) {
97
95
stack[i] = &values[i];
98
96
}
99
97
100
- getOpsFn (" executorch_prim::add.Scalar" )(context , stack);
98
+ getOpsFn (" executorch_prim::add.Scalar" )(context_ , stack);
101
99
EXPECT_EQ (stack[2 ]->toInt (), 7 );
102
100
103
- getOpsFn (" executorch_prim::sub.Scalar" )(context , stack);
101
+ getOpsFn (" executorch_prim::sub.Scalar" )(context_ , stack);
104
102
EXPECT_EQ (stack[2 ]->toInt (), -1 );
105
103
106
- getOpsFn (" executorch_prim::mul.Scalar" )(context , stack);
104
+ getOpsFn (" executorch_prim::mul.Scalar" )(context_ , stack);
107
105
EXPECT_EQ (stack[2 ]->toInt (), 12 );
108
106
109
- getOpsFn (" executorch_prim::floordiv.Scalar" )(context , stack);
107
+ getOpsFn (" executorch_prim::floordiv.Scalar" )(context_ , stack);
110
108
EXPECT_EQ (stack[2 ]->toInt (), 0 );
111
109
112
- getOpsFn (" executorch_prim::truediv.Scalar" )(context , stack);
110
+ getOpsFn (" executorch_prim::truediv.Scalar" )(context_ , stack);
113
111
EXPECT_FLOAT_EQ (stack[2 ]->toDouble (), 0.75 );
114
112
115
- getOpsFn (" executorch_prim::mod.int" )(context , stack);
113
+ getOpsFn (" executorch_prim::mod.int" )(context_ , stack);
116
114
EXPECT_EQ (stack[2 ]->toInt (), 3 );
117
115
118
- getOpsFn (" executorch_prim::mod.Scalar" )(context , stack);
116
+ getOpsFn (" executorch_prim::mod.Scalar" )(context_ , stack);
119
117
EXPECT_EQ (stack[2 ]->toInt (), 3 );
120
118
121
- getOpsFn (" executorch_prim::sym_float.Scalar" )(context , stack);
119
+ getOpsFn (" executorch_prim::sym_float.Scalar" )(context_ , stack);
122
120
EXPECT_FLOAT_EQ (stack[1 ]->toDouble (), 3.0 );
123
121
}
124
122
@@ -155,7 +153,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndex) {
155
153
stack[2 ] = &values[2 ];
156
154
157
155
// Simple test to copy to index 0.
158
- getOpsFn (" executorch_prim::et_copy_index.tensor" )(context , stack);
156
+ getOpsFn (" executorch_prim::et_copy_index.tensor" )(context_ , stack);
159
157
160
158
EXPECT_EQ (copy_to.sizes ()[0 ], 1 );
161
159
EXPECT_EQ (copy_to.sizes ()[1 ], 2 );
@@ -164,7 +162,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndex) {
164
162
values[1 ] = tf.make ({2 }, {5 , 6 });
165
163
values[2 ] = EValue ((int64_t )1 );
166
164
// Copy to the next index, 1.
167
- getOpsFn (" executorch_prim::et_copy_index.tensor" )(context , stack);
165
+ getOpsFn (" executorch_prim::et_copy_index.tensor" )(context_ , stack);
168
166
169
167
EXPECT_EQ (copy_to.sizes ()[0 ], 2 );
170
168
EXPECT_EQ (copy_to.sizes ()[1 ], 2 );
@@ -193,7 +191,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexMismatchShape) {
193
191
// copy_to.sizes[1:] and to_copy.sizes[:] don't match each other
194
192
// which is a pre-requisite for this operator.
195
193
ET_EXPECT_DEATH (
196
- getOpsFn (" executorch_prim::et_copy_index.tensor" )(context , stack), " " );
194
+ getOpsFn (" executorch_prim::et_copy_index.tensor" )(context_ , stack), " " );
197
195
}
198
196
199
197
TEST_F (RegisterPrimOpsTest, TestETCopyIndexStaticShape) {
@@ -217,7 +215,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexStaticShape) {
217
215
stack[2 ] = &values[2 ];
218
216
219
217
// Copy and replace at index 1.
220
- getOpsFn (" executorch_prim::et_copy_index.tensor" )(context , stack);
218
+ getOpsFn (" executorch_prim::et_copy_index.tensor" )(context_ , stack);
221
219
EXPECT_EQ (copy_to.sizes ()[0 ], 2 );
222
220
EXPECT_EQ (copy_to.sizes ()[1 ], 2 );
223
221
EXPECT_TENSOR_EQ (copy_to, tf.make ({2 , 2 }, {1 , 2 , 5 , 6 }));
@@ -228,7 +226,7 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexStaticShape) {
228
226
index = 2 ;
229
227
values[2 ] = EValue (index);
230
228
ET_EXPECT_DEATH (
231
- getOpsFn (" executorch_prim::et_copy_index.tensor" )(context , stack), " " );
229
+ getOpsFn (" executorch_prim::et_copy_index.tensor" )(context_ , stack), " " );
232
230
#endif
233
231
}
234
232
@@ -246,19 +244,19 @@ TEST_F(RegisterPrimOpsTest, TestBooleanOps) {
246
244
stack[i] = &values[i];
247
245
}
248
246
249
- getOpsFn (" executorch_prim::ge.Scalar" )(context , stack);
247
+ getOpsFn (" executorch_prim::ge.Scalar" )(context_ , stack);
250
248
EXPECT_EQ (stack[2 ]->toBool (), false );
251
249
252
- getOpsFn (" executorch_prim::gt.Scalar" )(context , stack);
250
+ getOpsFn (" executorch_prim::gt.Scalar" )(context_ , stack);
253
251
EXPECT_EQ (stack[2 ]->toBool (), false );
254
252
255
- getOpsFn (" executorch_prim::le.Scalar" )(context , stack);
253
+ getOpsFn (" executorch_prim::le.Scalar" )(context_ , stack);
256
254
EXPECT_EQ (stack[2 ]->toBool (), true );
257
255
258
- getOpsFn (" executorch_prim::lt.Scalar" )(context , stack);
256
+ getOpsFn (" executorch_prim::lt.Scalar" )(context_ , stack);
259
257
EXPECT_EQ (stack[2 ]->toBool (), true );
260
258
261
- getOpsFn (" executorch_prim::eq.Scalar" )(context , stack);
259
+ getOpsFn (" executorch_prim::eq.Scalar" )(context_ , stack);
262
260
EXPECT_EQ (stack[2 ]->toBool (), false );
263
261
}
264
262
@@ -277,7 +275,7 @@ TEST_F(RegisterPrimOpsTest, LocalScalarDenseReturnsCorrectValue) {
277
275
stack[i] = &values[i];
278
276
}
279
277
280
- getOpsFn (" aten::_local_scalar_dense" )(context , stack);
278
+ getOpsFn (" aten::_local_scalar_dense" )(context_ , stack);
281
279
282
280
int64_t expected = 1 ;
283
281
EXPECT_EQ (stack[1 ]->toInt (), expected);
@@ -295,7 +293,7 @@ TEST_F(RegisterPrimOpsTest, NegScalarReturnsCorrectValue) {
295
293
stack[i] = &values[i];
296
294
}
297
295
298
- getOpsFn (" executorch_prim::neg.Scalar" )(context , stack);
296
+ getOpsFn (" executorch_prim::neg.Scalar" )(context_ , stack);
299
297
300
298
EXPECT_EQ (stack[1 ]->toDouble (), -5 .0f );
301
299
@@ -305,7 +303,7 @@ TEST_F(RegisterPrimOpsTest, NegScalarReturnsCorrectValue) {
305
303
values[0 ] = EValue (a);
306
304
values[1 ] = EValue (b);
307
305
308
- getOpsFn (" executorch_prim::neg.Scalar" )(context , stack);
306
+ getOpsFn (" executorch_prim::neg.Scalar" )(context_ , stack);
309
307
310
308
EXPECT_EQ (stack[1 ]->toInt (), -5l );
311
309
}
@@ -327,7 +325,7 @@ TEST_F(RegisterPrimOpsTest, TestNegScalarWithTensorDies) {
327
325
}
328
326
329
327
// Try to negate a tensor, which should cause a runtime error.
330
- ET_EXPECT_DEATH (getOpsFn (" executorch_prim::neg.Scalar" )(context , stack), " " );
328
+ ET_EXPECT_DEATH (getOpsFn (" executorch_prim::neg.Scalar" )(context_ , stack), " " );
331
329
}
332
330
333
331
TEST_F (RegisterPrimOpsTest, TestETView) {
@@ -410,9 +408,9 @@ TEST_F(RegisterPrimOpsTest, TestETView) {
410
408
411
409
// Bad stacks expect death
412
410
for (int i = 0 ; i < N_BAD_STACKS; i++) {
413
- ET_EXPECT_DEATH (
414
- getOpsFn ( " executorch_prim::et_view.default " )(context, bad_stacks[i]) ,
415
- " " );
411
+ ET_EXPECT_KERNEL_FAILURE (
412
+ context_ ,
413
+ getOpsFn ( " executorch_prim::et_view.default " )(context_, bad_stacks[i]) );
416
414
}
417
415
418
416
constexpr int N_GOOD_STACKS = N_GOOD_OUTS;
@@ -422,7 +420,7 @@ TEST_F(RegisterPrimOpsTest, TestETView) {
422
420
423
421
// Good outs expect no death and correct output
424
422
for (int i = 0 ; i < N_GOOD_STACKS; i++) {
425
- getOpsFn (" executorch_prim::et_view.default" )(context , good_out_stacks[i]);
423
+ getOpsFn (" executorch_prim::et_view.default" )(context_ , good_out_stacks[i]);
426
424
EXPECT_TENSOR_EQ (good_outs[i], tf.make ({1 , 3 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 }));
427
425
EXPECT_EQ (good_outs[i].const_data_ptr (), self.const_data_ptr ());
428
426
}
@@ -456,7 +454,7 @@ TEST_F(RegisterPrimOpsTest, TestETViewDynamic) {
456
454
457
455
EValue* stack[3 ] = {&self_evalue, &size_int_list_evalue, &out_evalue};
458
456
459
- getOpsFn (" executorch_prim::et_view.default" )(context , stack);
457
+ getOpsFn (" executorch_prim::et_view.default" )(context_ , stack);
460
458
461
459
EXPECT_TENSOR_EQ (out, tf.make ({1 , 3 , 1 }, {1 , 2 , 3 }));
462
460
EXPECT_EQ (out.const_data_ptr (), self.const_data_ptr ());
@@ -493,14 +491,15 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) {
493
491
494
492
// good size test
495
493
EValue* stack[3 ] = {&self_evalue, &size_int_list_evalue, &out_evalue};
496
- getOpsFn (" executorch_prim::et_view.default" )(context , stack);
494
+ getOpsFn (" executorch_prim::et_view.default" )(context_ , stack);
497
495
EXPECT_TENSOR_EQ (out, tf.make ({3 , 1 , 0 }, {}));
498
496
EXPECT_EQ (out.const_data_ptr (), self.const_data_ptr ());
499
497
500
498
// bad size test
501
499
EValue* bad_stack[3 ] = {&self_evalue, &bad_size_int_list_evalue, &out_evalue};
502
- ET_EXPECT_DEATH (
503
- getOpsFn (" executorch_prim::et_view.default" )(context, bad_stack), " " );
500
+ ET_EXPECT_KERNEL_FAILURE (
501
+ context_,
502
+ getOpsFn (" executorch_prim::et_view.default" )(context_, bad_stack));
504
503
}
505
504
506
505
TEST_F (RegisterPrimOpsTest, TestCeil) {
@@ -518,7 +517,7 @@ TEST_F(RegisterPrimOpsTest, TestCeil) {
518
517
stack[j] = &values[j];
519
518
}
520
519
521
- getOpsFn (" executorch_prim::ceil.Scalar" )(context , stack);
520
+ getOpsFn (" executorch_prim::ceil.Scalar" )(context_ , stack);
522
521
EXPECT_EQ (stack[1 ]->toInt (), expected[i]);
523
522
}
524
523
}
@@ -539,7 +538,7 @@ TEST_F(RegisterPrimOpsTest, TestRound) {
539
538
stack[j] = &values[j];
540
539
}
541
540
542
- getOpsFn (" executorch_prim::round.Scalar" )(context , stack);
541
+ getOpsFn (" executorch_prim::round.Scalar" )(context_ , stack);
543
542
EXPECT_EQ (stack[1 ]->toInt (), expected[i]);
544
543
}
545
544
}
@@ -559,7 +558,7 @@ TEST_F(RegisterPrimOpsTest, TestTrunc) {
559
558
stack[j] = &values[j];
560
559
}
561
560
562
- getOpsFn (" executorch_prim::trunc.Scalar" )(context , stack);
561
+ getOpsFn (" executorch_prim::trunc.Scalar" )(context_ , stack);
563
562
EXPECT_EQ (stack[1 ]->toInt (), expected[i]);
564
563
}
565
564
}
0 commit comments