@@ -63,9 +63,12 @@ void check_index_select_args(
63
63
index.numel (),
64
64
trailing_dims);
65
65
66
- // Index should be a 1-D LongTensor, check if any index is out of bound
66
+ // Index should be a 1-D Long or Int Tensor. Check if any index is out of
67
+ // bound
68
+ ScalarType ix_type = index.scalar_type ();
67
69
ET_CHECK_MSG (
68
- index.scalar_type () == ScalarType::Long, " index scalar_type not long" );
70
+ ix_type == ScalarType::Long || ix_type == ScalarType::Int,
71
+ " Expected index tensor to have Long or Int scalar types" );
69
72
ET_CHECK_MSG (
70
73
index.dim () == 1 || index.dim () == 0 ,
71
74
" index.dim() %zd != 1 or 0" ,
@@ -138,16 +141,22 @@ Tensor& index_select_out(
138
141
139
142
const char * input_data = input.const_data_ptr <char >();
140
143
char * out_data = out.mutable_data_ptr <char >();
141
- const int64_t * index_arr = index.mutable_data_ptr <int64_t >();
142
- for (int i = 0 ; i < leading_dims; i++) {
143
- const char * src = input_data + i * in_dim_length * length_per_step;
144
- char * dest = out_data + i * out_dim_length * length_per_step;
145
- for (auto j = 0 ; j < out_dim_length; j++) {
146
- const char * copy_src = src + index_arr[j] * length_per_step;
147
- memcpy (dest, copy_src, length_per_step);
148
- dest += length_per_step;
144
+
145
+ ScalarType ix_type = index.scalar_type ();
146
+
147
+ ET_SWITCH_TWO_TYPES (Long, Int, ix_type, ctx, __func__, CTYPE, [&]() {
148
+ const CTYPE* const index_arr = index.mutable_data_ptr <CTYPE>();
149
+ for (int i = 0 ; i < leading_dims; i++) {
150
+ const char * src = input_data + i * in_dim_length * length_per_step;
151
+ char * dest = out_data + i * out_dim_length * length_per_step;
152
+ for (auto j = 0 ; j < out_dim_length; j++) {
153
+ const char * copy_src = src + index_arr[j] * length_per_step;
154
+ memcpy (dest, copy_src, length_per_step);
155
+ dest += length_per_step;
156
+ }
149
157
}
150
- }
158
+ });
159
+
151
160
return out;
152
161
}
153
162
0 commit comments