Skip to content

Allow kernel manual registration #491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions codegen/templates/RegisterKernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// ${generated_comment}
// This implements register_all_kernels() API that is declared in
// RegisterKernels.h
#include "RegisterKernels.h"
#include "${fn_header}" // Generated Function import headers

namespace torch {
namespace executor {

Error register_all_kernels() {
Kernel kernels_to_register[] = {
${unboxed_kernels} // Generated kernels
};
Error success_with_kernel_reg = register_kernels(kernels_to_register);
if (success_with_kernel_reg != Error::Ok) {
ET_LOG(Error, "Failed register all kernels");
return success_with_kernel_reg;
}
return Error::Ok;
}

} // namespace executor
} // namespace torch
22 changes: 22 additions & 0 deletions codegen/templates/RegisterKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// ${generated_comment}
// Exposing an API for registering all kernels at once.
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/kernel/operator_registry.h>
#include <executorch/runtime/platform/profiler.h>

namespace torch {
namespace executor {

Error register_all_kernels();

} // namespace executor
} // namespace torch
24 changes: 24 additions & 0 deletions runtime/kernel/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,30 @@ def define_common_targets():
],
)

executorch_generated_lib(
name = "test_manual_registration_lib",
deps = [
":executorch_all_ops",
"//executorch/kernels/portable:operators",
],
functions_yaml_target = "//executorch/kernels/portable:functions.yaml",
manual_registration = True,
visibility = [
"//executorch/...",
],
)

runtime.cxx_test(
name = "test_kernel_manual_registration",
srcs = [
"test_kernel_manual_registration.cpp",
],
deps = [
"//executorch/runtime/kernel:operator_registry",
":test_manual_registration_lib",
],
)

for aten_mode in (True, False):
aten_suffix = "_aten" if aten_mode else ""

Expand Down
37 changes: 37 additions & 0 deletions runtime/kernel/test/test_kernel_manual_registration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// Include RegisterKernels.h and call register_all_kernels().
#include <gtest/gtest.h>

#include <executorch/runtime/kernel/operator_registry.h>
#include <executorch/runtime/kernel/test/RegisterKernels.h>
#include <executorch/runtime/platform/runtime.h>

using namespace ::testing;

namespace torch {
namespace executor {

class KernelManualRegistrationTest : public ::testing::Test {
public:
void SetUp() override {
torch::executor::runtime_init();
}
};

TEST_F(KernelManualRegistrationTest, ManualRegister) {
Error result = register_all_kernels();
// Check that we can find the kernel for foo.
EXPECT_EQ(result, Error::Ok);
EXPECT_FALSE(hasOpsFn("fpp"));
EXPECT_TRUE(hasOpsFn("aten::add.out"));
}

} // namespace executor
} // namespace torch
45 changes: 36 additions & 9 deletions shim/xplat/executorch/codegen/codegen.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ GENERATED_SOURCES = [
"RegisterCodegenUnboxedKernelsEverything.cpp",
]

MANUAL_REGISTRATION_SOURCES = [
# buildifier: keep sorted
"RegisterKernelsEverything.cpp",
]

MANUAL_REGISTRATION_HEADERS = [
"RegisterKernels.h",
]

# Fake kernels only return `out` or any other tensor from arguments
CUSTOM_OPS_DUMMY_KERNEL_SOURCES = ["Register{}Stub.cpp".format(backend) for backend in STATIC_DISPATCH_BACKENDS]

Expand All @@ -35,11 +44,9 @@ CUSTOM_OPS_SCHEMA_REGISTRATION_SOURCES = [
def et_operator_library(
name,
ops = [],
exported_deps = [],
model = None,
include_all_operators = False,
ops_schema_yaml_target = None,
define_static_targets = False,
**kwargs):
genrule_cmd = [
"$(exe //executorch/codegen/tools:gen_oplist)",
Expand All @@ -61,6 +68,10 @@ def et_operator_library(
genrule_cmd.append(
"--include_all_operators",
)

# TODO(larryliu0820): Remove usages of this flag.
if "define_static_targets" in kwargs:
kwargs.pop("define_static_targets")
runtime.genrule(
name = name,
macros_only = False,
Expand All @@ -70,18 +81,22 @@ def et_operator_library(
**kwargs
)

def _get_headers(genrule_name, prefix = "", custom_op = None):
def _get_headers(genrule_name, prefix = "", custom_op = None, manual_registration = False):
headers = OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_op else [])
return {
prefix + f: ":{}[{}]".format(genrule_name, f)
for f in OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_op else [])
for f in (MANUAL_REGISTRATION_HEADERS if manual_registration else [])
}, {
prefix + f: ":{}[{}]".format(genrule_name, f)
for f in headers
}

def _prepare_genrule_and_lib(
name,
functions_yaml_path = None,
custom_ops_yaml_path = None,
custom_ops_aten_kernel_deps = [],
custom_ops_requires_runtime_registration = True,
manual_registration = False,
aten_mode = False):
"""
This function returns two dicts `genrules` and `libs`, derived from the arguments being passed
Expand Down Expand Up @@ -122,15 +137,18 @@ def _prepare_genrule_and_lib(
# actually-generated files matches GENERATED_FILES.
]

# Sources for generated kernel registration lib
sources = MANUAL_REGISTRATION_SOURCES if manual_registration else GENERATED_SOURCES

# The command will always generate these files.
genrule_outs = GENERATED_SOURCES + OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_ops_yaml_path else [])
genrule_outs = sources + OPERATOR_HEADERS + (CUSTOM_OPS_NATIVE_FUNCTION_HEADER if custom_ops_yaml_path else []) + MANUAL_REGISTRATION_HEADERS

genrules = {}
libs = {}

# if aten_mode is true, we don't need functions_yaml_path
genrule_name = name + "_combined"
headers = _get_headers(genrule_name = genrule_name, custom_op = custom_ops_yaml_path)
exported_headers, headers = _get_headers(genrule_name = genrule_name, custom_op = custom_ops_yaml_path, manual_registration = manual_registration)

# need to register ATen ops into Executorch runtime:
need_reg_aten_ops = aten_mode or functions_yaml_path
Expand All @@ -149,6 +167,10 @@ def _prepare_genrule_and_lib(
]
if aten_mode:
genrule_cmd = genrule_cmd + ["--use_aten_lib"]
if manual_registration:
genrule_cmd = genrule_cmd + [
"--manual_registration",
]
if custom_ops_yaml_path:
genrule_cmd = genrule_cmd + [
"--custom_ops_yaml_path=" + custom_ops_yaml_path,
Expand All @@ -160,13 +182,15 @@ def _prepare_genrule_and_lib(

if need_reg_ops:
libs[name] = {
"exported_headers": exported_headers,
"genrule": genrule_name,
"headers": headers,
"srcs": GENERATED_SOURCES,
"srcs": sources,
}

header_lib = name + "_headers"
libs[header_lib] = {
"exported_headers": exported_headers,
"headers": headers,
}
return genrules, libs
Expand Down Expand Up @@ -303,6 +327,7 @@ def executorch_generated_lib(
custom_ops_requires_runtime_registration = True,
visibility = [],
aten_mode = False,
manual_registration = False,
use_default_aten_ops_lib = True,
deps = [],
xplat_deps = [],
Expand Down Expand Up @@ -350,6 +375,7 @@ def executorch_generated_lib(
visibility: Visibility of the C++ library targets.
deps: Additinal deps of the main C++ library. Needs to be in either `//executorch` or `//caffe2` module.
platforms: platforms args to runtime.cxx_library (only used when in xplat)
manual_registration: if true, generate RegisterKernels.cpp and RegisterKernels.h.
use_default_aten_ops_lib: If `aten_mode` is True AND this flag is True, use `torch_mobile_all_ops` for ATen operator library.
xplat_deps: Additional xplat deps, can be used to provide custom operator library.
fbcode_deps: Additional fbcode deps, can be used to provide custom operator library.
Expand Down Expand Up @@ -391,9 +417,9 @@ def executorch_generated_lib(
name = name,
functions_yaml_path = functions_yaml_path,
custom_ops_yaml_path = custom_ops_yaml_path,
custom_ops_aten_kernel_deps = custom_ops_aten_kernel_deps,
custom_ops_requires_runtime_registration = custom_ops_requires_runtime_registration,
aten_mode = aten_mode,
manual_registration = manual_registration,
)

# genrule for selective build from static operator list
Expand Down Expand Up @@ -457,6 +483,7 @@ def executorch_generated_lib(
# target, and are not meant to be used by targets outside of this
# directory.
headers = libs[lib_name]["headers"],
exported_headers = libs[lib_name]["exported_headers"],
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
# link_whole is necessary because the operators register themselves via
# static initializers that run at program startup.
Expand Down