Skip to content

Commit 8a1f1c2

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Allow kernel manual registration (#491)
Summary: X-link: pytorch/pytorch#110086 Exposing a codegen mode for generating a hook for user to register their kernels. If we pass `--manual-registration` flag to `gen_executorch.py`, we will generate the following files: 1. RegisterKernels.h which declares a `register_all_kernels()` API inside `torch::executor` namespace. 2. RegisterKernelsEverything.cpp which implements `register_all_kernels()` by defining an array of generated kernels. This way user can depend on the library declared by `executorch_generated_lib` macro (with `manual_registration=True`) and be able to include `RegisterKernels.h`. Then they can manually call `register_all_kernels()` instead of relying on C++ static initialization mechanism which is not available in some embedded systems. Reviewed By: cccclai Differential Revision: D49439673
1 parent f6a8d9d commit 8a1f1c2

File tree

5 files changed

+148
-5
lines changed

5 files changed

+148
-5
lines changed

codegen/templates/RegisterKernels.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
// ${generated_comment}
9+
// This implements register_all_kernels() API that is declared in
10+
// RegisterKernels.h
11+
#include "RegisterKernels.h"
12+
#include "${fn_header}" // Generated Function import headers
13+
14+
namespace torch {
15+
namespace executor {
16+
17+
Error register_all_kernels() {
18+
Kernel kernels_to_register[] = {
19+
${unboxed_kernels} // Generated kernels
20+
};
21+
Error success_with_kernel_reg = register_kernels(kernels_to_register);
22+
if (success_with_kernel_reg != Error::Ok) {
23+
ET_LOG(Error, "Failed register all kernels");
24+
return success_with_kernel_reg;
25+
}
26+
return Error::Ok;
27+
}
28+
29+
} // namespace executor
30+
} // namespace torch

codegen/templates/RegisterKernels.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
// ${generated_comment}
9+
// Exposing an API for registering all kernels at once.
10+
#include <executorch/runtime/core/evalue.h>
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/kernel/operator_registry.h>
13+
#include <executorch/runtime/platform/profiler.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
Error register_all_kernels();
19+
20+
} // namespace executor
21+
} // namespace torch

runtime/kernel/test/targets.bzl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,30 @@ def define_common_targets():
7575
],
7676
)
7777

78+
executorch_generated_lib(
79+
name = "test_manual_registration_lib",
80+
deps = [
81+
":executorch_all_ops",
82+
"//executorch/kernels/portable:operators",
83+
],
84+
functions_yaml_target = "//executorch/kernels/portable:functions.yaml",
85+
manual_registration = True,
86+
visibility = [
87+
"//executorch/...",
88+
],
89+
)
90+
91+
runtime.cxx_test(
92+
name = "test_kernel_manual_registration",
93+
srcs = [
94+
"test_kernel_manual_registration.cpp",
95+
],
96+
deps = [
97+
"//executorch/runtime/kernel:operator_registry",
98+
":test_manual_registration_lib",
99+
],
100+
)
101+
78102
for aten_mode in (True, False):
79103
aten_suffix = "_aten" if aten_mode else ""
80104

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Include RegisterKernels.h and call register_all_kernels().
10+
#include <gtest/gtest.h>
11+
#include <vector>
12+
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <executorch/runtime/kernel/kernel_runtime_context.h>
15+
#include <executorch/runtime/kernel/operator_registry.h>
16+
#include <executorch/runtime/kernel/test/RegisterKernels.h>
17+
#include <executorch/runtime/platform/runtime.h>
18+
#include <executorch/test/utils/DeathTest.h>
19+
20+
using namespace ::testing;
21+
22+
namespace torch {
23+
namespace executor {
24+
25+
class KernelManualRegistrationTest : public ::testing::Test {
26+
public:
27+
void SetUp() override {
28+
torch::executor::runtime_init();
29+
}
30+
};
31+
32+
TEST_F(KernelManualRegistrationTest, ManualRegister) {
33+
Error result = register_all_kernels();
34+
// Check that we can find the kernel for foo.
35+
EXPECT_EQ(result, Error::Ok);
36+
EXPECT_FALSE(hasOpsFn("fpp"));
37+
EXPECT_TRUE(hasOpsFn("aten::add.out"));
38+
}
39+
40+
} // namespace executor
41+
} // namespace torch

shim/xplat/executorch/codegen/codegen.bzl

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ GENERATED_SOURCES = [
2020
"RegisterCodegenUnboxedKernelsEverything.cpp",
2121
]
2222

23+
MANUAL_REGISTRATION_SOURCES = [
24+
# buildifier: keep sorted
25+
"RegisterKernelsEverything.cpp",
26+
]
27+
28+
MANUAL_REGISTRATION_HEADERS = [
29+
"RegisterKernels.h",
30+
]
31+
2332
# Fake kernels only return `out` or any other tensor from arguments
2433
CUSTOM_OPS_DUMMY_KERNEL_SOURCES = ["Register{}Stub.cpp".format(backend) for backend in STATIC_DISPATCH_BACKENDS]
2534

@@ -70,10 +79,14 @@ def et_operator_library(
7079
**kwargs
7180
)
7281

73-
def _get_headers(genrule_name, prefix = "", custom_op = None):
82+
def _get_headers(genrule_name, prefix = "", custom_op = None, manual_registration = False):
83+
headers = OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_op else [])
7484
return {
7585
prefix + f: ":{}[{}]".format(genrule_name, f)
76-
for f in OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_op else [])
86+
for f in (MANUAL_REGISTRATION_HEADERS if manual_registration else [])
87+
}, {
88+
prefix + f: ":{}[{}]".format(genrule_name, f)
89+
for f in headers
7790
}
7891

7992
def _prepare_genrule_and_lib(
@@ -82,6 +95,7 @@ def _prepare_genrule_and_lib(
8295
custom_ops_yaml_path = None,
8396
custom_ops_aten_kernel_deps = [],
8497
custom_ops_requires_runtime_registration = True,
98+
manual_registration = False,
8599
aten_mode = False):
86100
"""
87101
This function returns two dicts `genrules` and `libs`, derived from the arguments being passed
@@ -122,15 +136,18 @@ def _prepare_genrule_and_lib(
122136
# actually-generated files matches GENERATED_FILES.
123137
]
124138

139+
# Sources for generated kernel registration lib
140+
sources = MANUAL_REGISTRATION_SOURCES if manual_registration else GENERATED_SOURCES
141+
125142
# The command will always generate these files.
126-
genrule_outs = GENERATED_SOURCES + OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_ops_yaml_path else [])
143+
genrule_outs = sources + OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_ops_yaml_path else []) + MANUAL_REGISTRATION_HEADERS
127144

128145
genrules = {}
129146
libs = {}
130147

131148
# if aten_mode is true, we don't need functions_yaml_path
132149
genrule_name = name + "_combined"
133-
headers = _get_headers(genrule_name = genrule_name, custom_op = custom_ops_yaml_path)
150+
exported_headers, headers = _get_headers(genrule_name = genrule_name, custom_op = custom_ops_yaml_path, manual_registration = manual_registration)
134151

135152
# need to register ATen ops into Executorch runtime:
136153
need_reg_aten_ops = aten_mode or functions_yaml_path
@@ -149,6 +166,10 @@ def _prepare_genrule_and_lib(
149166
]
150167
if aten_mode:
151168
genrule_cmd = genrule_cmd + ["--use_aten_lib"]
169+
if manual_registration:
170+
genrule_cmd = genrule_cmd + [
171+
"--manual_registration",
172+
]
152173
if custom_ops_yaml_path:
153174
genrule_cmd = genrule_cmd + [
154175
"--custom_ops_yaml_path=" + custom_ops_yaml_path,
@@ -160,13 +181,15 @@ def _prepare_genrule_and_lib(
160181

161182
if need_reg_ops:
162183
libs[name] = {
184+
"exported_headers": exported_headers,
163185
"genrule": genrule_name,
164186
"headers": headers,
165-
"srcs": GENERATED_SOURCES,
187+
"srcs": sources,
166188
}
167189

168190
header_lib = name + "_headers"
169191
libs[header_lib] = {
192+
"exported_headers": exported_headers,
170193
"headers": headers,
171194
}
172195
return genrules, libs
@@ -303,6 +326,7 @@ def executorch_generated_lib(
303326
custom_ops_requires_runtime_registration = True,
304327
visibility = [],
305328
aten_mode = False,
329+
manual_registration = False,
306330
use_default_aten_ops_lib = True,
307331
deps = [],
308332
xplat_deps = [],
@@ -350,6 +374,7 @@ def executorch_generated_lib(
350374
visibility: Visibility of the C++ library targets.
351375
deps: Additinal deps of the main C++ library. Needs to be in either `//executorch` or `//caffe2` module.
352376
platforms: platforms args to runtime.cxx_library (only used when in xplat)
377+
manual_registration: if true, generate RegisterKernels.cpp and RegisterKernels.h.
353378
use_default_aten_ops_lib: If `aten_mode` is True AND this flag is True, use `torch_mobile_all_ops` for ATen operator library.
354379
xplat_deps: Additional xplat deps, can be used to provide custom operator library.
355380
fbcode_deps: Additional fbcode deps, can be used to provide custom operator library.
@@ -394,6 +419,7 @@ def executorch_generated_lib(
394419
custom_ops_aten_kernel_deps = custom_ops_aten_kernel_deps,
395420
custom_ops_requires_runtime_registration = custom_ops_requires_runtime_registration,
396421
aten_mode = aten_mode,
422+
manual_registration = manual_registration,
397423
)
398424

399425
# genrule for selective build from static operator list
@@ -457,6 +483,7 @@ def executorch_generated_lib(
457483
# target, and are not meant to be used by targets outside of this
458484
# directory.
459485
headers = libs[lib_name]["headers"],
486+
exported_headers = libs[lib_name]["exported_headers"],
460487
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
461488
# link_whole is necessary because the operators register themselves via
462489
# static initializers that run at program startup.

0 commit comments

Comments
 (0)