Skip to content

Commit 3f535e0

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Add split_with_sizes_copy (#537)
Summary: Pull Request resolved: #537 Adds an implementation of `aten.split_with_sizes_copy`. The implementation is identical to the existing implementation for `split.Tensor`. The only difference is the input checking and output resizing functions. Reviewed By: manuelcandales Differential Revision: D49763676 fbshipit-source-id: a1b58d6d78e3fc80ab988d21c80756bf3263b6b4
1 parent e7228d4 commit 3f535e0

File tree

8 files changed

+260
-0
lines changed

8 files changed

+260
-0
lines changed

exir/dialects/edge/op/sample_input.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,23 @@
11501150
),
11511151
],
11521152
},
1153+
"split_with_sizes_copy.default": { # (Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
1154+
"args": [
1155+
InArg(ArgType.Tensor, size=[2, 6, 3]),
1156+
InArg(ArgType.LengthList, value=[3, 1, 2]),
1157+
InArg(ArgType.Dim, value=1),
1158+
],
1159+
"returns": [
1160+
Return(
1161+
ArgType.TensorList,
1162+
value=[
1163+
Return(ArgType.Tensor, size=[2, 3, 3]),
1164+
Return(ArgType.Tensor, size=[2, 1, 3]),
1165+
Return(ArgType.Tensor, size=[2, 2, 3]),
1166+
],
1167+
),
1168+
],
1169+
},
11531170
"sqrt.default": { # (Tensor self) -> Tensor
11541171
"args": [
11551172
InArg(ArgType.Tensor),
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cstdint>
10+
#include <cstring>
11+
12+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using Tensor = exec_aten::Tensor;
20+
using TensorList = exec_aten::TensorList;
21+
22+
void split_with_sizes_copy_out(
23+
RuntimeContext& ctx,
24+
const Tensor& in,
25+
exec_aten::ArrayRef<int64_t> split_sizes,
26+
int64_t dim,
27+
TensorList out) {
28+
(void)ctx;
29+
// Support python-style negative indexing. Note that this op does not accept 0
30+
// dimensional input tensors.
31+
if (dim < 0) {
32+
dim += in.dim();
33+
}
34+
35+
ET_KERNEL_CHECK(
36+
ctx,
37+
check_split_with_sizes_copy_args(in, split_sizes, dim, out),
38+
InvalidArgument,
39+
out);
40+
41+
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
42+
size_t expected_out_dim = 0;
43+
for (size_t i = 0; i < split_sizes.size(); i++) {
44+
expected_out_size[expected_out_dim++] = split_sizes[i];
45+
get_split_with_sizes_copy_out_target_size(
46+
in, split_sizes[i], dim, expected_out_size, &expected_out_dim);
47+
ET_KERNEL_CHECK(
48+
ctx,
49+
resize_tensor(out[i], {expected_out_size, expected_out_dim}) ==
50+
Error::Ok,
51+
InvalidArgument,
52+
out);
53+
}
54+
55+
const size_t leading_dims = getLeadingDims(in, dim);
56+
const size_t trailing_dims = getTrailingDims(in, dim);
57+
const size_t step = in.size(dim) * trailing_dims;
58+
59+
ScalarType in_type = in.scalar_type();
60+
ScalarType out_type = out[0].scalar_type();
61+
62+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE_IN, [&]() {
63+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() {
64+
const CTYPE_IN* in_data = in.const_data_ptr<CTYPE_IN>();
65+
for (size_t i = 0, e = out.size(); i < e; ++i) {
66+
size_t out_step = out[i].size(dim) * trailing_dims;
67+
if (out_step == 0) {
68+
continue;
69+
}
70+
const CTYPE_IN* src = in_data;
71+
CTYPE_OUT* dest = out[i].mutable_data_ptr<CTYPE_OUT>();
72+
for (size_t j = 0; j < leading_dims; ++j) {
73+
for (size_t k = 0; k < out_step; ++k) {
74+
dest[k] = convert<CTYPE_OUT, CTYPE_IN>(src[k]);
75+
}
76+
src += step;
77+
dest += out_step;
78+
}
79+
in_data += out_step;
80+
}
81+
});
82+
});
83+
}
84+
85+
} // namespace native
86+
} // namespace executor
87+
} // namespace torch

kernels/portable/cpu/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,12 @@ _ATEN_OPS = (
691691
op_target(
692692
name = "op_split_copy",
693693
),
694+
op_target(
695+
name = "op_split_with_sizes_copy",
696+
deps = [
697+
"//executorch/kernels/portable/cpu/util:copy_ops_util",
698+
],
699+
),
694700
op_target(
695701
name = "op_sqrt",
696702
deps = [

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,47 @@ void get_pixel_shuffle_out_target_size(
163163
out_sizes[i] = in.size(i) * casted_upscale_factor;
164164
}
165165

166+
bool check_split_with_sizes_copy_args(
167+
const Tensor& in,
168+
exec_aten::ArrayRef<int64_t> split_sizes,
169+
int64_t dim,
170+
TensorList out) {
171+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(in, 1));
172+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
173+
174+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
175+
split_sizes.size() == out.size(),
176+
"Number of split sizes must match the number of output tensors");
177+
178+
int64_t sum = 0;
179+
for (int i = 0; i < split_sizes.size(); i++) {
180+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
181+
split_sizes[i] >= 0, "All split sizes must be non negative.");
182+
sum += split_sizes[i];
183+
}
184+
185+
const ssize_t dim_size = in.size(dim);
186+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
187+
sum == dim_size,
188+
"Sum of split sizes does not match input size at given dim");
189+
190+
return true;
191+
}
192+
193+
void get_split_with_sizes_copy_out_target_size(
194+
const Tensor& in,
195+
int64_t split_size,
196+
int64_t dim,
197+
Tensor::SizesType* out_sizes,
198+
size_t* out_ndim) {
199+
*out_ndim = in.dim();
200+
201+
for (size_t d = 0; d < in.dim(); ++d) {
202+
out_sizes[d] = in.size(d);
203+
}
204+
out_sizes[dim] = split_size;
205+
}
206+
166207
bool check_stack_args(
167208
exec_aten::ArrayRef<Tensor> tensors,
168209
int64_t dim,

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@ void get_pixel_shuffle_out_target_size(
4343
Tensor::SizesType* out_sizes,
4444
size_t* out_ndim);
4545

46+
bool check_split_with_sizes_copy_args(
47+
const Tensor& in,
48+
exec_aten::ArrayRef<int64_t> split_sizes,
49+
int64_t dim,
50+
TensorList out);
51+
52+
void get_split_with_sizes_copy_out_target_size(
53+
const Tensor& in,
54+
int64_t split_size,
55+
int64_t dim,
56+
Tensor::SizesType* out_sizes,
57+
size_t* out_ndim);
58+
4659
bool check_stack_args(
4760
exec_aten::ArrayRef<Tensor> tensors,
4861
int64_t dim,

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,11 @@
632632
- arg_meta: null
633633
kernel_name: torch::executor::split_copy_Tensor_out
634634

635+
- op: split_with_sizes_copy.out
636+
kernels:
637+
- arg_meta: null
638+
kernel_name: torch::executor::split_with_sizes_copy_out
639+
635640
- op: sqrt.out
636641
kernels:
637642
- arg_meta: null
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/kernels/test/supported_features.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15+
16+
#include <gtest/gtest.h>
17+
18+
using namespace ::testing;
19+
20+
void op_split_with_sizes_copy_out(
21+
const exec_aten::Tensor& self,
22+
exec_aten::ArrayRef<int64_t> split_sizes,
23+
int64_t dim,
24+
exec_aten::TensorList out) {
25+
exec_aten::RuntimeContext context{};
26+
return torch::executor::aten::split_with_sizes_copy_outf(
27+
context, self, split_sizes, dim, out);
28+
}
29+
30+
TEST(OpSplitWithSizesCopyOutTest, SanityCheckDim1) {
31+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
32+
33+
exec_aten::Tensor self = tfFloat.make(
34+
{2, 6, 3},
35+
{-31.25, -92.75, -39.75, -3.25, 53.875, 88.25, -0.625, -1.125,
36+
14.75, 42.0, 89.875, -21.125, -8.0, -64.125, 23.0, 37.0,
37+
46.125, -83.25, -58.125, 19.625, -71.125, 64.75, -1.375, -83.5,
38+
-61.375, 13.125, 28.625, -94.0, -67.0, -8.625, -88.875, -79.125,
39+
0.375, -61.375, 65.0, -99.375});
40+
::std::vector<int64_t> split_sizes_vec = {3, 1, 2};
41+
exec_aten::ArrayRef<int64_t> split_sizes = exec_aten::ArrayRef<int64_t>(
42+
split_sizes_vec.data(), split_sizes_vec.size());
43+
int64_t dim = 1;
44+
::std::vector<exec_aten::Tensor> out_vec = {
45+
tfFloat.zeros({2, 3, 3}),
46+
tfFloat.zeros({2, 1, 3}),
47+
tfFloat.zeros({2, 2, 3})};
48+
exec_aten::TensorList out =
49+
exec_aten::TensorList(out_vec.data(), out_vec.size());
50+
::std::vector<exec_aten::Tensor> out_expected_vec = {
51+
tfFloat.make(
52+
{2, 3, 3},
53+
{-31.25,
54+
-92.75,
55+
-39.75,
56+
-3.25,
57+
53.875,
58+
88.25,
59+
-0.625,
60+
-1.125,
61+
14.75,
62+
-58.125,
63+
19.625,
64+
-71.125,
65+
64.75,
66+
-1.375,
67+
-83.5,
68+
-61.375,
69+
13.125,
70+
28.625}),
71+
tfFloat.make({2, 1, 3}, {42.0, 89.875, -21.125, -94.0, -67.0, -8.625}),
72+
tfFloat.make(
73+
{2, 2, 3},
74+
{-8.0,
75+
-64.125,
76+
23.0,
77+
37.0,
78+
46.125,
79+
-83.25,
80+
-88.875,
81+
-79.125,
82+
0.375,
83+
-61.375,
84+
65.0,
85+
-99.375})};
86+
exec_aten::TensorList out_expected =
87+
exec_aten::TensorList(out_expected_vec.data(), out_expected_vec.size());
88+
op_split_with_sizes_copy_out(self, split_sizes, dim, out);
89+
EXPECT_TENSOR_LISTS_CLOSE(out, out_expected);
90+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def define_common_targets():
260260
_common_op_test("op_slice_copy_test", ["aten", "portable"])
261261
_common_op_test("op_softmax_test", ["aten", "portable"])
262262
_common_op_test("op_split_copy_test", ["aten", "portable"])
263+
_common_op_test("op_split_with_sizes_copy_test", ["aten", "portable"])
263264
_common_op_test("op_sqrt_test", ["aten", "portable"])
264265
_common_op_test("op_squeeze_copy_test", ["aten", "portable"])
265266
_common_op_test("op_stack_test", ["aten", "portable"])

0 commit comments

Comments
 (0)