@@ -35,7 +35,9 @@ bool is_index_mask(
35
35
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices) {
36
36
if (indices.size () == 1 && indices[0 ].has_value ()) {
37
37
const Tensor& mask = indices[0 ].value ();
38
- if (mask.scalar_type () == ScalarType::Bool && mask.dim () == tensor.dim ()) {
38
+ if ((mask.scalar_type () == ScalarType::Bool ||
39
+ mask.scalar_type () == ScalarType::Byte) &&
40
+ mask.dim () == tensor.dim ()) {
39
41
return true ;
40
42
}
41
43
}
@@ -49,7 +51,8 @@ size_t get_indices_broadcast_len(
49
51
if (indices[i].has_value ()) {
50
52
const Tensor& index = indices[i].value ();
51
53
size_t len = 0 ;
52
- if (index.scalar_type () == ScalarType::Bool) {
54
+ if (index.scalar_type () == ScalarType::Bool ||
55
+ index.scalar_type () == ScalarType::Byte) {
53
56
len = count_boolean_index (index);
54
57
} else {
55
58
len = index.numel ();
@@ -99,7 +102,7 @@ bool indices_list_is_valid(
99
102
ET_LOG_AND_RETURN_IF_FALSE (
100
103
index_values_are_valid<int64_t >(tensor, i, index));
101
104
}
102
- } else if (idx_type == ScalarType::Bool) {
105
+ } else if (idx_type == ScalarType::Bool || idx_type == ScalarType::Byte ) {
103
106
ET_LOG_MSG_AND_RETURN_IF_FALSE (
104
107
index.numel () == tensor.size (i),
105
108
" indices[%zd].numel() %zd incompatible with input.size(%zd) %zd" ,
@@ -162,7 +165,7 @@ size_t get_index_query_pos_offset(
162
165
index_val += tensor.size (dim);
163
166
}
164
167
offset += index_val * step_len;
165
- } else if (idx_type == ScalarType::Bool) {
168
+ } else if (idx_type == ScalarType::Bool || idx_type == ScalarType::Byte ) {
166
169
const bool * const index_ptr = index.const_data_ptr <bool >();
167
170
// Broadcasting for boolean index tensors
168
171
size_t num_true = count_boolean_index (index);
0 commit comments