Skip to content

Commit 6a6733d

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix & cleanup op slice_copy (#703)
Summary: Pull Request resolved: #703 Resize out tensor ghstack-source-id: 203341580 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D49794150 fbshipit-source-id: 7e351595aa71f35336322db342388728d070afa4
1 parent 7f36c70 commit 6a6733d

File tree

4 files changed

+67
-74
lines changed

4 files changed

+67
-74
lines changed

kernels/portable/cpu/op_slice_copy.cpp

Lines changed: 25 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cstdint>
10-
#include <cstring>
11-
9+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
1210
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <cstring>
1312

1413
namespace torch {
1514
namespace executor {
@@ -19,62 +18,6 @@ using Tensor = exec_aten::Tensor;
1918

2019
namespace {
2120

22-
// TODO(gasoonjia): Move this to a common spot so all implementation of
23-
// this operator can share it. (e.g., DSP-specific)
24-
/// Asserts that the parameters are valid.
25-
void check_slice_copy_Tensor_out_args(
26-
const Tensor input,
27-
int64_t dim,
28-
int64_t num_values,
29-
int64_t step,
30-
Tensor output) {
31-
//
32-
// Check dim. The dim planed to be selected on shall exist in input
33-
ET_CHECK_MSG(
34-
dim >= 0 && dim < input.dim(),
35-
"dim %" PRId64 " out of range [0,%zd)",
36-
dim,
37-
input.dim());
38-
39-
// Input dtype shall match the output dtype.
40-
ET_CHECK_SAME_DTYPE2(input, output);
41-
42-
// The output.dim() shall equal to input.dim(), based on the definition of
43-
// slicing.
44-
ET_CHECK_MSG(
45-
input.dim() == output.dim(),
46-
"input.dim() %zd != output.dim() %zd",
47-
input.dim(),
48-
output.dim());
49-
50-
// Check step. Step must be greater than zero
51-
ET_CHECK_MSG(step > 0, "slice step must be greater than zero");
52-
53-
// The size of output tensor should follow these rules:
54-
// - output.size(i) shall equal to input.size(i) if i != dim,
55-
// - output.size(dim) shall equal to num_values
56-
for (size_t d = 0; d < input.dim() - 1; d++) {
57-
if (d != dim) {
58-
ET_CHECK_MSG(
59-
input.size(d) == output.size(d),
60-
"input.size(%zu) %zd != output.size(%zu) %zd | dim = %" PRId64 ")",
61-
d,
62-
input.size(d),
63-
d,
64-
output.size(d),
65-
dim);
66-
} else {
67-
ET_CHECK_MSG(
68-
output.size(d) == num_values,
69-
"input.size(%zu) %zd != num_values %" PRId64 " | dim = %" PRId64 ")",
70-
d,
71-
input.size(d),
72-
num_values,
73-
dim);
74-
}
75-
}
76-
}
77-
7821
int64_t adjust_slice_indices(
7922
int64_t dim_length,
8023
int64_t* start,
@@ -111,46 +54,54 @@ int64_t adjust_slice_indices(
11154

11255
} // namespace
11356

114-
/// slice_copy.Tensor_out(Tensor self, int dim=0, int? start=None, int?
115-
/// end=None, int step=1, *, Tensor(a!) out) -> Tensor(a!)
116-
/// -> Tensor(a!)
11757
Tensor& slice_copy_Tensor_out(
11858
RuntimeContext& ctx,
119-
const Tensor& input,
59+
const Tensor& in,
12060
int64_t dim,
12161
exec_aten::optional<int64_t> start_val,
12262
exec_aten::optional<int64_t> end_val,
12363
int64_t step,
12464
Tensor& out) {
12565
(void)ctx;
66+
67+
ET_KERNEL_CHECK(
68+
ctx, check_slice_copy_args(in, dim, step, out), InvalidArgument, out);
69+
12670
if (dim < 0) {
127-
dim += input.dim();
71+
dim += in.dim();
12872
}
12973

130-
// If user do not set value to end_val, set end to input.size(dim) (largest
74+
// If user do not set value to end_val, set end to in.size(dim) (largest
13175
// value available)
132-
int64_t end = end_val.has_value() ? end_val.value() : input.size(dim);
76+
int64_t end = end_val.has_value() ? end_val.value() : in.size(dim);
13377
// If user do not set value to start_val, set start to 0 (smallest value
13478
// available)
13579
int64_t start = start_val.has_value() ? start_val.value() : 0;
13680

137-
int64_t num_values =
138-
adjust_slice_indices(input.size(dim), &start, &end, step);
81+
int64_t num_values = adjust_slice_indices(in.size(dim), &start, &end, step);
13982

140-
check_slice_copy_Tensor_out_args(input, dim, num_values, step, out);
83+
Tensor::SizesType target_sizes[kTensorDimensionLimit];
84+
size_t target_ndim = 0;
85+
get_slice_copy_out_target_size(
86+
in, dim, num_values, target_sizes, &target_ndim);
87+
ET_KERNEL_CHECK(
88+
ctx,
89+
resize_tensor(out, {target_sizes, target_ndim}) == Error::Ok,
90+
InvalidArgument,
91+
out);
14192

142-
size_t dim_length = input.size(dim);
93+
size_t dim_length = in.size(dim);
14394

144-
size_t leading_dims = getLeadingDims(input, dim);
145-
size_t trailing_dims = getTrailingDims(input, dim);
95+
size_t leading_dims = getLeadingDims(in, dim);
96+
size_t trailing_dims = getTrailingDims(in, dim);
14697

14798
if (trailing_dims == 0) {
14899
return out;
149100
}
150101

151-
size_t length_per_step = trailing_dims * input.element_size();
102+
size_t length_per_step = trailing_dims * in.element_size();
152103

153-
const char* input_data = input.const_data_ptr<char>();
104+
const char* input_data = in.const_data_ptr<char>();
154105
char* dest = out.mutable_data_ptr<char>();
155106

156107
for (int i = 0; i < leading_dims; i++) {

kernels/portable/cpu/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,9 @@ _ATEN_OPS = (
676676
),
677677
op_target(
678678
name = "op_slice_copy",
679+
deps = [
680+
"//executorch/kernels/portable/cpu/util:copy_ops_util",
681+
],
679682
),
680683
op_target(
681684
name = "op_slice_scatter",

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,32 @@ void get_pixel_shuffle_out_target_size(
217217
out_sizes[i] = in.size(i) * casted_upscale_factor;
218218
}
219219

220+
bool check_slice_copy_args(
221+
const Tensor& in,
222+
int64_t dim,
223+
int64_t step,
224+
Tensor& out) {
225+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
226+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
227+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
228+
step > 0, "slice step must be greater than zero");
229+
return true;
230+
}
231+
232+
void get_slice_copy_out_target_size(
233+
const Tensor& in,
234+
int64_t dim,
235+
int64_t num_values,
236+
Tensor::SizesType* out_sizes,
237+
size_t* out_ndim) {
238+
*out_ndim = in.dim();
239+
240+
for (size_t d = 0; d < in.dim(); ++d) {
241+
out_sizes[d] = in.size(d);
242+
}
243+
out_sizes[dim] = num_values;
244+
}
245+
220246
bool check_split_with_sizes_copy_args(
221247
const Tensor& in,
222248
exec_aten::ArrayRef<int64_t> split_sizes,

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,19 @@ void get_pixel_shuffle_out_target_size(
5050
Tensor::SizesType* out_sizes,
5151
size_t* out_ndim);
5252

53+
bool check_slice_copy_args(
54+
const Tensor& in,
55+
int64_t dim,
56+
int64_t step,
57+
Tensor& out);
58+
59+
void get_slice_copy_out_target_size(
60+
const Tensor& in,
61+
int64_t dim,
62+
int64_t num_values,
63+
Tensor::SizesType* out_sizes,
64+
size_t* out_ndim);
65+
5366
bool check_split_with_sizes_copy_args(
5467
const Tensor& in,
5568
exec_aten::ArrayRef<int64_t> split_sizes,

0 commit comments

Comments
 (0)