Skip to content

Commit 102fe53

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Modernize indexing kernels
Reviewed By: manuelcandales Differential Revision: D48405501 fbshipit-source-id: 672ce23825a364200949e66355b4b9636ed38963
1 parent 16f3933 commit 102fe53

File tree

7 files changed

+309
-431
lines changed

7 files changed

+309
-431
lines changed

kernels/portable/cpu/op_index.cpp

Lines changed: 34 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,13 @@ using Tensor = exec_aten::Tensor;
2222

2323
namespace {
2424

25-
void check_index_args(
26-
const Tensor& input,
27-
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
28-
Tensor& output) {
29-
// size of indices must not exceed the number of dimensions
30-
ET_CHECK_MSG(
31-
indices.size() <= input.dim(),
32-
"indices.size() %zd > input.dim() %zd",
33-
ssize_t(indices.size()),
34-
ssize_t(input.dim()));
35-
36-
check_indices(input, indices);
37-
38-
check_index_result_size(input, indices, output);
39-
}
40-
4125
template <typename CTYPE_IN, typename CTYPE_OUT>
4226
void index_out_impl_mask(
43-
const Tensor& input,
27+
const Tensor& in,
4428
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
4529
Tensor& out) {
4630
// Data pointers
47-
const CTYPE_IN* const in_data = input.const_data_ptr<CTYPE_IN>();
31+
const CTYPE_IN* const in_data = in.const_data_ptr<CTYPE_IN>();
4832
CTYPE_OUT* const out_data = out.mutable_data_ptr<CTYPE_OUT>();
4933

5034
const Tensor& mask = indices[0].value();
@@ -60,11 +44,11 @@ void index_out_impl_mask(
6044

6145
template <typename CTYPE_IN, typename CTYPE_OUT>
6246
void index_out_impl_list(
63-
const Tensor& input,
47+
const Tensor& in,
6448
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
6549
Tensor& out) {
6650
// Data pointers
67-
const CTYPE_IN* const in_data = input.const_data_ptr<CTYPE_IN>();
51+
const CTYPE_IN* const in_data = in.const_data_ptr<CTYPE_IN>();
6852
CTYPE_OUT* dst = out.mutable_data_ptr<CTYPE_OUT>();
6953

7054
size_t num_idx_queries = get_indices_broadcast_len(indices);
@@ -73,13 +57,13 @@ void index_out_impl_list(
7357

7458
// For each index query, align the src and dst pointers to the position
7559
// described by the query.
76-
size_t offset = get_index_query_pos_offset(idx, input, indices);
60+
size_t offset = get_index_query_pos_offset(idx, in, indices);
7761
src += offset;
7862

7963
// Calculate the region of data to copy for this query.
8064
// For example, a 2x4x3x5 tensor indexing at [1, 1, :, :] should copy 15
8165
// elements.
82-
size_t copy_len = getTrailingDims(input, indices.size() - 1);
66+
size_t copy_len = getTrailingDims(in, indices.size() - 1);
8367

8468
for (size_t i = 0; i < copy_len; ++i) {
8569
dst[i] = static_cast<CTYPE_OUT>(src[i]);
@@ -88,107 +72,50 @@ void index_out_impl_list(
8872
}
8973
}
9074

91-
template <typename CTYPE_IN, typename CTYPE_OUT>
92-
void index_out_impl(
93-
const Tensor& input,
94-
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
95-
Tensor& out) {
96-
if (is_index_mask(input, indices)) {
97-
index_out_impl_mask<CTYPE_IN, CTYPE_OUT>(input, indices, out);
98-
} else {
99-
index_out_impl_list<CTYPE_IN, CTYPE_OUT>(input, indices, out);
100-
}
101-
}
102-
103-
template <typename CTYPE_IN>
104-
inline void index_out_switch_out(
105-
const Tensor& input,
106-
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
107-
Tensor& out) {
108-
auto out_type = out.scalar_type();
109-
#define INDEX_COPY_SWITCH_OUTPUT_CASE(ctype, dtype) \
110-
case ScalarType::dtype: \
111-
index_out_impl<CTYPE_IN, ctype>(input, indices, out); \
112-
break;
113-
114-
switch (out_type) {
115-
ET_FORALL_REAL_TYPES_AND(Bool, INDEX_COPY_SWITCH_OUTPUT_CASE);
116-
default:
117-
ET_CHECK_MSG(
118-
false, "%hhd scalar type is not supported for output", out_type);
119-
}
120-
121-
#undef INDEX_COPY_SWITCH_OUTPUT_CASE
122-
}
123-
124-
inline void index_out_switch_input(
125-
const Tensor& input,
126-
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
127-
Tensor& out) {
128-
auto input_type = input.scalar_type();
129-
#define INDEX_COPY_SWITCH_INPUT_CASE(ctype, dtype) \
130-
case ScalarType::dtype: \
131-
index_out_switch_out<ctype>(input, indices, out); \
132-
break;
133-
134-
switch (input_type) {
135-
ET_FORALL_REAL_TYPES_AND(Bool, INDEX_COPY_SWITCH_INPUT_CASE);
136-
default:
137-
ET_CHECK_MSG(
138-
false, "%hhd scalar type is not supported for input", input_type);
139-
}
140-
141-
#undef INDEX_COPY_SWITCH_INPUT_CASE
142-
}
143-
144-
// expected output dim: 1 + (remaining dimension). Shape: [indices.size,
145-
// *remaining dimension shape]. E.g., 3x3x3x3 tensor, index at [(1, 2), (0,
146-
// 2), :, :] gives output shape [2, 3, 3].
147-
Error resize_out(
148-
const Tensor& input,
149-
Tensor& out,
150-
ArrayRef<exec_aten::optional<Tensor>> indices) {
151-
size_t out_ndim = 0;
152-
Tensor::SizesType out_sizes[kTensorDimensionLimit];
153-
get_index_result_size(input, indices, out_sizes, out_ndim);
154-
155-
ArrayRef<Tensor::SizesType> output_size{out_sizes, out_ndim};
156-
auto error = resize_tensor(out, output_size);
157-
158-
return error;
159-
}
16075
} // namespace
16176

162-
/// aten::index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) ->
163-
/// Tensor(a!)
16477
Tensor& index_Tensor_out(
16578
RuntimeContext& ctx,
166-
const Tensor& input,
79+
const Tensor& in,
16780
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
16881
Tensor& out) {
169-
(void)ctx;
82+
ET_KERNEL_CHECK(
83+
ctx, check_index_args(in, indices, out), InvalidArgument, out);
17084

17185
if (indices.empty()) {
172-
auto error = resize_tensor(out, input.sizes());
173-
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
86+
ET_KERNEL_CHECK(
87+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
17488
memcpy(
175-
out.mutable_data_ptr<char>(),
176-
input.const_data_ptr<char>(),
177-
input.nbytes());
89+
out.mutable_data_ptr<char>(), in.const_data_ptr<char>(), in.nbytes());
17890
return out;
17991
}
18092

181-
// resize out tensor
182-
auto error = resize_out(input, out, indices);
183-
// TODO: Construct error message with requested output sizes.
184-
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
185-
check_index_args(input, indices, out);
93+
size_t expected_ndim = 0;
94+
Tensor::SizesType expected_size[kTensorDimensionLimit];
95+
get_index_out_target_size(in, indices, expected_size, &expected_ndim);
96+
ET_KERNEL_CHECK(
97+
ctx,
98+
resize_tensor(out, {expected_size, expected_ndim}) == Error::Ok,
99+
InvalidArgument,
100+
out);
101+
102+
check_index_args(in, indices, out);
186103

187-
if (input.numel() == 0) {
104+
if (in.numel() == 0) {
188105
return out;
189106
}
190107

191-
index_out_switch_input(input, indices, out);
108+
ScalarType in_type = in.scalar_type();
109+
ScalarType out_type = out.scalar_type();
110+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "index", CTYPE_IN, [&]() {
111+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "index", CTYPE_OUT, [&]() {
112+
if (is_index_mask(in, indices)) {
113+
index_out_impl_mask<CTYPE_IN, CTYPE_OUT>(in, indices, out);
114+
} else {
115+
index_out_impl_list<CTYPE_IN, CTYPE_OUT>(in, indices, out);
116+
}
117+
});
118+
});
192119

193120
return out;
194121
}

0 commit comments

Comments
 (0)