Skip to content

Commit bf33da4

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: index_select
Reviewed By: SS-JIA Differential Revision: D48289831 fbshipit-source-id: 7fa7a0ef02205379d4bed1a996c1bbecb17aa4eb
1 parent 076fd09 commit bf33da4

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

kernels/portable/cpu/op_index_select.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,12 @@ void check_index_select_args(
6363
index.numel(),
6464
trailing_dims);
6565

66-
// Index should be a 1-D LongTensor, check if any index is out of bound
66+
// Index should be a 1-D Long or Int Tensor. Check if any index is out of
67+
// bound
68+
ScalarType ix_type = index.scalar_type();
6769
ET_CHECK_MSG(
68-
index.scalar_type() == ScalarType::Long, "index scalar_type not long");
70+
ix_type == ScalarType::Long || ix_type == ScalarType::Int,
71+
"Expected index tensor to have Long or Int scalar types");
6972
ET_CHECK_MSG(
7073
index.dim() == 1 || index.dim() == 0,
7174
"index.dim() %zd != 1 or 0",
@@ -138,16 +141,22 @@ Tensor& index_select_out(
138141

139142
const char* input_data = input.const_data_ptr<char>();
140143
char* out_data = out.mutable_data_ptr<char>();
141-
const int64_t* index_arr = index.mutable_data_ptr<int64_t>();
142-
for (int i = 0; i < leading_dims; i++) {
143-
const char* src = input_data + i * in_dim_length * length_per_step;
144-
char* dest = out_data + i * out_dim_length * length_per_step;
145-
for (auto j = 0; j < out_dim_length; j++) {
146-
const char* copy_src = src + index_arr[j] * length_per_step;
147-
memcpy(dest, copy_src, length_per_step);
148-
dest += length_per_step;
144+
145+
ScalarType ix_type = index.scalar_type();
146+
147+
ET_SWITCH_TWO_TYPES(Long, Int, ix_type, ctx, __func__, CTYPE, [&]() {
148+
const CTYPE* const index_arr = index.mutable_data_ptr<CTYPE>();
149+
for (int i = 0; i < leading_dims; i++) {
150+
const char* src = input_data + i * in_dim_length * length_per_step;
151+
char* dest = out_data + i * out_dim_length * length_per_step;
152+
for (auto j = 0; j < out_dim_length; j++) {
153+
const char* copy_src = src + index_arr[j] * length_per_step;
154+
memcpy(dest, copy_src, length_per_step);
155+
dest += length_per_step;
156+
}
149157
}
150-
}
158+
});
159+
151160
return out;
152161
}
153162

0 commit comments

Comments
 (0)