Skip to content

Commit bd768bf

Browse files
Add op: masked_select
Differential Revision: D65497030 Pull Request resolved: #6670
1 parent 9393b8c commit bd768bf

File tree

6 files changed

+277
-0
lines changed

6 files changed

+277
-0
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@
243243

244244
- op: masked_scatter.out
245245

246+
- op: masked_select.out
247+
246248
- op: max_pool2d_with_indices.out
247249

248250
- op: max.dim_max
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
16+
Tensor& masked_select_out(
17+
KernelRuntimeContext& ctx,
18+
const Tensor& in,
19+
const Tensor& mask,
20+
Tensor& out) {
21+
ScalarType in_type = in.scalar_type();
22+
23+
ET_KERNEL_CHECK(
24+
ctx,
25+
executorch::runtime::tensor_is_realhbbf16_type(in),
26+
InvalidArgument,
27+
out);
28+
29+
ET_KERNEL_CHECK(
30+
ctx, mask.scalar_type() == ScalarType::Bool, InvalidArgument, out);
31+
ET_KERNEL_CHECK(ctx, out.scalar_type() == in_type, InvalidArgument, out);
32+
33+
ET_KERNEL_CHECK(
34+
ctx, tensors_have_same_dim_order(in, mask, out), InvalidArgument, out);
35+
36+
ET_KERNEL_CHECK(
37+
ctx, tensors_are_broadcastable_between(in, mask), InvalidArgument, out);
38+
39+
// If input or mask is empty, the output should be empty
40+
if (in.numel() == 0 || mask.numel() == 0) {
41+
ET_KERNEL_CHECK(
42+
ctx, resize_tensor(out, {0}) == Error::Ok, InvalidArgument, out);
43+
return out;
44+
}
45+
46+
// Compute the shape resulting from broadcasting the mask against the input
47+
size_t broadcast_ndim = 0;
48+
Tensor::SizesType broadcast_sizes[kTensorDimensionLimit];
49+
Error err = get_broadcast_target_size(
50+
in, mask, broadcast_sizes, kTensorDimensionLimit, &broadcast_ndim);
51+
if (err != Error::Ok) {
52+
ET_KERNEL_CHECK_MSG(
53+
ctx, false, InvalidArgument, out, "Failed to broadcast input and mask");
54+
}
55+
size_t broadcast_numel = 1;
56+
for (size_t i = 0; i < broadcast_ndim; i++) {
57+
broadcast_numel *= broadcast_sizes[i];
58+
}
59+
60+
// Compute the number of out elements
61+
size_t mask_true_count = 0;
62+
const bool* const mask_data = mask.const_data_ptr<bool>();
63+
for (size_t i = 0; i < mask.numel(); ++i) {
64+
if (mask_data[i]) {
65+
mask_true_count++;
66+
}
67+
}
68+
Tensor::SizesType out_numel =
69+
mask_true_count * (broadcast_numel / mask.numel());
70+
71+
// Resize the out tensor
72+
ET_KERNEL_CHECK(
73+
ctx, resize_tensor(out, {out_numel}) == Error::Ok, InvalidArgument, out);
74+
75+
const char* const in_data =
76+
reinterpret_cast<const char*>(in.const_data_ptr());
77+
char* const out_data = reinterpret_cast<char*>(out.mutable_data_ptr());
78+
const auto elem_size = in.element_size();
79+
80+
// Figure out if `in` is broadcasted
81+
bool in_is_broadcasted = false;
82+
if (in.dim() != broadcast_ndim) {
83+
in_is_broadcasted = true;
84+
} else {
85+
for (size_t i = 0; i < in.dim(); ++i) {
86+
if (in.size(i) != broadcast_sizes[i]) {
87+
in_is_broadcasted = true;
88+
}
89+
}
90+
}
91+
92+
// Figure out if `mask` is broadcasted
93+
bool mask_is_broadcasted = false;
94+
if (mask.dim() != broadcast_ndim) {
95+
mask_is_broadcasted = true;
96+
} else {
97+
for (size_t i = 0; i < mask.dim(); ++i) {
98+
if (mask.size(i) != broadcast_sizes[i]) {
99+
mask_is_broadcasted = true;
100+
}
101+
}
102+
}
103+
104+
// Figure out if either `in` or `mask` is broadcasted
105+
bool any_is_broadcasted = (in_is_broadcasted || mask_is_broadcasted);
106+
107+
size_t out_ix = 0;
108+
for (size_t i = 0; i < broadcast_numel; ++i) {
109+
size_t in_linear_index = i;
110+
size_t mask_linear_index = i;
111+
112+
// If either `in` or `mask` is broadcasted, we need to compute the indexes
113+
// in the broadcasted space.
114+
if (any_is_broadcasted) {
115+
size_t broadcast_indexes[kTensorDimensionLimit];
116+
delinearize_index(
117+
i,
118+
{broadcast_sizes, broadcast_ndim},
119+
broadcast_indexes,
120+
kTensorDimensionLimit);
121+
122+
if (in_is_broadcasted) {
123+
in_linear_index =
124+
linearize_access_indexes(broadcast_indexes, broadcast_ndim, in);
125+
}
126+
if (mask_is_broadcasted) {
127+
mask_linear_index =
128+
linearize_access_indexes(broadcast_indexes, broadcast_ndim, mask);
129+
}
130+
}
131+
132+
// If the mask is true, copy the value from `in` to `out` and increment the
133+
// `out_ix`
134+
if (mask_data[mask_linear_index]) {
135+
memcpy(
136+
out_data + out_ix * elem_size,
137+
in_data + in_linear_index * elem_size,
138+
elem_size);
139+
out_ix++;
140+
}
141+
}
142+
143+
return out;
144+
}
145+
146+
} // namespace native
147+
} // namespace executor
148+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,11 @@
547547
- arg_meta: null
548548
kernel_name: torch::executor::masked_scatter_out
549549

550+
- op: masked_select.out
551+
kernels:
552+
- arg_meta: null
553+
kernel_name: torch::executor::masked_select_out
554+
550555
- op: max.dim_max
551556
kernels:
552557
- arg_meta: null
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/kernels/test/supported_features.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15+
16+
#include <gtest/gtest.h>
17+
18+
using namespace ::testing;
19+
using exec_aten::ScalarType;
20+
using exec_aten::Tensor;
21+
using torch::executor::testing::SupportedFeatures;
22+
using torch::executor::testing::TensorFactory;
23+
24+
class OpMaskedSelectOutTest : public OperatorTest {
25+
protected:
26+
Tensor&
27+
op_masked_select_out(const Tensor& in, const Tensor& mask, Tensor& out) {
28+
return torch::executor::aten::masked_select_outf(context_, in, mask, out);
29+
}
30+
};
31+
32+
TEST_F(OpMaskedSelectOutTest, SmokeTest) {
33+
TensorFactory<ScalarType::Int> tf;
34+
TensorFactory<ScalarType::Bool> tfBool;
35+
36+
Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
37+
Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true});
38+
Tensor out = tf.zeros({3});
39+
40+
op_masked_select_out(in, mask, out);
41+
EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 4, 6}));
42+
}
43+
44+
TEST_F(OpMaskedSelectOutTest, BroadcastInput) {
45+
TensorFactory<ScalarType::Int> tf;
46+
TensorFactory<ScalarType::Bool> tfBool;
47+
48+
Tensor in = tf.make({3}, {1, 2, 3});
49+
Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true});
50+
Tensor out = tf.zeros({3});
51+
52+
op_masked_select_out(in, mask, out);
53+
EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 1, 3}));
54+
}
55+
56+
TEST_F(OpMaskedSelectOutTest, BroadcastMask) {
57+
TensorFactory<ScalarType::Int> tf;
58+
TensorFactory<ScalarType::Bool> tfBool;
59+
60+
Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
61+
Tensor mask = tfBool.make({3}, {false, true, false});
62+
63+
Tensor out = tf.zeros({2});
64+
65+
op_masked_select_out(in, mask, out);
66+
EXPECT_TENSOR_EQ(out, tf.make({2}, {2, 5}));
67+
}
68+
69+
TEST_F(OpMaskedSelectOutTest, BroadcastInputAndMask) {
70+
TensorFactory<ScalarType::Int> tf;
71+
TensorFactory<ScalarType::Bool> tfBool;
72+
73+
Tensor in = tf.ones({2, 3, 4, 1});
74+
Tensor mask = tfBool.ones({2, 1, 1, 5});
75+
Tensor out = tf.zeros({120});
76+
77+
op_masked_select_out(in, mask, out);
78+
EXPECT_TENSOR_EQ(out, tf.ones({120}));
79+
}
80+
81+
TEST_F(OpMaskedSelectOutTest, EmptyInput) {
82+
TensorFactory<ScalarType::Int> tf;
83+
TensorFactory<ScalarType::Bool> tfBool;
84+
85+
Tensor in = tf.make({2, 0}, {});
86+
Tensor mask = tfBool.make({2, 1}, {true, true});
87+
Tensor out = tf.zeros({0});
88+
89+
op_masked_select_out(in, mask, out);
90+
EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
91+
}
92+
93+
TEST_F(OpMaskedSelectOutTest, EmptyMask) {
94+
TensorFactory<ScalarType::Int> tf;
95+
TensorFactory<ScalarType::Bool> tfBool;
96+
97+
Tensor in = tf.make({2, 1}, {100, 200});
98+
Tensor mask = tfBool.make({2, 0}, {});
99+
Tensor out = tf.zeros({0});
100+
101+
op_masked_select_out(in, mask, out);
102+
EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
103+
}
104+
105+
TEST_F(OpMaskedSelectOutTest, EmptyInputAndMask) {
106+
TensorFactory<ScalarType::Int> tf;
107+
TensorFactory<ScalarType::Bool> tfBool;
108+
109+
Tensor in = tf.make({2, 0}, {});
110+
Tensor mask = tfBool.make({0}, {});
111+
Tensor out = tf.zeros({0});
112+
113+
op_masked_select_out(in, mask, out);
114+
EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
115+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def define_common_targets():
255255
_common_op_test("op_lt_test", ["aten", "portable"])
256256
_common_op_test("op_masked_fill_test", ["aten", "portable"])
257257
_common_op_test("op_masked_scatter_test", ["aten", "portable"])
258+
_common_op_test("op_masked_select_test", ["aten", "portable"])
258259
_common_op_test("op_max_test", ["aten", "portable"])
259260
_common_op_test("op_max_pool2d_with_indices_test", ["aten", "portable"])
260261
_common_op_test("op_maximum_test", ["aten", "portable"])

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,12 @@ ATEN_OPS = (
789789
"//executorch/kernels/portable/cpu/util:broadcast_util",
790790
],
791791
),
792+
op_target(
793+
name = "op_masked_select",
794+
deps = [
795+
"//executorch/kernels/portable/cpu/util:broadcast_util",
796+
],
797+
),
792798
op_target(
793799
name = "op_max",
794800
deps = [

0 commit comments

Comments
 (0)