Skip to content

Commit 77cb4a5

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix op index_select (#756)
Summary: Pull Request resolved: #756 ghstack-source-id: 203458981 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D50104697 fbshipit-source-id: 4209839556311ea20141d8dcaa585bc70d47d9a3
1 parent 09e7cba commit 77cb4a5

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

kernels/portable/cpu/op_index_select.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,15 @@ Tensor& index_select_out(
2828
ET_KERNEL_CHECK(
2929
ctx, check_index_select_args(in, dim, index, out), InvalidArgument, out);
3030

31+
if (dim < 0) {
32+
dim += nonzero_dim(in);
33+
}
34+
3135
size_t expected_ndim = 0;
3236
Tensor::SizesType expected_size[kTensorDimensionLimit];
3337
get_index_select_out_target_size(
3438
in, dim, index, expected_size, &expected_ndim);
39+
3540
ET_KERNEL_CHECK(
3641
ctx,
3742
resize_tensor(out, {expected_size, expected_ndim}) == Error::Ok,

kernels/portable/cpu/util/index_util.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ bool check_index_select_args(
279279

280280
if (index.scalar_type() == ScalarType::Long) {
281281
const int64_t* const index_ptr = index.const_data_ptr<int64_t>();
282-
for (size_t i = 1; i < index.numel(); ++i) {
282+
for (size_t i = 0; i < index.numel(); ++i) {
283283
ET_LOG_MSG_AND_RETURN_IF_FALSE(
284284
index_ptr[i] >= 0 && index_ptr[i] < nonempty_size(in, dim),
285285
"index[%zu] = %" PRId64 " is out of range [0, %zd)",
@@ -289,7 +289,7 @@ bool check_index_select_args(
289289
}
290290
} else {
291291
const int32_t* const index_ptr = index.const_data_ptr<int32_t>();
292-
for (size_t i = 1; i < index.numel(); ++i) {
292+
for (size_t i = 0; i < index.numel(); ++i) {
293293
ET_LOG_MSG_AND_RETURN_IF_FALSE(
294294
index_ptr[i] >= 0 && index_ptr[i] < nonempty_size(in, dim),
295295
"index[%zu] = %" PRId32 " is out of range [0, %zd)",

0 commit comments

Comments
 (0)