Skip to content

Commit 4f8ccab

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Cleanup operator class (#113)
Summary: Pull Request resolved: #113 This work is to improve memory efficiency on operator/kernel registration. We deprecated `Operator` in favor of `Kernel` quite a while ago in the specialized kernels project. From then on `Operator` contains a fixed number ( 8 ) of kernels and the registry contains a fixed number (250) of `Operator`s. This is a waste of memory in two aspects: 1. We register prim ops unconditionally as `Operators` but they only have 1 kernel. This is a 8 times of waste. 2. We should be able to leverage selective build information on how many kernels are being registered for this build. This number can be passed in by a preprocessing flag. To address this issue, this diff changes the 2-layer structure of Operator and Kernel, only store kernels in the registry. The downside is that we may experience a slight regression on static init time `register_kernels()` call because now it needs to do a linear search and make sure there's no duplicate kernels in the registry. This is O(N^2). Reviewed By: JacobSzwejbka Differential Revision: D48619001 fbshipit-source-id: 62c3e8e2938a93730c5345d903b450c0b7299a30
1 parent 2b85bcb commit 4f8ccab

File tree

5 files changed

+127
-278
lines changed

5 files changed

+127
-278
lines changed

extension/pybindings/module.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,12 @@ void create_profile_block(const std::string& name) {
418418
}
419419

420420
py::list get_ops_names() {
421-
const auto& ops_array = getOpsArray();
422-
py::list list(ops_array.size());
423-
for (size_t i = 0; i < ops_array.size(); ++i) {
424-
list[i] = std::string(ops_array[i].name_);
421+
const auto& kernels_array = get_kernels();
422+
py::set set;
423+
for (size_t i = 0; i < kernels_array.size(); ++i) {
424+
set.add(kernels_array[i].name_);
425425
}
426-
return list;
426+
return py::list(set);
427427
}
428428

429429
} // namespace

kernels/prim_ops/register_prim_ops.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include <executorch/runtime/kernel/kernel_includes.h>
1212
#include <executorch/runtime/kernel/operator_registry.h>
1313

14-
using OpArrayRef = ::torch::executor::ArrayRef<::torch::executor::Operator>;
14+
using KernelArrayRef = ::torch::executor::ArrayRef<::torch::executor::Kernel>;
1515
using torch::executor::function::et_copy_index;
1616

1717
namespace torch {
@@ -20,9 +20,9 @@ namespace function {
2020

2121
namespace {
2222

23-
static Operator prim_ops[] = {
23+
static Kernel prim_ops[] = {
2424
// aten::sym_size.int(Tensor self, int dim) -> SymInt
25-
Operator(
25+
Kernel(
2626
"aten::sym_size.int",
2727
[](RuntimeContext& context, EValue** stack) {
2828
(void)context;
@@ -35,7 +35,7 @@ static Operator prim_ops[] = {
3535
out = EValue(size);
3636
}),
3737
// aten::sym_numel(Tensor self) -> SymInt
38-
Operator(
38+
Kernel(
3939
"aten::sym_numel",
4040
[](RuntimeContext& context, EValue** stack) {
4141
(void)context;
@@ -46,7 +46,7 @@ static Operator prim_ops[] = {
4646
out = EValue(numel);
4747
}),
4848
// executorch_prim::add.Scalar(Scalar, Scalar) -> Scalar
49-
Operator(
49+
Kernel(
5050
"executorch_prim::add.Scalar",
5151
[](RuntimeContext& context, EValue** stack) {
5252
(void)context;
@@ -64,7 +64,7 @@ static Operator prim_ops[] = {
6464
}),
6565

6666
// executorch_prim::sub.Scalar(Scalar, Scalar) -> Scalar
67-
Operator(
67+
Kernel(
6868
"executorch_prim::sub.Scalar",
6969
[](RuntimeContext& context, EValue** stack) {
7070
(void)context;
@@ -82,7 +82,7 @@ static Operator prim_ops[] = {
8282
}),
8383

8484
// executorch_prim::mul.Scalar(Scalar, Scalar) -> Scalar
85-
Operator(
85+
Kernel(
8686
"executorch_prim::mul.Scalar",
8787
[](RuntimeContext& context, EValue** stack) {
8888
(void)context;
@@ -100,7 +100,7 @@ static Operator prim_ops[] = {
100100
}),
101101

102102
// executorch_prim::floordiv.Scalar(Scalar, Scalar) -> Scalar
103-
Operator(
103+
Kernel(
104104
"executorch_prim::floordiv.Scalar",
105105
[](RuntimeContext& context, EValue** stack) {
106106
(void)context;
@@ -118,7 +118,7 @@ static Operator prim_ops[] = {
118118
}),
119119

120120
// executorch_prim::eq.Scalar(Scalar, Scalar) -> bool
121-
Operator(
121+
Kernel(
122122
"executorch_prim::eq.Scalar",
123123
[](RuntimeContext& context, EValue** stack) {
124124
(void)context;
@@ -138,7 +138,7 @@ static Operator prim_ops[] = {
138138
}),
139139

140140
// executorch_prim::gt.Scalar(Scalar, Scalar) -> bool
141-
Operator(
141+
Kernel(
142142
"executorch_prim::gt.Scalar",
143143
[](RuntimeContext& context, EValue** stack) {
144144
(void)context;
@@ -158,7 +158,7 @@ static Operator prim_ops[] = {
158158
}),
159159

160160
// executorch_prim::lt.Scalar(Scalar, Scalar) -> bool
161-
Operator(
161+
Kernel(
162162
"executorch_prim::lt.Scalar",
163163
[](RuntimeContext& context, EValue** stack) {
164164
(void)context;
@@ -178,7 +178,7 @@ static Operator prim_ops[] = {
178178
}),
179179

180180
// executorch_prim::ge.Scalar(Scalar, Scalar) -> bool
181-
Operator(
181+
Kernel(
182182
"executorch_prim::ge.Scalar",
183183
[](RuntimeContext& context, EValue** stack) {
184184
(void)context;
@@ -198,7 +198,7 @@ static Operator prim_ops[] = {
198198
}),
199199

200200
// executorch_prim::le.Scalar(Scalar, Scalar) -> bool
201-
Operator(
201+
Kernel(
202202
"executorch_prim::le.Scalar",
203203
[](RuntimeContext& context, EValue** stack) {
204204
(void)context;
@@ -220,7 +220,7 @@ static Operator prim_ops[] = {
220220
// TODO(T159977211): wait a little bit so older models with these ops are
221221
// regenerated and then delete them
222222
// executorch_prim::add.int(int, int) -> int
223-
Operator(
223+
Kernel(
224224
"executorch_prim::add.int",
225225
[](RuntimeContext& context, EValue** stack) {
226226
(void)context;
@@ -231,7 +231,7 @@ static Operator prim_ops[] = {
231231
}),
232232

233233
// executorch_prim::sub.int(int, int) -> int
234-
Operator(
234+
Kernel(
235235
"executorch_prim::sub.int",
236236
[](RuntimeContext& context, EValue** stack) {
237237
(void)context;
@@ -242,7 +242,7 @@ static Operator prim_ops[] = {
242242
}),
243243

244244
// executorch_prim::mul.int(int, int) -> int
245-
Operator(
245+
Kernel(
246246
"executorch_prim::mul.int",
247247
[](RuntimeContext& context, EValue** stack) {
248248
(void)context;
@@ -253,7 +253,7 @@ static Operator prim_ops[] = {
253253
}),
254254

255255
// executorch_prim::floordiv.int(int, int) -> int
256-
Operator(
256+
Kernel(
257257
"executorch_prim::floordiv.int",
258258
[](RuntimeContext& context, EValue** stack) {
259259
(void)context;
@@ -264,7 +264,7 @@ static Operator prim_ops[] = {
264264
}),
265265

266266
// executorch_prim::eq.int(int, int) -> bool
267-
Operator(
267+
Kernel(
268268
"executorch_prim::eq.int",
269269
[](RuntimeContext& context, EValue** stack) {
270270
(void)context;
@@ -275,7 +275,7 @@ static Operator prim_ops[] = {
275275
}),
276276

277277
// executorch_prim::gt.int(int, int) -> bool
278-
Operator(
278+
Kernel(
279279
"executorch_prim::gt.int",
280280
[](RuntimeContext& context, EValue** stack) {
281281
(void)context;
@@ -286,7 +286,7 @@ static Operator prim_ops[] = {
286286
}),
287287

288288
// executorch_prim::lt.int(int, int) -> bool
289-
Operator(
289+
Kernel(
290290
"executorch_prim::lt.int",
291291
[](RuntimeContext& context, EValue** stack) {
292292
(void)context;
@@ -297,7 +297,7 @@ static Operator prim_ops[] = {
297297
}),
298298

299299
// executorch_prim::ge.int(int, int) -> bool
300-
Operator(
300+
Kernel(
301301
"executorch_prim::ge.int",
302302
[](RuntimeContext& context, EValue** stack) {
303303
(void)context;
@@ -308,7 +308,7 @@ static Operator prim_ops[] = {
308308
}),
309309

310310
// executorch_prim::le.int(int, int) -> bool
311-
Operator(
311+
Kernel(
312312
"executorch_prim::le.int",
313313
[](RuntimeContext& context, EValue** stack) {
314314
(void)context;
@@ -319,17 +319,17 @@ static Operator prim_ops[] = {
319319
}),
320320

321321
// executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor
322-
Operator("executorch_prim::et_copy_index.tensor", &et_copy_index),
322+
Kernel("executorch_prim::et_copy_index.tensor", &et_copy_index),
323323

324324
};
325325

326-
static OpArrayRef op_array_ref(
326+
static KernelArrayRef kernel_array_ref(
327327
prim_ops,
328-
prim_ops + sizeof(prim_ops) / sizeof(Operator));
328+
prim_ops + sizeof(prim_ops) / sizeof(Kernel));
329329

330330
// Return value not used. Keep the static variable assignment to register
331331
// operators in static initialization time.
332-
static auto success_with_op_reg = register_operators(op_array_ref);
332+
static auto success_with_kernel_reg = register_kernels(kernel_array_ref);
333333

334334
} // namespace
335335
} // namespace function

0 commit comments

Comments
 (0)