Skip to content

Commit 801e1c9

Browse files
authored
Forbid having TensorImpl with zero number of elements and null data.
Differential Revision: D61810277 Pull Request resolved: #4909
1 parent 5942e4a commit 801e1c9

File tree

3 files changed

+91
-4
lines changed

3 files changed

+91
-4
lines changed

runtime/core/exec_aten/testing_util/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def define_common_targets():
2323
# list.
2424
"//executorch/runtime/core/exec_aten/util/test/...",
2525
"//executorch/runtime/core/exec_aten/testing_util/test/...",
26+
"//executorch/runtime/core/portable_type/test/...",
2627
"//executorch/kernels/prim_ops/test/...",
2728
"//executorch/kernels/portable/test/...",
2829
"//executorch/kernels/portable/cpu/util/test/...",

runtime/core/portable_type/tensor_impl.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,16 @@ namespace {
2525
* Compute the number of elements based on the sizes of a tensor.
2626
*/
2727
ssize_t compute_numel(const TensorImpl::SizesType* sizes, ssize_t dim) {
28+
ET_CHECK_MSG(
29+
dim == 0 || sizes != nullptr,
30+
"Sizes must be provided for non-scalar tensors");
2831
ssize_t numel = 1; // Zero-dimensional tensors (scalars) have numel == 1.
2932
for (ssize_t i = 0; i < dim; ++i) {
33+
ET_CHECK_MSG(
34+
sizes[i] >= 0,
35+
"Size must be non-negative, got %d at dimension %zd",
36+
sizes[i],
37+
i);
3038
numel *= sizes[i];
3139
}
3240
return numel;
@@ -52,6 +60,7 @@ TensorImpl::TensorImpl(
5260
shape_dynamism_(dynamism) {
5361
ET_CHECK_MSG(
5462
isValid(type_), "Invalid type %" PRId8, static_cast<int8_t>(type_));
63+
ET_CHECK_MSG(dim_ >= 0, "Dimension must be non-negative, got %zd", dim_);
5564
}
5665

5766
size_t TensorImpl::nbytes() const {

runtime/core/portable_type/test/tensor_impl_test.cpp

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88

99
#include <executorch/runtime/core/portable_type/tensor_impl.h>
1010

11-
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
12-
#include <executorch/runtime/platform/runtime.h>
13-
1411
#include <gtest/gtest.h>
1512
#include <random>
1613

14+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
15+
#include <executorch/runtime/platform/runtime.h>
16+
#include <executorch/test/utils/DeathTest.h>
17+
1718
using namespace ::testing;
1819

1920
namespace torch {
@@ -29,7 +30,7 @@ class TensorImplTest : public ::testing::Test {
2930
void SetUp() override {
3031
// Since these tests cause ET_LOG to be called, the PAL must be initialized
3132
// first.
32-
torch::executor::runtime_init();
33+
runtime_init();
3334
}
3435
};
3536

@@ -370,5 +371,81 @@ TEST_F(TensorImplTest, TestWriteRead) {
370371
EXPECT_EQ(y[0], 22.0);
371372
}
372373

374+
TEST_F(TensorImplTest, TestInvalidScalarType) {
375+
SizesType sizes[2] = {3, 2};
376+
ET_EXPECT_DEATH(TensorImpl t(static_cast<ScalarType>(-1), 2, sizes), "");
377+
}
378+
379+
TEST_F(TensorImplTest, TestNegativeDimension) {
380+
SizesType sizes[2] = {3, 2};
381+
ET_EXPECT_DEATH(TensorImpl t(ScalarType::Float, -1, sizes), "");
382+
}
383+
384+
TEST_F(TensorImplTest, TestNullSizesNonZeroDim) {
385+
ET_EXPECT_DEATH(TensorImpl t(ScalarType::Float, 2, nullptr), "");
386+
}
387+
388+
TEST_F(TensorImplTest, TestNonNegativeSizes) {
389+
SizesType sizes[2] = {3, -2};
390+
ET_EXPECT_DEATH(TensorImpl t(ScalarType::Float, 2, sizes), "");
391+
}
392+
393+
TEST_F(TensorImplTest, TestEmptyTensor) {
394+
SizesType sizes[2] = {0, 0};
395+
TensorImpl t(ScalarType::Float, 2, sizes);
396+
EXPECT_EQ(t.numel(), 0);
397+
EXPECT_EQ(t.data(), nullptr);
398+
}
399+
400+
TEST_F(TensorImplTest, TestTensorWithNoElementsButAllocatedMemory) {
401+
SizesType sizes[2] = {0, 0};
402+
float data[1] = {1.0};
403+
TensorImpl t(ScalarType::Float, 2, sizes, data);
404+
EXPECT_EQ(t.numel(), 0);
405+
EXPECT_EQ(t.data(), data);
406+
}
407+
408+
TEST_F(TensorImplTest, TestTensorWithShapeButNoMemory) {
409+
SizesType sizes[2] = {3, 2};
410+
TensorImpl t(ScalarType::Float, 2, sizes);
411+
EXPECT_GT(t.numel(), 0);
412+
EXPECT_EQ(t.data(), nullptr);
413+
}
414+
415+
TEST_F(TensorImplTest, TestNormalTensor) {
416+
SizesType sizes[2] = {3, 2};
417+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
418+
TensorImpl t(ScalarType::Float, 2, sizes, data);
419+
EXPECT_GT(t.numel(), 0);
420+
EXPECT_EQ(t.data(), data);
421+
}
422+
423+
TEST_F(TensorImplTest, TestResizingTensorToZeroAndBack) {
424+
SizesType sizes[2] = {3, 2};
425+
TensorImpl t(
426+
ScalarType::Float,
427+
2,
428+
sizes,
429+
nullptr,
430+
nullptr,
431+
nullptr,
432+
TensorShapeDynamism::DYNAMIC_BOUND);
433+
434+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
435+
t.set_data(data);
436+
EXPECT_GT(t.numel(), 0);
437+
EXPECT_EQ(t.data(), data);
438+
439+
SizesType zero_sizes[2] = {0, 0};
440+
t.set_sizes_contiguous({zero_sizes, 2});
441+
EXPECT_EQ(t.numel(), 0);
442+
EXPECT_EQ(t.data(), data);
443+
444+
SizesType new_sizes[2] = {3, 2};
445+
t.set_sizes_contiguous({new_sizes, 2});
446+
EXPECT_GT(t.numel(), 0);
447+
EXPECT_EQ(t.data(), data);
448+
}
449+
373450
} // namespace executor
374451
} // namespace torch

0 commit comments

Comments
 (0)