Skip to content

Commit fd1a338

Browse files
tarun292facebook-github-bot
authored andcommitted
Refactor select_copy into util (#2784)
Summary: Pull Request resolved: #2784 Refactoring the core logic of select_copy into a helper utility function so that it can be utilized by other operators that internally need to do a select operation. Reviewed By: iseeyuan Differential Revision: D55565485 fbshipit-source-id: 387212824c23613efe45f4be647de15763dad5aa
1 parent 067a829 commit fd1a338

File tree

5 files changed

+123
-59
lines changed

5 files changed

+123
-59
lines changed

kernels/portable/cpu/op_select_copy.cpp

Lines changed: 4 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#include <cstring>
1010

11-
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
11+
#include <executorch/kernels/portable/cpu/util/select_copy_util.h>
1212
#include <executorch/runtime/kernel/kernel_includes.h>
1313

1414
namespace torch {
@@ -23,64 +23,9 @@ Tensor& select_copy_int_out(
2323
int64_t dim,
2424
int64_t index,
2525
Tensor& out) {
26-
(void)ctx;
27-
28-
ET_KERNEL_CHECK(
29-
ctx,
30-
check_select_copy_out_args(in, dim, index, out),
31-
InvalidArgument,
32-
out);
33-
34-
if (dim < 0) {
35-
dim += nonzero_dim(in);
36-
}
37-
38-
Tensor::SizesType target_sizes[kTensorDimensionLimit];
39-
size_t target_ndim = 0;
40-
get_select_copy_out_target_size(in, dim, target_sizes, &target_ndim);
41-
42-
ET_KERNEL_CHECK(
43-
ctx,
44-
resize_tensor(out, {target_sizes, target_ndim}) == Error::Ok,
45-
InvalidArgument,
46-
out);
47-
48-
// If the input is a empty tensor, no other operation could be done. We just
49-
// return the output.
50-
if (in.numel() == 0) {
51-
return out;
52-
}
53-
// The code past this point assumes that the tensors are non-empty.
54-
55-
// Support python-style negative indexing
56-
if (index < 0) {
57-
index += in.size(dim);
58-
}
59-
60-
size_t leading_dims = getLeadingDims(in, dim);
61-
size_t trailing_dims = getTrailingDims(in, dim);
62-
size_t dim_length = in.size(dim);
63-
64-
// Number of bytes to copy in the each memcpy operation
65-
size_t copy_size_per_op = trailing_dims * out.element_size();
66-
67-
// Step between the src locations of two adjcant memcpy operations
68-
size_t src_step_per_op = dim_length * trailing_dims * in.element_size();
69-
70-
// the start point of data need to be copied is the start point of overall
71-
// data chunk plus the offset between the overall start point and the first
72-
// data to be copied.
73-
char* input_data = in.mutable_data_ptr<char>();
74-
75-
size_t start_offset = index * trailing_dims * in.element_size();
76-
char* src = input_data + start_offset;
77-
78-
char* dest = out.mutable_data_ptr<char>();
79-
80-
for (size_t j = 0; j < leading_dims; ++j) {
81-
memcpy(dest, src, copy_size_per_op);
82-
src += src_step_per_op;
83-
dest += copy_size_per_op;
26+
Error err = torch::executor::select_copy_util(in, dim, index, out);
27+
if (err != Error::Ok) {
28+
ctx.fail(err);
8429
}
8530
return out;
8631
}

kernels/portable/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,7 @@ _ATEN_OPS = (
794794
name = "op_select_copy",
795795
deps = [
796796
"//executorch/kernels/portable/cpu/util:copy_ops_util",
797+
"//executorch/kernels/portable/cpu/util:select_copy_util",
797798
],
798799
),
799800
op_target(
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cstring>
10+
11+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
#include "executorch/kernels/portable/cpu/util/select_copy_util.h"
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
using Tensor = exec_aten::Tensor;
19+
20+
Error select_copy_util(
21+
const Tensor& in,
22+
int64_t dim,
23+
int64_t index,
24+
Tensor& out) {
25+
if (!check_select_copy_out_args(in, dim, index, out)) {
26+
return Error::InvalidArgument;
27+
}
28+
29+
if (dim < 0) {
30+
dim += nonzero_dim(in);
31+
}
32+
33+
Tensor::SizesType target_sizes[kTensorDimensionLimit];
34+
size_t target_ndim = 0;
35+
get_select_copy_out_target_size(in, dim, target_sizes, &target_ndim);
36+
37+
if (!(resize_tensor(out, {target_sizes, target_ndim}) == Error::Ok)) {
38+
return Error::InvalidArgument;
39+
}
40+
41+
// If the input is a empty tensor, no other operation could be done. We just
42+
// return the output.
43+
if (in.numel() == 0) {
44+
return Error::Ok;
45+
}
46+
// The code past this point assumes that the tensors are non-empty.
47+
48+
// Support python-style negative indexing
49+
if (index < 0) {
50+
index += in.size(dim);
51+
}
52+
53+
size_t leading_dims = getLeadingDims(in, dim);
54+
size_t trailing_dims = getTrailingDims(in, dim);
55+
size_t dim_length = in.size(dim);
56+
57+
// Number of bytes to copy in the each memcpy operation
58+
size_t copy_size_per_op = trailing_dims * out.element_size();
59+
60+
// Step between the src locations of two adjcant memcpy operations
61+
size_t src_step_per_op = dim_length * trailing_dims * in.element_size();
62+
63+
// the start point of data need to be copied is the start point of overall
64+
// data chunk plus the offset between the overall start point and the first
65+
// data to be copied.
66+
char* input_data = in.mutable_data_ptr<char>();
67+
68+
size_t start_offset = index * trailing_dims * in.element_size();
69+
char* src = input_data + start_offset;
70+
71+
char* dest = out.mutable_data_ptr<char>();
72+
73+
for (size_t j = 0; j < leading_dims; ++j) {
74+
memcpy(dest, src, copy_size_per_op);
75+
src += src_step_per_op;
76+
dest += copy_size_per_op;
77+
}
78+
79+
return Error::Ok;
80+
}
81+
82+
} // namespace executor
83+
} // namespace torch
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
Error select_copy_util(
17+
const Tensor& in,
18+
int64_t dim,
19+
int64_t index,
20+
Tensor& out);
21+
22+
} // namespace executor
23+
} // namespace torch

kernels/portable/cpu/util/targets.bzl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,18 @@ def define_common_targets():
191191
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/quantized/..."],
192192
)
193193

194+
runtime.cxx_library(
195+
name = "select_copy_util",
196+
srcs = ["select_copy_util.cpp"],
197+
exported_headers = ["select_copy_util.h"],
198+
deps = [
199+
":copy_ops_util",
200+
"//executorch/runtime/kernel:kernel_includes",
201+
"//executorch/runtime/core/exec_aten/util:tensor_util",
202+
],
203+
visibility = ["//executorch/kernels/portable/cpu/..."],
204+
)
205+
194206
# Utility functions that can be used by operators that perform reduction
195207
for aten_mode in [True, False]:
196208
suffix = "_aten" if aten_mode else ""

0 commit comments

Comments
 (0)