6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
- #include < cstdint>
10
- #include < cstring>
11
-
9
+ #include < executorch/kernels/portable/cpu/util/copy_ops_util.h>
12
10
#include < executorch/runtime/kernel/kernel_includes.h>
11
+ #include < cstring>
13
12
14
13
namespace torch {
15
14
namespace executor {
@@ -19,62 +18,6 @@ using Tensor = exec_aten::Tensor;
19
18
20
19
namespace {
21
20
22
- // TODO(gasoonjia): Move this to a common spot so all implementation of
23
- // this operator can share it. (e.g., DSP-specific)
24
- // / Asserts that the parameters are valid.
25
- void check_slice_copy_Tensor_out_args (
26
- const Tensor input,
27
- int64_t dim,
28
- int64_t num_values,
29
- int64_t step,
30
- Tensor output) {
31
- //
32
- // Check dim. The dim planed to be selected on shall exist in input
33
- ET_CHECK_MSG (
34
- dim >= 0 && dim < input.dim (),
35
- " dim %" PRId64 " out of range [0,%zd)" ,
36
- dim,
37
- input.dim ());
38
-
39
- // Input dtype shall match the output dtype.
40
- ET_CHECK_SAME_DTYPE2 (input, output);
41
-
42
- // The output.dim() shall equal to input.dim(), based on the definition of
43
- // slicing.
44
- ET_CHECK_MSG (
45
- input.dim () == output.dim (),
46
- " input.dim() %zd != output.dim() %zd" ,
47
- input.dim (),
48
- output.dim ());
49
-
50
- // Check step. Step must be greater than zero
51
- ET_CHECK_MSG (step > 0 , " slice step must be greater than zero" );
52
-
53
- // The size of output tensor should follow these rules:
54
- // - output.size(i) shall equal to input.size(i) if i != dim,
55
- // - output.size(dim) shall equal to num_values
56
- for (size_t d = 0 ; d < input.dim () - 1 ; d++) {
57
- if (d != dim) {
58
- ET_CHECK_MSG (
59
- input.size (d) == output.size (d),
60
- " input.size(%zu) %zd != output.size(%zu) %zd | dim = %" PRId64 " )" ,
61
- d,
62
- input.size (d),
63
- d,
64
- output.size (d),
65
- dim);
66
- } else {
67
- ET_CHECK_MSG (
68
- output.size (d) == num_values,
69
- " input.size(%zu) %zd != num_values %" PRId64 " | dim = %" PRId64 " )" ,
70
- d,
71
- input.size (d),
72
- num_values,
73
- dim);
74
- }
75
- }
76
- }
77
-
78
21
int64_t adjust_slice_indices (
79
22
int64_t dim_length,
80
23
int64_t * start,
@@ -111,46 +54,54 @@ int64_t adjust_slice_indices(
111
54
112
55
} // namespace
113
56
114
- // / slice_copy.Tensor_out(Tensor self, int dim=0, int? start=None, int?
115
- // / end=None, int step=1, *, Tensor(a!) out) -> Tensor(a!)
116
- // / -> Tensor(a!)
117
57
Tensor& slice_copy_Tensor_out (
118
58
RuntimeContext& ctx,
119
- const Tensor& input ,
59
+ const Tensor& in ,
120
60
int64_t dim,
121
61
exec_aten::optional<int64_t > start_val,
122
62
exec_aten::optional<int64_t > end_val,
123
63
int64_t step,
124
64
Tensor& out) {
125
65
(void )ctx;
66
+
67
+ ET_KERNEL_CHECK (
68
+ ctx, check_slice_copy_args (in, dim, step, out), InvalidArgument, out);
69
+
126
70
if (dim < 0 ) {
127
- dim += input .dim ();
71
+ dim += in .dim ();
128
72
}
129
73
130
- // If user do not set value to end_val, set end to input .size(dim) (largest
74
+ // If user do not set value to end_val, set end to in .size(dim) (largest
131
75
// value available)
132
- int64_t end = end_val.has_value () ? end_val.value () : input .size (dim);
76
+ int64_t end = end_val.has_value () ? end_val.value () : in .size (dim);
133
77
// If user do not set value to start_val, set start to 0 (smallest value
134
78
// available)
135
79
int64_t start = start_val.has_value () ? start_val.value () : 0 ;
136
80
137
- int64_t num_values =
138
- adjust_slice_indices (input.size (dim), &start, &end, step);
81
+ int64_t num_values = adjust_slice_indices (in.size (dim), &start, &end, step);
139
82
140
- check_slice_copy_Tensor_out_args (input, dim, num_values, step, out);
83
+ Tensor::SizesType target_sizes[kTensorDimensionLimit ];
84
+ size_t target_ndim = 0 ;
85
+ get_slice_copy_out_target_size (
86
+ in, dim, num_values, target_sizes, &target_ndim);
87
+ ET_KERNEL_CHECK (
88
+ ctx,
89
+ resize_tensor (out, {target_sizes, target_ndim}) == Error::Ok,
90
+ InvalidArgument,
91
+ out);
141
92
142
- size_t dim_length = input .size (dim);
93
+ size_t dim_length = in .size (dim);
143
94
144
- size_t leading_dims = getLeadingDims (input , dim);
145
- size_t trailing_dims = getTrailingDims (input , dim);
95
+ size_t leading_dims = getLeadingDims (in , dim);
96
+ size_t trailing_dims = getTrailingDims (in , dim);
146
97
147
98
if (trailing_dims == 0 ) {
148
99
return out;
149
100
}
150
101
151
- size_t length_per_step = trailing_dims * input .element_size ();
102
+ size_t length_per_step = trailing_dims * in .element_size ();
152
103
153
- const char * input_data = input .const_data_ptr <char >();
104
+ const char * input_data = in .const_data_ptr <char >();
154
105
char * dest = out.mutable_data_ptr <char >();
155
106
156
107
for (int i = 0 ; i < leading_dims; i++) {
0 commit comments