Skip to content

Commit 11f18f2

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Allow kernel manual registration (#491)
Summary: X-link: pytorch/pytorch#110086 Exposing a codegen mode for generating a hook for user to register their kernels. If we pass `--manual-registration` flag to `gen_executorch.py`, we will generate the following files: 1. RegisterKernels.h which declares a `register_all_kernels()` API inside `torch::executor` namespace. 2. RegisterKernelsEverything.cpp which implements `register_all_kernels()` by defining an array of generated kernels. This way user can depend on the library declared by `executorch_generated_lib` macro (with `manual_registration=True`) and be able to include `RegisterKernels.h`. Then they can manually call `register_all_kernels()` instead of relying on C++ static initialization mechanism which is not available in some embedded systems. Reviewed By: cccclai Differential Revision: D49439673
1 parent b6e182c commit 11f18f2

File tree

5 files changed

+146
-8
lines changed

5 files changed

+146
-8
lines changed

codegen/templates/RegisterKernels.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// ${generated_comment}
10+
// This implements register_all_kernels() API that is declared in
11+
// RegisterKernels.h
12+
#include "RegisterKernels.h"
13+
#include "${fn_header}" // Generated Function import headers
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
Error register_all_kernels() {
19+
Kernel kernels_to_register[] = {
20+
${unboxed_kernels} // Generated kernels
21+
};
22+
Error success_with_kernel_reg = register_kernels(kernels_to_register);
23+
if (success_with_kernel_reg != Error::Ok) {
24+
ET_LOG(Error, "Failed register all kernels");
25+
return success_with_kernel_reg;
26+
}
27+
return Error::Ok;
28+
}
29+
30+
} // namespace executor
31+
} // namespace torch

codegen/templates/RegisterKernels.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// ${generated_comment}
10+
// Exposing an API for registering all kernels at once.
11+
#include <executorch/runtime/core/evalue.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/kernel/operator_registry.h>
14+
#include <executorch/runtime/platform/profiler.h>
15+
16+
namespace torch {
17+
namespace executor {
18+
19+
Error register_all_kernels();
20+
21+
} // namespace executor
22+
} // namespace torch

runtime/kernel/test/targets.bzl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,30 @@ def define_common_targets():
7575
],
7676
)
7777

78+
executorch_generated_lib(
79+
name = "test_manual_registration_lib",
80+
deps = [
81+
":executorch_all_ops",
82+
"//executorch/kernels/portable:operators",
83+
],
84+
functions_yaml_target = "//executorch/kernels/portable:functions.yaml",
85+
manual_registration = True,
86+
visibility = [
87+
"//executorch/...",
88+
],
89+
)
90+
91+
runtime.cxx_test(
92+
name = "test_kernel_manual_registration",
93+
srcs = [
94+
"test_kernel_manual_registration.cpp",
95+
],
96+
deps = [
97+
"//executorch/runtime/kernel:operator_registry",
98+
":test_manual_registration_lib",
99+
],
100+
)
101+
78102
for aten_mode in (True, False):
79103
aten_suffix = "_aten" if aten_mode else ""
80104

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Include RegisterKernels.h and call register_all_kernels().
10+
#include <gtest/gtest.h>
11+
12+
#include <executorch/runtime/kernel/operator_registry.h>
13+
#include <executorch/runtime/kernel/test/RegisterKernels.h>
14+
#include <executorch/runtime/platform/runtime.h>
15+
16+
using namespace ::testing;
17+
18+
namespace torch {
19+
namespace executor {
20+
21+
class KernelManualRegistrationTest : public ::testing::Test {
22+
public:
23+
void SetUp() override {
24+
torch::executor::runtime_init();
25+
}
26+
};
27+
28+
TEST_F(KernelManualRegistrationTest, ManualRegister) {
29+
Error result = register_all_kernels();
30+
// Check that we can find the kernel for foo.
31+
EXPECT_EQ(result, Error::Ok);
32+
EXPECT_FALSE(hasOpsFn("fpp"));
33+
EXPECT_TRUE(hasOpsFn("aten::add.out"));
34+
}
35+
36+
} // namespace executor
37+
} // namespace torch

shim/xplat/executorch/codegen/codegen.bzl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ GENERATED_SOURCES = [
2020
"RegisterCodegenUnboxedKernelsEverything.cpp",
2121
]
2222

23+
MANUAL_REGISTRATION_SOURCES = [
24+
# buildifier: keep sorted
25+
"RegisterKernelsEverything.cpp",
26+
]
27+
28+
MANUAL_REGISTRATION_HEADERS = [
29+
"RegisterKernels.h",
30+
]
31+
2332
# Fake kernels only return `out` or any other tensor from arguments
2433
CUSTOM_OPS_DUMMY_KERNEL_SOURCES = ["Register{}Stub.cpp".format(backend) for backend in STATIC_DISPATCH_BACKENDS]
2534

@@ -35,7 +44,6 @@ CUSTOM_OPS_SCHEMA_REGISTRATION_SOURCES = [
3544
def et_operator_library(
3645
name,
3746
ops = [],
38-
exported_deps = [],
3947
model = None,
4048
include_all_operators = False,
4149
ops_schema_yaml_target = None,
@@ -70,18 +78,22 @@ def et_operator_library(
7078
**kwargs
7179
)
7280

73-
def _get_headers(genrule_name, prefix = "", custom_op = None):
81+
def _get_headers(genrule_name, prefix = "", custom_op = None, manual_registration = False):
82+
headers = OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_op else [])
7483
return {
7584
prefix + f: ":{}[{}]".format(genrule_name, f)
76-
for f in OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_op else [])
85+
for f in (MANUAL_REGISTRATION_HEADERS if manual_registration else [])
86+
}, {
87+
prefix + f: ":{}[{}]".format(genrule_name, f)
88+
for f in headers
7789
}
7890

7991
def _prepare_genrule_and_lib(
8092
name,
8193
functions_yaml_path = None,
8294
custom_ops_yaml_path = None,
83-
custom_ops_aten_kernel_deps = [],
8495
custom_ops_requires_runtime_registration = True,
96+
manual_registration = False,
8597
aten_mode = False):
8698
"""
8799
This function returns two dicts `genrules` and `libs`, derived from the arguments being passed
@@ -122,15 +134,18 @@ def _prepare_genrule_and_lib(
122134
# actually-generated files matches GENERATED_FILES.
123135
]
124136

137+
# Sources for generated kernel registration lib
138+
sources = MANUAL_REGISTRATION_SOURCES if manual_registration else GENERATED_SOURCES
139+
125140
# The command will always generate these files.
126-
genrule_outs = GENERATED_SOURCES + OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_ops_yaml_path else [])
141+
genrule_outs = sources + OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_ops_yaml_path else []) + MANUAL_REGISTRATION_HEADERS
127142

128143
genrules = {}
129144
libs = {}
130145

131146
# if aten_mode is true, we don't need functions_yaml_path
132147
genrule_name = name + "_combined"
133-
headers = _get_headers(genrule_name = genrule_name, custom_op = custom_ops_yaml_path)
148+
exported_headers, headers = _get_headers(genrule_name = genrule_name, custom_op = custom_ops_yaml_path, manual_registration = manual_registration)
134149

135150
# need to register ATen ops into Executorch runtime:
136151
need_reg_aten_ops = aten_mode or functions_yaml_path
@@ -149,6 +164,10 @@ def _prepare_genrule_and_lib(
149164
]
150165
if aten_mode:
151166
genrule_cmd = genrule_cmd + ["--use_aten_lib"]
167+
if manual_registration:
168+
genrule_cmd = genrule_cmd + [
169+
"--manual_registration",
170+
]
152171
if custom_ops_yaml_path:
153172
genrule_cmd = genrule_cmd + [
154173
"--custom_ops_yaml_path=" + custom_ops_yaml_path,
@@ -160,13 +179,15 @@ def _prepare_genrule_and_lib(
160179

161180
if need_reg_ops:
162181
libs[name] = {
182+
"exported_headers": exported_headers,
163183
"genrule": genrule_name,
164184
"headers": headers,
165-
"srcs": GENERATED_SOURCES,
185+
"srcs": sources,
166186
}
167187

168188
header_lib = name + "_headers"
169189
libs[header_lib] = {
190+
"exported_headers": exported_headers,
170191
"headers": headers,
171192
}
172193
return genrules, libs
@@ -303,6 +324,7 @@ def executorch_generated_lib(
303324
custom_ops_requires_runtime_registration = True,
304325
visibility = [],
305326
aten_mode = False,
327+
manual_registration = False,
306328
use_default_aten_ops_lib = True,
307329
deps = [],
308330
xplat_deps = [],
@@ -350,6 +372,7 @@ def executorch_generated_lib(
350372
visibility: Visibility of the C++ library targets.
351373
deps: Additinal deps of the main C++ library. Needs to be in either `//executorch` or `//caffe2` module.
352374
platforms: platforms args to runtime.cxx_library (only used when in xplat)
375+
manual_registration: if true, generate RegisterKernels.cpp and RegisterKernels.h.
353376
use_default_aten_ops_lib: If `aten_mode` is True AND this flag is True, use `torch_mobile_all_ops` for ATen operator library.
354377
xplat_deps: Additional xplat deps, can be used to provide custom operator library.
355378
fbcode_deps: Additional fbcode deps, can be used to provide custom operator library.
@@ -391,9 +414,9 @@ def executorch_generated_lib(
391414
name = name,
392415
functions_yaml_path = functions_yaml_path,
393416
custom_ops_yaml_path = custom_ops_yaml_path,
394-
custom_ops_aten_kernel_deps = custom_ops_aten_kernel_deps,
395417
custom_ops_requires_runtime_registration = custom_ops_requires_runtime_registration,
396418
aten_mode = aten_mode,
419+
manual_registration = manual_registration,
397420
)
398421

399422
# genrule for selective build from static operator list
@@ -457,6 +480,7 @@ def executorch_generated_lib(
457480
# target, and are not meant to be used by targets outside of this
458481
# directory.
459482
headers = libs[lib_name]["headers"],
483+
exported_headers = libs[lib_name]["exported_headers"],
460484
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
461485
# link_whole is necessary because the operators register themselves via
462486
# static initializers that run at program startup.

0 commit comments

Comments
 (0)