Skip to content

Commit a5c1bb9

Browse files
authored
Small refactoring of TensorImpl. (#4640)
Summary: Do the actual size change after all the checks. Reviewed By: JacobSzwejbka Differential Revision: D60854861
1 parent 79c15ef commit a5c1bb9

File tree

2 files changed

+35
-51
lines changed

2 files changed

+35
-51
lines changed

.github/workflows/pull.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ jobs:
337337
size=${arr[4]}
338338
# threshold=48120 on devserver with gcc11.4
339339
# todo(lfq): update once binary size is below 50kb.
340-
threshold="51768"
340+
threshold="51784"
341341
if [[ "$size" -le "$threshold" ]]; then
342342
echo "Success $size <= $threshold"
343343
else

runtime/core/portable_type/tensor_impl.cpp

Lines changed: 34 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
#include <executorch/runtime/core/portable_type/tensor_impl.h>
1010

11+
#include <algorithm>
1112
#include <cstdint>
12-
#include <cstring> // std::memcpy
1313

1414
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
1515
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -25,11 +25,11 @@ namespace {
2525
* Compute the number of elements based on the sizes of a tensor.
2626
*/
2727
ssize_t compute_numel(const TensorImpl::SizesType* sizes, ssize_t dim) {
28-
ssize_t n = 1;
29-
for (ssize_t i = 0; i < dim; i++) {
30-
n *= sizes[i];
28+
ssize_t numel = 1; // Zero-dimensional tensors (scalars) have numel == 1.
29+
for (ssize_t i = 0; i < dim; ++i) {
30+
numel *= sizes[i];
3131
}
32-
return n;
32+
return numel;
3333
}
3434
} // namespace
3535

@@ -67,7 +67,7 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef<SizesType> new_sizes) {
6767
ET_CHECK_OR_RETURN_ERROR(
6868
new_sizes.size() == dim_,
6969
NotSupported,
70-
"ETensor rank is immutable old: %zu new: %zu",
70+
"Attempted to change the tensor rank which is immutable: old=%zu, new=%zu",
7171
dim_,
7272
new_sizes.size());
7373

@@ -82,55 +82,39 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef<SizesType> new_sizes) {
8282
if (dim_ == 0) {
8383
return Error::Ok;
8484
}
85-
86-
// Can only resize a StaticShape Tensor to the same size
87-
if (shape_dynamism_ == TensorShapeDynamism::STATIC) {
88-
for (int i = 0; i < new_sizes.size(); i++) {
85+
switch (shape_dynamism_) {
86+
case TensorShapeDynamism::STATIC:
8987
ET_CHECK_OR_RETURN_ERROR(
90-
new_sizes[i] == sizes_[i],
88+
std::equal(sizes_, sizes_ + dim_, new_sizes.begin()),
9189
NotSupported,
92-
"Attempted to resize a static tensor to a new shape at "
93-
"dimension %d old_size: %d new_size: %d",
94-
i,
95-
sizes_[i],
96-
new_sizes[i]);
97-
}
98-
// no work to do after checking for error
99-
return Error::Ok;
100-
}
101-
102-
const auto new_numel = compute_numel(new_sizes.data(), dim_);
103-
104-
// Bounded tensors can be reshaped, but not beyond the upper bound.
105-
if (shape_dynamism_ == TensorShapeDynamism::DYNAMIC_BOUND ||
90+
"Attempted to resize a static tensor");
91+
break;
92+
case TensorShapeDynamism::DYNAMIC_BOUND:
10693
// TODO(T175194371): Unbounded dynamic tensor resizing is not yet
10794
// supported: treat them as upper-bounded.
108-
shape_dynamism_ == TensorShapeDynamism::DYNAMIC_UNBOUND) {
109-
ET_CHECK_OR_RETURN_ERROR(
110-
new_numel <= numel_bound_,
111-
NotSupported,
112-
"Attempted to resize a bounded tensor with capacity of %zu elements to %zu elements.",
113-
new_numel,
114-
numel_bound_);
95+
case TensorShapeDynamism::DYNAMIC_UNBOUND: {
96+
const auto new_numel = compute_numel(new_sizes.data(), dim_);
97+
ET_CHECK_OR_RETURN_ERROR(
98+
new_numel <= numel_bound_,
99+
NotSupported,
100+
"Attempted to resize a bounded tensor with capacity of %zu elements to %zu elements.",
101+
new_numel,
102+
numel_bound_);
103+
ET_CHECK_OR_RETURN_ERROR(
104+
strides_ != nullptr,
105+
Internal,
106+
"Strides cannot be nullptr for resize");
107+
ET_CHECK_OR_RETURN_ERROR(
108+
dim_order_ != nullptr,
109+
Internal,
110+
"Dim order cannot be nullptr for resize");
111+
ET_CHECK_OK_OR_RETURN_ERROR(
112+
dim_order_to_stride(new_sizes.data(), dim_order_, dim_, strides_));
113+
114+
numel_ = new_numel;
115+
std::copy(new_sizes.begin(), new_sizes.end(), sizes_);
116+
}
115117
}
116-
117-
// Copy sizes over
118-
std::memcpy(sizes_, new_sizes.data(), sizeof(SizesType) * dim_);
119-
120-
// Compute new strides
121-
ET_CHECK_OR_RETURN_ERROR(
122-
strides_ != nullptr, Internal, "Strides cannot be nullptr for resize");
123-
ET_CHECK_OR_RETURN_ERROR(
124-
dim_order_ != nullptr,
125-
Internal,
126-
"Dim order cannot be nullptr for resize");
127-
auto status = dim_order_to_stride(sizes_, dim_order_, dim_, strides_);
128-
ET_CHECK_OR_RETURN_ERROR(
129-
status == Error::Ok,
130-
Internal,
131-
"dim_order_to_stride returned invalid status");
132-
numel_ = new_numel;
133-
134118
return Error::Ok;
135119
}
136120

0 commit comments

Comments
 (0)