Skip to content

Commit 153cf3b

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Allow kernel manual registration (#491)
Summary: X-link: pytorch/pytorch#110086 Pull Request resolved: #491 Exposing a codegen mode for generating a hook for user to register their kernels. If we pass `--manual-registration` flag to `gen_executorch.py`, we will generate the following files: 1. RegisterKernels.h which declares a `register_all_kernels()` API inside `torch::executor` namespace. 2. RegisterKernelsEverything.cpp which implements `register_all_kernels()` by defining an array of generated kernels. This way user can depend on the library declared by `executorch_generated_lib` macro (with `manual_registration=True`) and be able to include `RegisterKernels.h`. Then they can manually call `register_all_kernels()` instead of relying on C++ static initialization mechanism which is not available in some embedded systems. Reviewed By: cccclai Differential Revision: D49439673 fbshipit-source-id: 50ee80e1040e5dd32eb94f08c78f1544811cd08d
1 parent 8cbfa80 commit 153cf3b

File tree

5 files changed

+150
-9
lines changed

5 files changed

+150
-9
lines changed

codegen/templates/RegisterKernels.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// ${generated_comment}
10+
// This implements register_all_kernels() API that is declared in
11+
// RegisterKernels.h
12+
#include "RegisterKernels.h"
13+
#include "${fn_header}" // Generated Function import headers
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
Error register_all_kernels() {
19+
Kernel kernels_to_register[] = {
20+
${unboxed_kernels} // Generated kernels
21+
};
22+
Error success_with_kernel_reg = register_kernels(kernels_to_register);
23+
if (success_with_kernel_reg != Error::Ok) {
24+
ET_LOG(Error, "Failed register all kernels");
25+
return success_with_kernel_reg;
26+
}
27+
return Error::Ok;
28+
}
29+
30+
} // namespace executor
31+
} // namespace torch

codegen/templates/RegisterKernels.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// ${generated_comment}
10+
// Exposing an API for registering all kernels at once.
11+
#include <executorch/runtime/core/evalue.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/kernel/operator_registry.h>
14+
#include <executorch/runtime/platform/profiler.h>
15+
16+
namespace torch {
17+
namespace executor {
18+
19+
Error register_all_kernels();
20+
21+
} // namespace executor
22+
} // namespace torch

runtime/kernel/test/targets.bzl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,30 @@ def define_common_targets():
7575
],
7676
)
7777

78+
executorch_generated_lib(
79+
name = "test_manual_registration_lib",
80+
deps = [
81+
":executorch_all_ops",
82+
"//executorch/kernels/portable:operators",
83+
],
84+
functions_yaml_target = "//executorch/kernels/portable:functions.yaml",
85+
manual_registration = True,
86+
visibility = [
87+
"//executorch/...",
88+
],
89+
)
90+
91+
runtime.cxx_test(
92+
name = "test_kernel_manual_registration",
93+
srcs = [
94+
"test_kernel_manual_registration.cpp",
95+
],
96+
deps = [
97+
"//executorch/runtime/kernel:operator_registry",
98+
":test_manual_registration_lib",
99+
],
100+
)
101+
78102
for aten_mode in (True, False):
79103
aten_suffix = "_aten" if aten_mode else ""
80104

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Include RegisterKernels.h and call register_all_kernels().
10+
#include <gtest/gtest.h>
11+
12+
#include <executorch/runtime/kernel/operator_registry.h>
13+
#include <executorch/runtime/kernel/test/RegisterKernels.h>
14+
#include <executorch/runtime/platform/runtime.h>
15+
16+
using namespace ::testing;
17+
18+
namespace torch {
19+
namespace executor {
20+
21+
class KernelManualRegistrationTest : public ::testing::Test {
22+
public:
23+
void SetUp() override {
24+
torch::executor::runtime_init();
25+
}
26+
};
27+
28+
TEST_F(KernelManualRegistrationTest, ManualRegister) {
29+
Error result = register_all_kernels();
30+
// Check that we can find the kernel for foo.
31+
EXPECT_EQ(result, Error::Ok);
32+
EXPECT_FALSE(hasOpsFn("fpp"));
33+
EXPECT_TRUE(hasOpsFn("aten::add.out"));
34+
}
35+
36+
} // namespace executor
37+
} // namespace torch

shim/xplat/executorch/codegen/codegen.bzl

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ GENERATED_SOURCES = [
2020
"RegisterCodegenUnboxedKernelsEverything.cpp",
2121
]
2222

23+
MANUAL_REGISTRATION_SOURCES = [
24+
# buildifier: keep sorted
25+
"RegisterKernelsEverything.cpp",
26+
]
27+
28+
MANUAL_REGISTRATION_HEADERS = [
29+
"RegisterKernels.h",
30+
]
31+
2332
# Fake kernels only return `out` or any other tensor from arguments
2433
CUSTOM_OPS_DUMMY_KERNEL_SOURCES = ["Register{}Stub.cpp".format(backend) for backend in STATIC_DISPATCH_BACKENDS]
2534

@@ -35,11 +44,9 @@ CUSTOM_OPS_SCHEMA_REGISTRATION_SOURCES = [
3544
def et_operator_library(
3645
name,
3746
ops = [],
38-
exported_deps = [],
3947
model = None,
4048
include_all_operators = False,
4149
ops_schema_yaml_target = None,
42-
define_static_targets = False,
4350
**kwargs):
4451
genrule_cmd = [
4552
"$(exe //executorch/codegen/tools:gen_oplist)",
@@ -61,6 +68,10 @@ def et_operator_library(
6168
genrule_cmd.append(
6269
"--include_all_operators",
6370
)
71+
72+
# TODO(larryliu0820): Remove usages of this flag.
73+
if "define_static_targets" in kwargs:
74+
kwargs.pop("define_static_targets")
6475
runtime.genrule(
6576
name = name,
6677
macros_only = False,
@@ -70,18 +81,22 @@ def et_operator_library(
7081
**kwargs
7182
)
7283

73-
def _get_headers(genrule_name, prefix = "", custom_op = None):
84+
def _get_headers(genrule_name, prefix = "", custom_op = None, manual_registration = False):
85+
headers = OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_op else [])
7486
return {
7587
prefix + f: ":{}[{}]".format(genrule_name, f)
76-
for f in OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_op else [])
88+
for f in (MANUAL_REGISTRATION_HEADERS if manual_registration else [])
89+
}, {
90+
prefix + f: ":{}[{}]".format(genrule_name, f)
91+
for f in headers
7792
}
7893

7994
def _prepare_genrule_and_lib(
8095
name,
8196
functions_yaml_path = None,
8297
custom_ops_yaml_path = None,
83-
custom_ops_aten_kernel_deps = [],
8498
custom_ops_requires_runtime_registration = True,
99+
manual_registration = False,
85100
aten_mode = False):
86101
"""
87102
This function returns two dicts `genrules` and `libs`, derived from the arguments being passed
@@ -122,15 +137,18 @@ def _prepare_genrule_and_lib(
122137
# actually-generated files matches GENERATED_FILES.
123138
]
124139

140+
# Sources for generated kernel registration lib
141+
sources = MANUAL_REGISTRATION_SOURCES if manual_registration else GENERATED_SOURCES
142+
125143
# The command will always generate these files.
126-
genrule_outs = GENERATED_SOURCES + OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_ops_yaml_path else [])
144+
genrule_outs = sources + OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_ops_yaml_path else []) + MANUAL_REGISTRATION_HEADERS
127145

128146
genrules = {}
129147
libs = {}
130148

131149
# if aten_mode is true, we don't need functions_yaml_path
132150
genrule_name = name + "_combined"
133-
headers = _get_headers(genrule_name = genrule_name, custom_op = custom_ops_yaml_path)
151+
exported_headers, headers = _get_headers(genrule_name = genrule_name, custom_op = custom_ops_yaml_path, manual_registration = manual_registration)
134152

135153
# need to register ATen ops into Executorch runtime:
136154
need_reg_aten_ops = aten_mode or functions_yaml_path
@@ -149,6 +167,10 @@ def _prepare_genrule_and_lib(
149167
]
150168
if aten_mode:
151169
genrule_cmd = genrule_cmd + ["--use_aten_lib"]
170+
if manual_registration:
171+
genrule_cmd = genrule_cmd + [
172+
"--manual_registration",
173+
]
152174
if custom_ops_yaml_path:
153175
genrule_cmd = genrule_cmd + [
154176
"--custom_ops_yaml_path=" + custom_ops_yaml_path,
@@ -160,13 +182,15 @@ def _prepare_genrule_and_lib(
160182

161183
if need_reg_ops:
162184
libs[name] = {
185+
"exported_headers": exported_headers,
163186
"genrule": genrule_name,
164187
"headers": headers,
165-
"srcs": GENERATED_SOURCES,
188+
"srcs": sources,
166189
}
167190

168191
header_lib = name + "_headers"
169192
libs[header_lib] = {
193+
"exported_headers": exported_headers,
170194
"headers": headers,
171195
}
172196
return genrules, libs
@@ -303,6 +327,7 @@ def executorch_generated_lib(
303327
custom_ops_requires_runtime_registration = True,
304328
visibility = [],
305329
aten_mode = False,
330+
manual_registration = False,
306331
use_default_aten_ops_lib = True,
307332
deps = [],
308333
xplat_deps = [],
@@ -350,6 +375,7 @@ def executorch_generated_lib(
350375
visibility: Visibility of the C++ library targets.
351376
deps: Additinal deps of the main C++ library. Needs to be in either `//executorch` or `//caffe2` module.
352377
platforms: platforms args to runtime.cxx_library (only used when in xplat)
378+
manual_registration: if true, generate RegisterKernels.cpp and RegisterKernels.h.
353379
use_default_aten_ops_lib: If `aten_mode` is True AND this flag is True, use `torch_mobile_all_ops` for ATen operator library.
354380
xplat_deps: Additional xplat deps, can be used to provide custom operator library.
355381
fbcode_deps: Additional fbcode deps, can be used to provide custom operator library.
@@ -391,9 +417,9 @@ def executorch_generated_lib(
391417
name = name,
392418
functions_yaml_path = functions_yaml_path,
393419
custom_ops_yaml_path = custom_ops_yaml_path,
394-
custom_ops_aten_kernel_deps = custom_ops_aten_kernel_deps,
395420
custom_ops_requires_runtime_registration = custom_ops_requires_runtime_registration,
396421
aten_mode = aten_mode,
422+
manual_registration = manual_registration,
397423
)
398424

399425
# genrule for selective build from static operator list
@@ -457,6 +483,7 @@ def executorch_generated_lib(
457483
# target, and are not meant to be used by targets outside of this
458484
# directory.
459485
headers = libs[lib_name]["headers"],
486+
exported_headers = libs[lib_name]["exported_headers"],
460487
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
461488
# link_whole is necessary because the operators register themselves via
462489
# static initializers that run at program startup.

0 commit comments

Comments
 (0)