Skip to content

Commit 2002490

Browse files
Add op: gather.out
Differential Revision: D61822105 Pull Request resolved: #4939
1 parent 88edab8 commit 2002490

File tree

8 files changed

+543
-0
lines changed

8 files changed

+543
-0
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@
173173

174174
- op: full.out
175175

176+
- op: gather.out
177+
176178
- op: ge.Scalar_out
177179

178180
- op: ge.Tensor_out

kernels/portable/cpu/op_gather.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cinttypes>
10+
#include <cstdint>
11+
#include <cstring>
12+
13+
#include <executorch/kernels/portable/cpu/util/index_util.h>
14+
#include <executorch/runtime/kernel/kernel_includes.h>
15+
16+
namespace torch {
17+
namespace executor {
18+
namespace native {
19+
20+
using Tensor = exec_aten::Tensor;
21+
using ScalarType = exec_aten::ScalarType;
22+
23+
namespace {
24+
25+
template <typename CTYPE>
26+
void gather_helper(
27+
const Tensor& in,
28+
const Tensor& index,
29+
Tensor& out,
30+
int64_t dim) {
31+
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
32+
const long* index_data = index.const_data_ptr<long>();
33+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
34+
35+
if (index.dim() == 0) {
36+
out_data[0] = in_data[index_data[0]];
37+
return;
38+
}
39+
40+
for (size_t ix = 0; ix < index.numel(); ++ix) {
41+
size_t ix_coord[kTensorDimensionLimit];
42+
indexToCoordinate(index, ix, ix_coord);
43+
44+
size_t in_coord[kTensorDimensionLimit];
45+
for (size_t i = 0; i < out.dim(); ++i) {
46+
if (i == dim) {
47+
in_coord[i] = index_data[ix];
48+
} else {
49+
in_coord[i] = ix_coord[i];
50+
}
51+
}
52+
53+
size_t in_ix = coordinateToIndex(in, in_coord);
54+
size_t out_ix = coordinateToIndex(out, ix_coord);
55+
56+
out_data[out_ix] = in_data[in_ix];
57+
}
58+
}
59+
60+
} // namespace
61+
62+
Tensor& gather_out(
63+
RuntimeContext& ctx,
64+
const Tensor& in,
65+
int64_t dim,
66+
const Tensor& index,
67+
bool sparse_grad,
68+
Tensor& out) {
69+
(void)ctx;
70+
71+
ET_KERNEL_CHECK(
72+
ctx,
73+
check_gather_args(in, dim, index, sparse_grad, out),
74+
InvalidArgument,
75+
out);
76+
77+
if (dim < 0) {
78+
dim += nonzero_dim(in);
79+
}
80+
81+
ET_KERNEL_CHECK(
82+
ctx,
83+
resize_tensor(out, index.sizes()) == Error::Ok,
84+
InvalidArgument,
85+
out);
86+
87+
constexpr auto name = "gather.out";
88+
89+
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
90+
gather_helper<CTYPE>(in, index, out, dim);
91+
});
92+
93+
return out;
94+
}
95+
96+
} // namespace native
97+
} // namespace executor
98+
} // namespace torch

kernels/portable/cpu/util/index_util.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,51 @@
1212
namespace torch {
1313
namespace executor {
1414

15+
bool check_gather_args(
16+
const Tensor& in,
17+
int64_t dim,
18+
const Tensor& index,
19+
bool sparse_grad,
20+
Tensor& out) {
21+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
22+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
23+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
24+
index.scalar_type() == ScalarType::Long,
25+
"Expected dypte int64 for index");
26+
if (index.numel() != 0) {
27+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
28+
nonzero_dim(in) == nonzero_dim(index),
29+
"self and index should have the same dimensionality when index is not empty "
30+
"except for the case when one has dimension 0 and the other has dimension 1");
31+
}
32+
33+
// Normalize dim to non-negative value
34+
if (dim < 0) {
35+
dim += nonzero_dim(in);
36+
}
37+
38+
for (size_t d = 0; d < nonzero_dim(in); ++d) {
39+
if (d != dim) {
40+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
41+
nonempty_size(index, d) <= nonempty_size(in, d),
42+
"size of dimension %zd of index should be smaller than the size of that dimension of input if dimension %zd != dim %zd",
43+
d,
44+
d,
45+
(size_t)dim);
46+
}
47+
}
48+
const long* index_data = index.const_data_ptr<long>();
49+
for (size_t i = 0; i < index.numel(); ++i) {
50+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
51+
index_data[i] >= 0 && index_data[i] < nonempty_size(in, dim),
52+
"Index is out of bounds for dimension %zd with size %zd",
53+
(size_t)dim,
54+
nonempty_size(index, dim));
55+
}
56+
57+
return true;
58+
}
59+
1560
bool check_index_select_args(
1661
const Tensor& in,
1762
int64_t dim,

kernels/portable/cpu/util/index_util.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
namespace torch {
1515
namespace executor {
1616

17+
bool check_gather_args(
18+
const Tensor& in,
19+
int64_t dim,
20+
const Tensor& index,
21+
bool sparse_grad,
22+
Tensor& output);
23+
1724
bool check_index_select_args(
1825
const Tensor& in,
1926
int64_t dim,

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,11 @@
392392
- arg_meta: null
393393
kernel_name: torch::executor::full_like_out
394394

395+
- op: gather.out
396+
kernels:
397+
- arg_meta: null
398+
kernel_name: torch::executor::gather_out
399+
395400
- op: ge.Scalar_out
396401
kernels:
397402
- arg_meta: null

0 commit comments

Comments
 (0)