6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
+ #include < executorch/kernels/portable/cpu/util/copy_ops_util.h>
9
10
#include < executorch/runtime/kernel/kernel_includes.h>
10
- #include < cstring>
11
11
12
12
namespace torch {
13
13
namespace executor {
@@ -17,40 +17,6 @@ using Tensor = exec_aten::Tensor;
17
17
using ScalarType = exec_aten::ScalarType;
18
18
19
19
namespace {
20
- size_t compute_storage_nbytes (
21
- IntArrayRef sizes,
22
- IntArrayRef strides,
23
- size_t itemsize_bytes) {
24
- // size of the underlying storage is 1 bigger than the offset
25
- // of the last element according to stride
26
- size_t size = 1 ;
27
- for (size_t i = 0 ; i < sizes.size (); ++i) {
28
- if (sizes[i] == 0 ) {
29
- return 0 ;
30
- }
31
- size += strides[i] * (sizes[i] - 1 );
32
- }
33
- return size * itemsize_bytes;
34
- }
35
-
36
- void check_inbounds_for_storage (
37
- const Tensor& self,
38
- ArrayRef<int64_t > size,
39
- ArrayRef<int64_t > stride,
40
- int64_t storage_offset) {
41
- size_t storage_size_bytes =
42
- compute_storage_nbytes (size, stride, self.element_size ());
43
- size_t storage_offset_bytes = storage_offset * self.element_size ();
44
- if (storage_size_bytes == 0 ) {
45
- return ;
46
- }
47
- size_t new_storage_size_bytes = self.nbytes ();
48
- ET_CHECK_MSG (
49
- storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
50
- " Requiring a storage size of %zd are out of bounds for storage of size %zd" ,
51
- storage_size_bytes + storage_offset_bytes,
52
- new_storage_size_bytes);
53
- }
54
20
55
21
/* *
56
22
* Copy input_data to output_data according to the stride and shape recursively
@@ -81,39 +47,8 @@ void as_strided_copy(
81
47
}
82
48
}
83
49
84
- void check_preconditions (
85
- const Tensor& self,
86
- ArrayRef<int64_t > size,
87
- ArrayRef<int64_t > stride,
88
- optional<int64_t > storage_offset,
89
- Tensor& out) {
90
- ET_CHECK_SAME_DTYPE2 (self, out);
91
- ET_CHECK_MSG (
92
- size.size () == stride.size (), " mismatch in length of strides and shape" );
93
- for (const auto & val : stride) {
94
- ET_CHECK_MSG (
95
- val >= 0 ,
96
- " as_strided: Negative strides are not supported at the moment" );
97
- }
98
- ET_CHECK_MSG (
99
- out.sizes ().size () == size.size (),
100
- " output tensor should have same shape as size" );
101
- for (size_t i = 0 ; i < out.sizes ().size (); ++i) {
102
- ET_CHECK_MSG (
103
- out.sizes ().at (i) == size.at (i),
104
- " output tensor should have same shape as size" );
105
- }
106
- int64_t offset = storage_offset.has_value () ? storage_offset.value () : 0 ;
107
- ET_CHECK_MSG (offset >= 0 , " Negative storage offset" );
108
- check_inbounds_for_storage (self, size, stride, offset);
109
- }
110
-
111
50
} // namespace
112
51
113
- /* *
114
- * Copy the tener `self` to `out`, assume `self` and `out` have same type and
115
- * shape
116
- */
117
52
Tensor& as_strided_copy_out (
118
53
RuntimeContext& ctx,
119
54
const Tensor& self,
@@ -123,34 +58,35 @@ Tensor& as_strided_copy_out(
123
58
Tensor& out) {
124
59
(void )ctx;
125
60
126
- torch::executor::Error err = resize_tensor (out, size);
127
- ET_CHECK_MSG (
128
- err == torch::executor::Error::Ok,
129
- " Failed to resize out Tensor in as_strided_copy_out" );
61
+ ET_KERNEL_CHECK (
62
+ ctx,
63
+ check_as_strided_copy_args (self, size, stride, storage_offset, out),
64
+ InvalidArgument,
65
+ out);
66
+
67
+ ET_KERNEL_CHECK (
68
+ ctx,
69
+ resize_tensor (out, size) == torch::executor::Error::Ok,
70
+ InvalidArgument,
71
+ out);
72
+
73
+ if (self.numel () == 0 ) {
74
+ return out;
75
+ }
130
76
131
- check_preconditions (self, size, stride, storage_offset, out);
132
77
size_t offset = storage_offset.has_value () ? storage_offset.value () : 0 ;
133
78
134
- # define AS_STRIDED_COPY_TENSOR ( ctype, dtype ) \
135
- case ScalarType::dtype: \
136
- as_strided_copy<ctype>( \
137
- /* input_data= */ self. mutable_data_ptr <ctype>() + offset, \
138
- /* output_data= */ out. mutable_data_ptr <ctype>(), \
139
- out, \
140
- size, \
141
- stride, \
142
- /* dim= */ 0 ); \
143
- break ;
79
+ ET_SWITCH_ALL_TYPES (self. scalar_type (), ctx, __func__, CTYPE, [&] {
80
+ CTYPE* self_data = self. mutable_data_ptr <CTYPE>() + offset;
81
+ CTYPE* out_data = out. mutable_data_ptr <CTYPE>();
82
+
83
+ if (size. empty ()) {
84
+ out_data[ 0 ] = self_data[ 0 ];
85
+ } else {
86
+ as_strided_copy<CTYPE>(self_data, out_data, out, size, stride, 0 );
87
+ }
88
+ }) ;
144
89
145
- switch (self.scalar_type ()) {
146
- ET_FORALL_SCALAR_TYPES (AS_STRIDED_COPY_TENSOR)
147
- default :
148
- ET_CHECK_MSG (
149
- false ,
150
- " Unhandled dtype %" PRId8,
151
- static_cast <int8_t >(self.scalar_type ()));
152
- }
153
- #undef AS_STRIDED_COPY_TENSOR
154
90
return out;
155
91
}
156
92
0 commit comments