Skip to content

Commit 87cf2bd

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Dtype compliance: index & index_put
Reviewed By: salilsdesai Differential Revision: D48415576 fbshipit-source-id: 671007c9e6753d0bce8c64341cff9c15c388b432
1 parent 102fe53 commit 87cf2bd

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

kernels/portable/cpu/util/index_util.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ bool is_index_mask(
3535
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices) {
3636
if (indices.size() == 1 && indices[0].has_value()) {
3737
const Tensor& mask = indices[0].value();
38-
if (mask.scalar_type() == ScalarType::Bool && mask.dim() == tensor.dim()) {
38+
if ((mask.scalar_type() == ScalarType::Bool ||
39+
mask.scalar_type() == ScalarType::Byte) &&
40+
mask.dim() == tensor.dim()) {
3941
return true;
4042
}
4143
}
@@ -49,7 +51,8 @@ size_t get_indices_broadcast_len(
4951
if (indices[i].has_value()) {
5052
const Tensor& index = indices[i].value();
5153
size_t len = 0;
52-
if (index.scalar_type() == ScalarType::Bool) {
54+
if (index.scalar_type() == ScalarType::Bool ||
55+
index.scalar_type() == ScalarType::Byte) {
5356
len = count_boolean_index(index);
5457
} else {
5558
len = index.numel();
@@ -99,7 +102,7 @@ bool indices_list_is_valid(
99102
ET_LOG_AND_RETURN_IF_FALSE(
100103
index_values_are_valid<int64_t>(tensor, i, index));
101104
}
102-
} else if (idx_type == ScalarType::Bool) {
105+
} else if (idx_type == ScalarType::Bool || idx_type == ScalarType::Byte) {
103106
ET_LOG_MSG_AND_RETURN_IF_FALSE(
104107
index.numel() == tensor.size(i),
105108
"indices[%zd].numel() %zd incompatible with input.size(%zd) %zd",
@@ -162,7 +165,7 @@ size_t get_index_query_pos_offset(
162165
index_val += tensor.size(dim);
163166
}
164167
offset += index_val * step_len;
165-
} else if (idx_type == ScalarType::Bool) {
168+
} else if (idx_type == ScalarType::Bool || idx_type == ScalarType::Byte) {
166169
const bool* const index_ptr = index.const_data_ptr<bool>();
167170
// Broadcasting for boolean index tensors
168171
size_t num_true = count_boolean_index(index);

0 commit comments

Comments
 (0)