@@ -22,29 +22,13 @@ using Tensor = exec_aten::Tensor;
22
22
23
23
namespace {
24
24
25
- void check_index_args (
26
- const Tensor& input,
27
- exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
28
- Tensor& output) {
29
- // size of indices must not exceed the number of dimensions
30
- ET_CHECK_MSG (
31
- indices.size () <= input.dim (),
32
- " indices.size() %zd > input.dim() %zd" ,
33
- ssize_t (indices.size ()),
34
- ssize_t (input.dim ()));
35
-
36
- check_indices (input, indices);
37
-
38
- check_index_result_size (input, indices, output);
39
- }
40
-
41
25
template <typename CTYPE_IN, typename CTYPE_OUT>
42
26
void index_out_impl_mask (
43
- const Tensor& input ,
27
+ const Tensor& in ,
44
28
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
45
29
Tensor& out) {
46
30
// Data pointers
47
- const CTYPE_IN* const in_data = input .const_data_ptr <CTYPE_IN>();
31
+ const CTYPE_IN* const in_data = in .const_data_ptr <CTYPE_IN>();
48
32
CTYPE_OUT* const out_data = out.mutable_data_ptr <CTYPE_OUT>();
49
33
50
34
const Tensor& mask = indices[0 ].value ();
@@ -60,11 +44,11 @@ void index_out_impl_mask(
60
44
61
45
template <typename CTYPE_IN, typename CTYPE_OUT>
62
46
void index_out_impl_list (
63
- const Tensor& input ,
47
+ const Tensor& in ,
64
48
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
65
49
Tensor& out) {
66
50
// Data pointers
67
- const CTYPE_IN* const in_data = input .const_data_ptr <CTYPE_IN>();
51
+ const CTYPE_IN* const in_data = in .const_data_ptr <CTYPE_IN>();
68
52
CTYPE_OUT* dst = out.mutable_data_ptr <CTYPE_OUT>();
69
53
70
54
size_t num_idx_queries = get_indices_broadcast_len (indices);
@@ -73,13 +57,13 @@ void index_out_impl_list(
73
57
74
58
// For each index query, align the src and dst pointers to the position
75
59
// described by the query.
76
- size_t offset = get_index_query_pos_offset (idx, input , indices);
60
+ size_t offset = get_index_query_pos_offset (idx, in , indices);
77
61
src += offset;
78
62
79
63
// Calculate the region of data to copy for this query.
80
64
// For example, a 2x4x3x5 tensor indexing at [1, 1, :, :] should copy 15
81
65
// elements.
82
- size_t copy_len = getTrailingDims (input , indices.size () - 1 );
66
+ size_t copy_len = getTrailingDims (in , indices.size () - 1 );
83
67
84
68
for (size_t i = 0 ; i < copy_len; ++i) {
85
69
dst[i] = static_cast <CTYPE_OUT>(src[i]);
@@ -88,107 +72,50 @@ void index_out_impl_list(
88
72
}
89
73
}
90
74
91
- template <typename CTYPE_IN, typename CTYPE_OUT>
92
- void index_out_impl (
93
- const Tensor& input,
94
- exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
95
- Tensor& out) {
96
- if (is_index_mask (input, indices)) {
97
- index_out_impl_mask<CTYPE_IN, CTYPE_OUT>(input, indices, out);
98
- } else {
99
- index_out_impl_list<CTYPE_IN, CTYPE_OUT>(input, indices, out);
100
- }
101
- }
102
-
103
- template <typename CTYPE_IN>
104
- inline void index_out_switch_out (
105
- const Tensor& input,
106
- exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
107
- Tensor& out) {
108
- auto out_type = out.scalar_type ();
109
- #define INDEX_COPY_SWITCH_OUTPUT_CASE (ctype, dtype ) \
110
- case ScalarType::dtype: \
111
- index_out_impl<CTYPE_IN, ctype>(input, indices, out); \
112
- break ;
113
-
114
- switch (out_type) {
115
- ET_FORALL_REAL_TYPES_AND (Bool, INDEX_COPY_SWITCH_OUTPUT_CASE);
116
- default :
117
- ET_CHECK_MSG (
118
- false , " %hhd scalar type is not supported for output" , out_type);
119
- }
120
-
121
- #undef INDEX_COPY_SWITCH_OUTPUT_CASE
122
- }
123
-
124
- inline void index_out_switch_input (
125
- const Tensor& input,
126
- exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
127
- Tensor& out) {
128
- auto input_type = input.scalar_type ();
129
- #define INDEX_COPY_SWITCH_INPUT_CASE (ctype, dtype ) \
130
- case ScalarType::dtype: \
131
- index_out_switch_out<ctype>(input, indices, out); \
132
- break ;
133
-
134
- switch (input_type) {
135
- ET_FORALL_REAL_TYPES_AND (Bool, INDEX_COPY_SWITCH_INPUT_CASE);
136
- default :
137
- ET_CHECK_MSG (
138
- false , " %hhd scalar type is not supported for input" , input_type);
139
- }
140
-
141
- #undef INDEX_COPY_SWITCH_INPUT_CASE
142
- }
143
-
144
- // expected output dim: 1 + (remaining dimension). Shape: [indices.size,
145
- // *remaining dimension shape]. E.g., 3x3x3x3 tensor, index at [(1, 2), (0,
146
- // 2), :, :] gives output shape [2, 3, 3].
147
- Error resize_out (
148
- const Tensor& input,
149
- Tensor& out,
150
- ArrayRef<exec_aten::optional<Tensor>> indices) {
151
- size_t out_ndim = 0 ;
152
- Tensor::SizesType out_sizes[kTensorDimensionLimit ];
153
- get_index_result_size (input, indices, out_sizes, out_ndim);
154
-
155
- ArrayRef<Tensor::SizesType> output_size{out_sizes, out_ndim};
156
- auto error = resize_tensor (out, output_size);
157
-
158
- return error;
159
- }
160
75
} // namespace
161
76
162
- // / aten::index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) ->
163
- // / Tensor(a!)
164
77
Tensor& index_Tensor_out (
165
78
RuntimeContext& ctx,
166
- const Tensor& input ,
79
+ const Tensor& in ,
167
80
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
168
81
Tensor& out) {
169
- (void )ctx;
82
+ ET_KERNEL_CHECK (
83
+ ctx, check_index_args (in, indices, out), InvalidArgument, out);
170
84
171
85
if (indices.empty ()) {
172
- auto error = resize_tensor (out, input. sizes ());
173
- ET_CHECK_MSG (error == Error::Ok, " Failed to resize output tensor. " );
86
+ ET_KERNEL_CHECK (
87
+ ctx, resize_tensor (out, in. sizes ()) == Error::Ok, InvalidArgument, out );
174
88
memcpy (
175
- out.mutable_data_ptr <char >(),
176
- input.const_data_ptr <char >(),
177
- input.nbytes ());
89
+ out.mutable_data_ptr <char >(), in.const_data_ptr <char >(), in.nbytes ());
178
90
return out;
179
91
}
180
92
181
- // resize out tensor
182
- auto error = resize_out (input, out, indices);
183
- // TODO: Construct error message with requested output sizes.
184
- ET_CHECK_MSG (error == Error::Ok, " Failed to resize output tensor." );
185
- check_index_args (input, indices, out);
93
+ size_t expected_ndim = 0 ;
94
+ Tensor::SizesType expected_size[kTensorDimensionLimit ];
95
+ get_index_out_target_size (in, indices, expected_size, &expected_ndim);
96
+ ET_KERNEL_CHECK (
97
+ ctx,
98
+ resize_tensor (out, {expected_size, expected_ndim}) == Error::Ok,
99
+ InvalidArgument,
100
+ out);
101
+
102
+ check_index_args (in, indices, out);
186
103
187
- if (input .numel () == 0 ) {
104
+ if (in .numel () == 0 ) {
188
105
return out;
189
106
}
190
107
191
- index_out_switch_input (input, indices, out);
108
+ ScalarType in_type = in.scalar_type ();
109
+ ScalarType out_type = out.scalar_type ();
110
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, " index" , CTYPE_IN, [&]() {
111
+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, " index" , CTYPE_OUT, [&]() {
112
+ if (is_index_mask (in, indices)) {
113
+ index_out_impl_mask<CTYPE_IN, CTYPE_OUT>(in, indices, out);
114
+ } else {
115
+ index_out_impl_list<CTYPE_IN, CTYPE_OUT>(in, indices, out);
116
+ }
117
+ });
118
+ });
192
119
193
120
return out;
194
121
}
0 commit comments