Skip to content

Commit 89958d2

Browse files
dbortfacebook-github-bot
authored andcommitted
Catch invalid scalar type when parsing tensors (#1518)
Summary: Pull Request resolved: #1518 Fail non-fatally when encountering an unknown/unhandled `ScalarType` in a `.pte` file. As part of this: - Move the "types not supported yet" logic out of `scalar_type_util` and into `tensor_parser`, since that decision is an aspect of the runtime and not a fundamental aspect of `ScalarType`. - Remove the now-duplicate `sizeof_scalar_type` function, which is the same as the exsting `elementSize` function. Before this diff, `sizeof_scalar_type` did the "unsupported" checks that have now moved. - Add an `isValid()` function to let users of `ScalarType` know whether a given enum value is legit. This makes it possible to avoid the fatal error when calling `elementSize` on a bad value. - Add unit tests for the new `isValid()`. Reviewed By: larryliu0820 Differential Revision: D52451738 fbshipit-source-id: 88e47d7a3c688e3e2a68cc86114935c7ff6a73b5
1 parent d869385 commit 89958d2

File tree

7 files changed

+63
-42
lines changed

7 files changed

+63
-42
lines changed

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,16 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType)
253253
// Utility functions to retrieve metadata for a given ScalarType
254254
//
255255

256+
/**
257+
* Returns true if the parameter is one of the values covered by
258+
* ET_FORALL_SCALAR_TYPES.
259+
*/
260+
inline bool isValid(exec_aten::ScalarType type) {
261+
return static_cast<int8_t>(type) >= 0 &&
262+
type < exec_aten::ScalarType::NumOptions &&
263+
type != exec_aten::ScalarType::Undefined;
264+
}
265+
256266
/**
257267
* Returns the name of a ScalarType as a C string.
258268
*
@@ -541,38 +551,6 @@ inline exec_aten::ScalarType promoteTypes(
541551
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
542552
}
543553

544-
/**
545-
* Return the size of corresponding ctype given ScalarType.
546-
*/
547-
inline size_t sizeof_scalar_type(exec_aten::ScalarType type) {
548-
// Reject types that are not yet supported or are out of bounds.
549-
ET_CHECK_MSG(
550-
type != exec_aten::ScalarType::Half &&
551-
type != exec_aten::ScalarType::ComplexHalf &&
552-
type != exec_aten::ScalarType::ComplexFloat &&
553-
type != exec_aten::ScalarType::ComplexDouble &&
554-
type != exec_aten::ScalarType::BFloat16 &&
555-
type != exec_aten::ScalarType::Undefined,
556-
"Invalid or unsupported ScalarType %" PRId8,
557-
static_cast<int8_t>(type));
558-
559-
size_t type_size = 0;
560-
#define SCALAR_TYPE_SIZE(ctype, dtype) \
561-
case exec_aten::ScalarType::dtype: \
562-
type_size = sizeof(ctype); \
563-
break;
564-
565-
switch (type) {
566-
ET_FORALL_SCALAR_TYPES(SCALAR_TYPE_SIZE)
567-
default:
568-
ET_CHECK_MSG(
569-
false, "Invalid input ScalarType %" PRId8, static_cast<int8_t>(type));
570-
}
571-
#undef SCALAR_TYPE_SIZE
572-
573-
return type_size;
574-
}
575-
576554
//
577555
// Helper macros for switch case macros (see below)
578556
//

runtime/core/exec_aten/util/test/scalar_type_util_test.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,26 @@ TEST(ScalarTypeUtilTest, ElementSize) {
7070
}
7171
}
7272

73+
TEST(ScalarTypeUtilTest, IsValidTrue) {
74+
// Some valid types.
75+
EXPECT_TRUE(torch::executor::isValid(ScalarType::Byte));
76+
EXPECT_TRUE(torch::executor::isValid(ScalarType::Float));
77+
EXPECT_TRUE(torch::executor::isValid(ScalarType::ComplexFloat));
78+
EXPECT_TRUE(torch::executor::isValid(ScalarType::Bits16));
79+
}
80+
81+
TEST(ScalarTypeUtilTest, IsValidFalse) {
82+
// Undefined, which is sort of a special case since it's not part of the
83+
// iteration macros but is still a part of the enum.
84+
EXPECT_FALSE(torch::executor::isValid(ScalarType::Undefined));
85+
86+
// Some out-of-range types, also demonstrating that NumOptions is not really a
87+
// scalar type.
88+
EXPECT_FALSE(torch::executor::isValid(ScalarType::NumOptions));
89+
EXPECT_FALSE(torch::executor::isValid(static_cast<ScalarType>(127)));
90+
EXPECT_FALSE(torch::executor::isValid(static_cast<ScalarType>(-1)));
91+
}
92+
7393
TEST(ScalarTypeUtilTest, UnknownTypeElementSizeDies) {
7494
// Undefined, which is sort of a special case since it's not part of the
7595
// iteration macros but is still a part of the enum.

runtime/core/portable_type/tensor_impl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ TensorImpl::TensorImpl(
4747
data_(data),
4848
dim_(dim),
4949
numel_(compute_numel(sizes, dim)),
50-
capacity_(numel_ * sizeof_scalar_type(type)),
50+
capacity_(numel_ * elementSize(type)),
5151
type_(type),
5252
shape_dynamism_(dynamism) {}
5353

5454
size_t TensorImpl::nbytes() const {
55-
return numel_ * sizeof_scalar_type(type_);
55+
return numel_ * elementSize(type_);
5656
}
5757

5858
ssize_t TensorImpl::size(ssize_t dim) const {
@@ -78,7 +78,7 @@ ScalarType TensorImpl::scalar_type() const {
7878

7979
// Return the size of one element of the tensor
8080
ssize_t TensorImpl::element_size() const {
81-
return sizeof_scalar_type(type_);
81+
return elementSize(type_);
8282
}
8383

8484
const ArrayRef<TensorImpl::SizesType> TensorImpl::sizes() const {
@@ -145,7 +145,7 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef<SizesType> new_sizes) {
145145

146146
// Upper bounded tensors can be reshaped but not beyond upper bound
147147
if (shape_dynamism_ == TensorShapeDynamism::DYNAMIC_BOUND) {
148-
auto new_nbytes = new_numel * sizeof_scalar_type(type_);
148+
auto new_nbytes = new_numel * elementSize(type_);
149149
ET_CHECK_OR_RETURN_ERROR(
150150
new_nbytes <= capacity_,
151151
NotSupported,

runtime/core/portable_type/test/executor_tensor_test.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@ namespace executor {
1515

1616
TEST(TensorTest, InvalidScalarType) {
1717
TensorImpl::SizesType sizes[1] = {1};
18-
// A type that executorch doesn't support yet.
19-
ET_EXPECT_DEATH({ TensorImpl x(ScalarType::BFloat16, 1, sizes); }, "");
2018

21-
// The literal Undefined type.
19+
// Undefined, which is sort of a special case since it's not part of the
20+
// iteration macros but is still a part of the enum.
2221
ET_EXPECT_DEATH({ TensorImpl y(ScalarType::Undefined, 1, sizes); }, "");
2322

24-
// An int value that doesn't map to a valid enum value
23+
// Some out-of-range types, also demonstrating that NumOptions is not really a
24+
// scalar type.
2525
ET_EXPECT_DEATH({ TensorImpl y(ScalarType::NumOptions, 1, sizes); }, "");
26+
ET_EXPECT_DEATH(
27+
{ TensorImpl y(static_cast<ScalarType>(127), 1, sizes); }, "");
28+
ET_EXPECT_DEATH({ TensorImpl y(static_cast<ScalarType>(-1), 1, sizes); }, "");
2629
}
2730

2831
TEST(TensorTest, SetData) {

runtime/executor/method_meta.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ size_t calculate_nbytes(
5959
for (ssize_t i = 0; i < sizes.size(); i++) {
6060
n *= sizes[i];
6161
}
62-
return n * sizeof_scalar_type(scalar_type);
62+
return n * torch::executor::elementSize(scalar_type);
6363
}
6464

6565
} // namespace

runtime/executor/tensor_parser_aten.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/runtime/executor/tensor_parser.h>
1010

1111
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
12+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
1213
#include <executorch/runtime/executor/memory_manager.h>
1314
#include <executorch/runtime/executor/program.h>
1415
#include <executorch/runtime/platform/profiler.h>
@@ -43,6 +44,11 @@ Result<at::Tensor> parseTensor(
4344

4445
// get metadata
4546
at::ScalarType type = static_cast<at::ScalarType>(s_tensor->scalar_type());
47+
ET_CHECK_OR_RETURN_ERROR(
48+
isValid(type),
49+
InvalidProgram,
50+
"Invalid ScalarType %" PRId8,
51+
static_cast<int8_t>(type));
4652
auto options = at::CPU(type).options();
4753

4854
// convert int32 in serialization to int64 for aten

runtime/executor/tensor_parser_portable.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/runtime/core/evalue.h>
1212
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1313
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
14+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
1415
#include <executorch/runtime/executor/memory_manager.h>
1516
#include <executorch/runtime/executor/program.h>
1617
#include <executorch/runtime/platform/profiler.h>
@@ -33,6 +34,19 @@ Result<torch::executor::Tensor> parseTensor(
3334
"Non-zero storage offset %" PRId32 " not supported",
3435
s_tensor->storage_offset());
3536

37+
ScalarType scalar_type = static_cast<ScalarType>(s_tensor->scalar_type());
38+
ET_CHECK_OR_RETURN_ERROR(
39+
isValid(scalar_type) &&
40+
// Types that do not yet have deserialization support.
41+
scalar_type != exec_aten::ScalarType::Half &&
42+
scalar_type != exec_aten::ScalarType::ComplexHalf &&
43+
scalar_type != exec_aten::ScalarType::ComplexFloat &&
44+
scalar_type != exec_aten::ScalarType::ComplexDouble &&
45+
scalar_type != exec_aten::ScalarType::BFloat16,
46+
InvalidProgram,
47+
"Invalid or unsupported ScalarType %" PRId8,
48+
static_cast<int8_t>(scalar_type));
49+
3650
TensorShapeDynamism dynamism =
3751
static_cast<TensorShapeDynamism>(s_tensor->shape_dynamism());
3852
// TODO(T133200526): Remove this check once fully dynamic shapes are
@@ -90,7 +104,7 @@ Result<torch::executor::Tensor> parseTensor(
90104
// Placement new on the allocated memory space. Note that we create this first
91105
// with null data so we can find its expected size before getting its memory.
92106
new (tensor_impl) torch::executor::TensorImpl(
93-
static_cast<ScalarType>(s_tensor->scalar_type()),
107+
scalar_type,
94108
dim,
95109
sizes,
96110
/*data=*/nullptr,

0 commit comments

Comments
 (0)