8
8
9
9
#include < executorch/runtime/core/portable_type/tensor_impl.h>
10
10
11
+ #include < algorithm>
11
12
#include < cstdint>
12
- #include < cstring> // std::memcpy
13
13
14
14
#include < executorch/runtime/core/exec_aten/util/dim_order_util.h>
15
15
#include < executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -25,11 +25,11 @@ namespace {
25
25
* Compute the number of elements based on the sizes of a tensor.
26
26
*/
27
27
ssize_t compute_numel (const TensorImpl::SizesType* sizes, ssize_t dim) {
28
- ssize_t n = 1 ;
29
- for (ssize_t i = 0 ; i < dim; i++ ) {
30
- n *= sizes[i];
28
+ ssize_t numel = 1 ; // Zero-dimensional tensors (scalars) have numel == 1.
29
+ for (ssize_t i = 0 ; i < dim; ++i ) {
30
+ numel *= sizes[i];
31
31
}
32
- return n ;
32
+ return numel ;
33
33
}
34
34
} // namespace
35
35
@@ -67,7 +67,7 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef<SizesType> new_sizes) {
67
67
ET_CHECK_OR_RETURN_ERROR (
68
68
new_sizes.size () == dim_,
69
69
NotSupported,
70
- " ETensor rank is immutable old: %zu new: %zu" ,
70
+ " Attempted to change the tensor rank which is immutable: old= %zu, new= %zu" ,
71
71
dim_,
72
72
new_sizes.size ());
73
73
@@ -82,55 +82,39 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef<SizesType> new_sizes) {
82
82
if (dim_ == 0 ) {
83
83
return Error::Ok;
84
84
}
85
-
86
- // Can only resize a StaticShape Tensor to the same size
87
- if (shape_dynamism_ == TensorShapeDynamism::STATIC) {
88
- for (int i = 0 ; i < new_sizes.size (); i++) {
85
+ switch (shape_dynamism_) {
86
+ case TensorShapeDynamism::STATIC:
89
87
ET_CHECK_OR_RETURN_ERROR (
90
- new_sizes[i] == sizes_[i] ,
88
+ std::equal (sizes_, sizes_ + dim_, new_sizes. begin ()) ,
91
89
NotSupported,
92
- " Attempted to resize a static tensor to a new shape at "
93
- " dimension %d old_size: %d new_size: %d" ,
94
- i,
95
- sizes_[i],
96
- new_sizes[i]);
97
- }
98
- // no work to do after checking for error
99
- return Error::Ok;
100
- }
101
-
102
- const auto new_numel = compute_numel (new_sizes.data (), dim_);
103
-
104
- // Bounded tensors can be reshaped, but not beyond the upper bound.
105
- if (shape_dynamism_ == TensorShapeDynamism::DYNAMIC_BOUND ||
90
+ " Attempted to resize a static tensor" );
91
+ break ;
92
+ case TensorShapeDynamism::DYNAMIC_BOUND:
106
93
// TODO(T175194371): Unbounded dynamic tensor resizing is not yet
107
94
// supported: treat them as upper-bounded.
108
- shape_dynamism_ == TensorShapeDynamism::DYNAMIC_UNBOUND) {
109
- ET_CHECK_OR_RETURN_ERROR (
110
- new_numel <= numel_bound_,
111
- NotSupported,
112
- " Attempted to resize a bounded tensor with capacity of %zu elements to %zu elements." ,
113
- new_numel,
114
- numel_bound_);
95
+ case TensorShapeDynamism::DYNAMIC_UNBOUND: {
96
+ const auto new_numel = compute_numel (new_sizes.data (), dim_);
97
+ ET_CHECK_OR_RETURN_ERROR (
98
+ new_numel <= numel_bound_,
99
+ NotSupported,
100
+ " Attempted to resize a bounded tensor with capacity of %zu elements to %zu elements." ,
101
+ new_numel,
102
+ numel_bound_);
103
+ ET_CHECK_OR_RETURN_ERROR (
104
+ strides_ != nullptr ,
105
+ Internal,
106
+ " Strides cannot be nullptr for resize" );
107
+ ET_CHECK_OR_RETURN_ERROR (
108
+ dim_order_ != nullptr ,
109
+ Internal,
110
+ " Dim order cannot be nullptr for resize" );
111
+ ET_CHECK_OK_OR_RETURN_ERROR (
112
+ dim_order_to_stride (new_sizes.data (), dim_order_, dim_, strides_));
113
+
114
+ numel_ = new_numel;
115
+ std::copy (new_sizes.begin (), new_sizes.end (), sizes_);
116
+ }
115
117
}
116
-
117
- // Copy sizes over
118
- std::memcpy (sizes_, new_sizes.data (), sizeof (SizesType) * dim_);
119
-
120
- // Compute new strides
121
- ET_CHECK_OR_RETURN_ERROR (
122
- strides_ != nullptr , Internal, " Strides cannot be nullptr for resize" );
123
- ET_CHECK_OR_RETURN_ERROR (
124
- dim_order_ != nullptr ,
125
- Internal,
126
- " Dim order cannot be nullptr for resize" );
127
- auto status = dim_order_to_stride (sizes_, dim_order_, dim_, strides_);
128
- ET_CHECK_OR_RETURN_ERROR (
129
- status == Error::Ok,
130
- Internal,
131
- " dim_order_to_stride returned invalid status" );
132
- numel_ = new_numel;
133
-
134
118
return Error::Ok;
135
119
}
136
120
0 commit comments