Skip to content

Commit 7b3b485

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Half support for index op
Reviewed By: cccclai Differential Revision: D56543186 fbshipit-source-id: 4fed6b9b3ede3cdcb67a9a52150e3f22cc02b180
1 parent 8ec0af9 commit 7b3b485

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

kernels/portable/cpu/op_index.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,11 @@ Tensor& index_Tensor_out(
4040
if (block_count == 0) {
4141
ET_KERNEL_CHECK(
4242
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
43-
ET_SWITCH_REAL_TYPES_AND(
44-
Bool, in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
45-
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
46-
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
47-
memcpy(out_data, in_data, in.nbytes());
48-
});
43+
ET_SWITCH_REALHB_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
44+
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
45+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
46+
memcpy(out_data, in_data, in.nbytes());
47+
});
4948
return out;
5049
}
5150

@@ -85,20 +84,19 @@ Tensor& index_Tensor_out(
8584
compute_dim_map(in, indices, dim_map, block_count == 1);
8685
compute_index_map(in, indices, ix_map);
8786

88-
ET_SWITCH_REAL_TYPES_AND(
89-
Bool, in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
90-
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
91-
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
92-
93-
for (auto out_ix = 0; out_ix < out.numel(); out_ix++) {
94-
size_t in_ix = 0;
95-
bool success = true;
96-
std::tie(in_ix, success) =
97-
get_in_ix(in, indices, out, out_ix, start, xdim, dim_map, ix_map);
98-
ET_KERNEL_CHECK(ctx, success, InvalidArgument, );
99-
out_data[out_ix] = in_data[in_ix];
100-
}
101-
});
87+
ET_SWITCH_REALHB_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
88+
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
89+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
90+
91+
for (auto out_ix = 0; out_ix < out.numel(); out_ix++) {
92+
size_t in_ix = 0;
93+
bool success = true;
94+
std::tie(in_ix, success) =
95+
get_in_ix(in, indices, out, out_ix, start, xdim, dim_map, ix_map);
96+
ET_KERNEL_CHECK(ctx, success, InvalidArgument, );
97+
out_data[out_ix] = in_data[in_ix];
98+
}
99+
});
102100

103101
return out;
104102
}

0 commit comments

Comments
 (0)