Skip to content

Commit 5462df7

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add CMake build example for custom ops
Summary: This new example demonstrates how to register custom ops using PyTorch C++ APIs and how to build a library to link to both AOT and runtime. Reviewed By: digantdesai Differential Revision: D48184410 fbshipit-source-id: 661087a5183b9cfff8ebbc541ea032cdb0f80b06
1 parent 217ddba commit 5462df7

File tree

9 files changed

+228
-34
lines changed

9 files changed

+228
-34
lines changed

CMakeLists.txt

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,16 @@
4343
cmake_minimum_required(VERSION 3.13)
4444
project(executorch)
4545

46-
# option to register custom ops in `examples/custom_ops`
47-
option(REGISTER_EXAMPLE_CUSTOM_OPS
48-
"Register custom ops defined in examples/custom_ops" OFF)
46+
# option to register custom operator `my_ops::mul3` in
47+
# `examples/custom_ops/custom_ops_1.py`
48+
option(REGISTER_EXAMPLE_CUSTOM_OP_1
49+
"Register custom operator defined in examples/custom_ops/custom_ops_1.py"
50+
OFF)
51+
# option to register custom operator `my_ops::mul4` in
52+
# `examples/custom_ops/custom_ops_2.py`
53+
option(REGISTER_EXAMPLE_CUSTOM_OP_2
54+
"Register custom operator defined in examples/custom_ops/custom_ops_2.py"
55+
OFF)
4956

5057
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
5158
if(NOT CMAKE_CXX_STANDARD)
@@ -58,6 +65,11 @@ endif()
5865
# TODO(dbort): Fix these warnings and remove this flag.
5966
set(_common_compile_options -Wno-deprecated-declarations)
6067

68+
if(REGISTER_EXAMPLE_CUSTOM_OP_2)
69+
# Need to be linked to a shared library
70+
list(APPEND _common_compile_options -fPIC)
71+
endif()
72+
6173
# Let files say "include <executorch/path/to/header.h>".
6274
set(_common_include_directories ${CMAKE_CURRENT_SOURCE_DIR}/..)
6375

@@ -274,7 +286,7 @@ target_link_libraries(executor_runner executorch portable_kernels_bindings
274286
target_compile_options(executor_runner PUBLIC ${_common_compile_options})
275287

276288
# Generate custom_ops_lib based on REGISTER_EXAMPLE_CUSTOM_OPS
277-
if(REGISTER_EXAMPLE_CUSTOM_OPS)
289+
if(REGISTER_EXAMPLE_CUSTOM_OP_1 OR REGISTER_EXAMPLE_CUSTOM_OP_2)
278290
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/examples/custom_ops)
279291
target_link_libraries(executor_runner custom_ops_lib)
280292
endif()

examples/custom_ops/CMakeLists.txt

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,17 @@ file(GLOB_RECURSE _codegen_tools_srcs "${EXECUTORCH_ROOT}/codegen/tools/*.py")
3737
file(GLOB_RECURSE _codegen_templates "${EXECUTORCH_ROOT}/codegen/templates/*")
3838
file(GLOB_RECURSE _torchgen_srcs "${TORCH_ROOT}/torchgen/*.py")
3939

40-
set(_gen_oplist_command
41-
"${PYTHON_EXECUTABLE}" -m codegen.tools.gen_oplist
42-
--output_path=${_oplist_yaml}
43-
--ops_schema_yaml_path=${CMAKE_CURRENT_LIST_DIR}/custom_ops.yaml)
40+
# Selective build. If we want to register all ops in custom_ops.yaml, do
41+
# `--ops_schema_yaml_path=${CMAKE_CURRENT_LIST_DIR}/custom_ops.yaml)` instead of
42+
# `root_ops`
43+
set(_gen_oplist_command "${PYTHON_EXECUTABLE}" -m codegen.tools.gen_oplist
44+
--output_path=${_oplist_yaml})
45+
46+
if(REGISTER_EXAMPLE_CUSTOM_OP_2)
47+
list(APPEND _gen_oplist_command --root_ops="my_ops::mul4.out")
48+
elseif(REGISTER_EXAMPLE_CUSTOM_OP_1)
49+
list(APPEND _gen_oplist_command --root_ops="my_ops::mul3.out")
50+
endif()
4451

4552
# Command to codegen C++ wrappers to register custom ops to both PyTorch and
4653
# Executorch runtime.
@@ -78,31 +85,56 @@ add_custom_command(
7885
WORKING_DIRECTORY ${EXECUTORCH_ROOT})
7986
# Prepare for C++ libraries.
8087

81-
# 1. TODO: C++ library to register custom ops into PyTorch.
82-
# ~~~
83-
# add_library(custom_ops_aot_lib SHARED
84-
# ${OUTPUT_DIRECTORY}/RegisterCPUCustomOps.cpp
85-
# ${OUTPUT_DIRECTORY}/RegisterSchema.cpp
86-
# ${OUTPUT_DIRECTORY}/CustomOpsNativeFunctions.h)
87-
# ~~~
88+
# 1. C++ library to register custom ops into PyTorch.
89+
if(REGISTER_EXAMPLE_CUSTOM_OP_2)
90+
add_library(
91+
custom_ops_aot_lib SHARED
92+
${CMAKE_CURRENT_BINARY_DIR}/RegisterCPUCustomOps.cpp
93+
${CMAKE_CURRENT_BINARY_DIR}/RegisterSchema.cpp
94+
${CMAKE_CURRENT_BINARY_DIR}/CustomOpsNativeFunctions.h
95+
${CMAKE_CURRENT_LIST_DIR}/custom_ops_2.cpp # register my_ops::mul4
96+
${CMAKE_CURRENT_LIST_DIR}/custom_ops_2_out.cpp # register my_ops::mul4.out
97+
)
98+
# Find `Torch`.
99+
find_package(Torch REQUIRED)
100+
# ATen mode is on
101+
target_compile_definitions(custom_ops_aot_lib PRIVATE USE_ATEN_LIB=1)
102+
target_include_directories(custom_ops_aot_lib
103+
PUBLIC ${_common_include_directories})
104+
include_directories(${TORCH_INCLUDE_DIRS})
88105

89-
# Find `Torch`.
90-
# ~~~
91-
# find_package(Torch REQUIRED)
92-
# target_link_libraries(custom_ops_aot_lib PUBLIC Torch)
93-
# ~~~
106+
target_link_libraries(custom_ops_aot_lib PRIVATE torch executorch)
94107

95-
# 1. C++ library to register custom ops into Executorch runtime.
108+
# Ensure that the load-time constructor functions run. By default, the linker
109+
# would remove them since there are no other references to them.
110+
if((CMAKE_CXX_COMPILER_ID MATCHES "AppleClang")
111+
OR (APPLE AND CMAKE_CXX_COMPILER_ID MATCHES "Clang"))
112+
target_link_options(custom_ops_aot_lib INTERFACE
113+
"-Wl,-force_load,$<TARGET_FILE:custom_ops_aot_lib>")
114+
elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang|GNU")
115+
target_link_options(
116+
custom_ops_aot_lib INTERFACE
117+
"-Wl,--whole-archive,$<TARGET_FILE:custom_ops_aot_lib>,--no-whole-archive"
118+
)
119+
endif()
120+
endif()
96121

122+
# 1. C++ library to register custom ops into Executorch runtime.
97123
add_library(custom_ops_lib)
98124
target_sources(
99125
custom_ops_lib
100126
PRIVATE
101127
${CMAKE_CURRENT_BINARY_DIR}/RegisterCodegenUnboxedKernelsEverything.cpp
102128
${CMAKE_CURRENT_BINARY_DIR}/Functions.h
103129
${CMAKE_CURRENT_BINARY_DIR}/NativeFunctions.h
104-
${CMAKE_CURRENT_BINARY_DIR}/CustomOpsNativeFunctions.h
105-
${CMAKE_CURRENT_LIST_DIR}/custom_ops_1.cpp)
130+
${CMAKE_CURRENT_BINARY_DIR}/CustomOpsNativeFunctions.h)
131+
if(REGISTER_EXAMPLE_CUSTOM_OP_1)
132+
target_sources(custom_ops_lib
133+
PRIVATE ${CMAKE_CURRENT_LIST_DIR}/custom_ops_1_out.cpp)
134+
elseif(REGISTER_EXAMPLE_CUSTOM_OP_2)
135+
target_sources(custom_ops_lib
136+
PRIVATE ${CMAKE_CURRENT_LIST_DIR}/custom_ops_2_out.cpp)
137+
endif()
106138

107139
target_link_libraries(custom_ops_lib PRIVATE executorch)
108140

examples/custom_ops/custom_ops_2.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <torch/library.h>
11+
12+
namespace custom {
13+
namespace native {
14+
15+
using at::Tensor;
16+
using c10::ScalarType;
17+
18+
// mul4(Tensor input) -> Tensor
19+
Tensor mul4_impl(const Tensor& in) {
20+
// naive approach
21+
at::Tensor out = at::zeros_like(in);
22+
out.copy_(in);
23+
out.mul_(4);
24+
return out;
25+
}
26+
27+
TORCH_LIBRARY_FRAGMENT(my_ops, m) {
28+
m.def(TORCH_SELECTIVE_SCHEMA("my_ops::mul4(Tensor input) -> Tensor"));
29+
}
30+
31+
TORCH_LIBRARY_IMPL(my_ops, CompositeExplicitAutograd, m) {
32+
m.impl(TORCH_SELECTIVE_NAME("my_ops::mul4"), TORCH_FN(mul4_impl));
33+
}
34+
} // namespace native
35+
} // namespace custom

examples/custom_ops/custom_ops_2.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Example of showcasing registering custom operator through torch library API."""
8+
import torch
9+
10+
from examples.export.export_example import export_to_ff
11+
12+
torch.ops.load_library("cmake-out/examples/custom_ops/libcustom_ops_aot_lib.so")
13+
14+
# example model
15+
class Model(torch.nn.Module):
16+
def forward(self, a):
17+
return torch.ops.my_ops.mul4.default(a)
18+
19+
20+
def main():
21+
m = Model()
22+
input = torch.randn(2, 3)
23+
# capture and lower
24+
export_to_ff("custom_ops_2", m, (input,))
25+
26+
27+
if __name__ == "__main__":
28+
main()
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
11+
namespace custom {
12+
namespace native {
13+
14+
using exec_aten::ScalarType;
15+
using exec_aten::Tensor;
16+
using torch::executor::RuntimeContext;
17+
18+
namespace {
19+
void check_preconditions(const Tensor& in, Tensor& out) {
20+
ET_CHECK_MSG(
21+
out.scalar_type() == ScalarType::Float,
22+
"Expected out tensor to have dtype Float, but got %hhd instead",
23+
out.scalar_type());
24+
ET_CHECK_MSG(
25+
in.scalar_type() == ScalarType::Float,
26+
"Expected in tensor to have dtype Float, but got %hhd instead",
27+
in.scalar_type());
28+
ET_CHECK_MSG(
29+
out.dim() == in.dim(),
30+
"Number of dims of out tensor is not compatible with inputs");
31+
ET_CHECK_MSG(
32+
out.numel() == in.numel(),
33+
"Number of elements of out tensor %zd is not compatible with inputs %zd",
34+
ssize_t(out.numel()),
35+
ssize_t(in.numel()));
36+
}
37+
} // namespace
38+
// mul4.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)
39+
Tensor& mul4_out_impl(const Tensor& in, Tensor& out) {
40+
check_preconditions(in, out);
41+
float* out_data = out.mutable_data_ptr<float>();
42+
const float* in_data = in.const_data_ptr<float>();
43+
for (size_t out_idx = 0; out_idx < out.numel(); ++out_idx) {
44+
out_data[out_idx] = in_data[out_idx] * 4;
45+
}
46+
return out;
47+
}
48+
49+
Tensor& mul4_out_impl(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
50+
(void)ctx;
51+
mul4_out_impl(in, out);
52+
return out;
53+
}
54+
55+
} // namespace native
56+
} // namespace custom

examples/custom_ops/targets.bzl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@ def define_common_targets():
1515
],
1616
)
1717

18+
# ~~~ START of custom ops 1 `my_ops::mul3` library definitions ~~~
1819
et_operator_library(
19-
name = "executorch_all_ops",
20-
include_all_operators = True,
20+
name = "select_custom_ops_1",
21+
ops = [
22+
"my_ops::mul3.out",
23+
],
2124
define_static_targets = True,
2225
visibility = [
2326
"//executorch/codegen/...",
@@ -26,8 +29,8 @@ def define_common_targets():
2629
)
2730

2831
runtime.cxx_library(
29-
name = "custom_kernel_lib",
30-
srcs = ["custom_ops_1.cpp"],
32+
name = "custom_ops_1",
33+
srcs = ["custom_ops_1_out.cpp"],
3134
deps = [
3235
"//executorch/runtime/kernel:kernel_includes",
3336
],
@@ -38,14 +41,16 @@ def define_common_targets():
3841
)
3942

4043
executorch_generated_lib(
41-
name = "generated_lib",
44+
name = "lib_1",
4245
deps = [
43-
":executorch_all_ops",
44-
":custom_kernel_lib",
46+
":select_custom_ops_1",
47+
":custom_ops_1",
4548
],
4649
custom_ops_yaml_target = ":custom_ops.yaml",
4750
visibility = [
4851
"//executorch/...",
4952
"@EXECUTORCH_CLIENTS",
5053
],
5154
)
55+
56+
# ~~~ END of custom ops 1 `my_ops::mul3` library definitions ~~~

examples/custom_ops/test_custom_ops.sh

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ test_buck2_custom_op_1() {
1919

2020
echo 'Running executor_runner'
2121
buck2 run //fbcode/executorch/examples/executor_runner:executor_runner \
22-
--config=executorch.include_custom_ops=1 -- --model_path="./${model_name}.pte"
22+
--config=executorch.register_custom_op_1=1 -- --model_path="./${model_name}.pte"
2323
# should give correct result
2424

2525
echo "Removing ${model_name}.pte"
@@ -34,7 +34,7 @@ test_cmake_custom_op_1() {
3434
(rm -rf cmake-out \
3535
&& mkdir cmake-out \
3636
&& cd cmake-out \
37-
&& cmake -DBUCK2=buck2 -DBUILD_EXAMPLE_CUSTOM_OPS=ON ..)
37+
&& cmake -DBUCK2=buck2 -DREGISTER_EXAMPLE_CUSTOM_OP_1=ON ..)
3838

3939
echo 'Building executor_runner'
4040
cmake --build cmake-out -j9
@@ -43,5 +43,29 @@ test_cmake_custom_op_1() {
4343
cmake-out/executor_runner --model_path="./${model_name}.pte"
4444
}
4545

46+
test_cmake_custom_op_2() {
47+
local model_name='custom_ops_2'
48+
SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
49+
CMAKE_PREFIX_PATH="${SITE_PACKAGES}/torch"
50+
51+
(rm -rf cmake-out \
52+
&& mkdir cmake-out \
53+
&& cd cmake-out \
54+
&& cmake -DBUCK2=buck2 \
55+
-DREGISTER_EXAMPLE_CUSTOM_OP_2=ON \
56+
-DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" ..)
57+
58+
echo 'Building executor_runner'
59+
cmake --build cmake-out -j9
60+
61+
echo "Exporting ${model_name}.pte"
62+
python3 -m "examples.custom_ops.${model_name}"
63+
# should save file custom_ops_2.pte
64+
65+
echo 'Running executor_runner'
66+
cmake-out/executor_runner "--model_path=./${model_name}.pte"
67+
}
68+
4669
test_buck2_custom_op_1
4770
test_cmake_custom_op_1
71+
test_cmake_custom_op_2

examples/executor_runner/targets.bzl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ def define_common_targets():
77
TARGETS and BUCK files that call this function.
88
"""
99

10-
include_custom_ops = native.read_config("executorch", "include_custom_ops", "0") == "1"
10+
register_custom_op_1 = native.read_config("executorch", "register_custom_op_1", "0") == "1"
11+
12+
custom_ops_lib = ["//executorch/examples/custom_ops:lib_1"] if register_custom_op_1 else []
1113

1214
# Test driver for models, uses all portable kernels.
1315
runtime.cxx_binary(
@@ -19,7 +21,7 @@ def define_common_targets():
1921
"//executorch/extension/data_loader:file_data_loader",
2022
"//executorch/util:util",
2123
"//executorch/kernels/portable:generated_lib_all_ops",
22-
] + (["//executorch/examples/custom_ops:generated_lib"] if include_custom_ops else []),
24+
] + custom_ops_lib,
2325
external_deps = [
2426
"gflags",
2527
],

0 commit comments

Comments
 (0)