Skip to content

Commit 9581086

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add buck2 example for custom ops
Summary: Similar to D48103652 and D48184410 but focusing on buck2 support. This example demonstrates the ability to build a `exir_custom_ops_aot_lib` by using only C++ implementations of both functional and out variants of the custom op. Specifically, let's assume a user has an op `my_ops::mul4` and already has a C++ implementation of it. The user registers this op through PyTorch C++ op registration API. Now they want to run the same op on Executorch, naturally they need `my_ops::mul4.out` and a C++ implementation of it so that Executorch runtime can consume. `exir_custom_ops_aot_lib` is able to reuse this C++ impl for Executorch runtime, link it with `libtorch` and generate code to register it into PyTorch JIT runtime. This way the user doesn't have to write an extra Python impl. This is the exact behavior demonstrated in D48184410, but instead of relying on CMake this diff enables BUCK2 targets. Reviewed By: kimishpatel Differential Revision: D48296936 fbshipit-source-id: 58ab7127f853549d29e7b25648dd7098219fe5c3
1 parent d1563b6 commit 9581086

File tree

9 files changed

+191
-38
lines changed

9 files changed

+191
-38
lines changed

examples/custom_ops/custom_ops_2.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
library that calls PyTorch C++ op registration API.
99
"""
1010

11+
import argparse
12+
1113
import torch
1214

1315
from examples.export.export_example import export_to_pte
1416

15-
1617
# example model
1718
class Model(torch.nn.Module):
1819
def forward(self, a):
@@ -22,21 +23,21 @@ def forward(self, a):
2223
def main():
2324
m = Model()
2425
input = torch.randn(2, 3)
25-
# load shared library
26-
from sys import platform
27-
28-
if platform == "linux" or platform == "linux2":
29-
extension = ".so"
30-
elif platform == "darwin":
31-
extension = ".dylib"
32-
else:
33-
raise RuntimeError(f"Unsupported platform {platform}")
34-
torch.ops.load_library(
35-
f"cmake-out/examples/custom_ops/libcustom_ops_aot_lib{extension}"
36-
)
26+
3727
# capture and lower
3828
export_to_pte("custom_ops_2", m, (input,))
3929

4030

4131
if __name__ == "__main__":
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument(
34+
"-s",
35+
"--so_library",
36+
required=True,
37+
help="Provide path to so library. E.g., cmake-out/examples/custom_ops/libcustom_ops_aot_lib.so",
38+
)
39+
args = parser.parse_args()
40+
torch.ops.load_library(args.so_library)
41+
print(args.so_library)
42+
4243
main()

examples/custom_ops/targets.bzl

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2-
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib")
2+
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib", "exir_custom_ops_aot_lib")
33

44
def define_common_targets():
55
"""Defines targets that should be shared between fbcode and xplat.
@@ -54,3 +54,75 @@ def define_common_targets():
5454
)
5555

5656
# ~~~ END of custom ops 1 `my_ops::mul3` library definitions ~~~
57+
# ~~~ START of custom ops 2 `my_ops::mul4` library definitions ~~~
58+
59+
et_operator_library(
60+
name = "select_custom_ops_2",
61+
ops = [
62+
"my_ops::mul4.out",
63+
],
64+
define_static_targets = True,
65+
visibility = [
66+
"//executorch/codegen/...",
67+
"@EXECUTORCH_CLIENTS",
68+
],
69+
)
70+
71+
runtime.cxx_library(
72+
name = "custom_ops_2",
73+
srcs = ["custom_ops_2_out.cpp"],
74+
deps = [
75+
"//executorch/runtime/kernel:kernel_includes",
76+
],
77+
visibility = [
78+
"//executorch/...",
79+
"@EXECUTORCH_CLIENTS",
80+
],
81+
)
82+
83+
runtime.cxx_library(
84+
name = "custom_ops_2_aten",
85+
srcs = [
86+
"custom_ops_2.cpp",
87+
"custom_ops_2_out.cpp",
88+
],
89+
deps = [
90+
"//executorch/runtime/kernel:kernel_includes_aten",
91+
],
92+
visibility = [
93+
"//executorch/...",
94+
"@EXECUTORCH_CLIENTS",
95+
],
96+
external_deps = ["libtorch"],
97+
# @lint-ignore BUCKLINT link_whole
98+
link_whole = True,
99+
# WARNING: using a deprecated API to avoid being built into a shared
100+
# library. In the case of dynamically loading so library we don't want
101+
# it to depend on other so libraries because that way we have to
102+
# specify library directory path.
103+
force_static = True,
104+
)
105+
106+
exir_custom_ops_aot_lib(
107+
name = "custom_ops_aot_lib_2",
108+
yaml_target = ":custom_ops.yaml",
109+
visibility = ["//executorch/..."],
110+
kernels = [":custom_ops_2_aten"],
111+
deps = [
112+
":select_custom_ops_2",
113+
],
114+
)
115+
116+
executorch_generated_lib(
117+
name = "lib_2",
118+
deps = [
119+
":select_custom_ops_2",
120+
":custom_ops_2",
121+
],
122+
custom_ops_yaml_target = ":custom_ops.yaml",
123+
visibility = [
124+
"//executorch/...",
125+
"@EXECUTORCH_CLIENTS",
126+
],
127+
)
128+
# ~~~ END of custom ops 2 `my_ops::mul4` library definitions ~~~

examples/custom_ops/test_custom_ops.sh

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ test_buck2_custom_op_1() {
1919

2020
echo 'Running executor_runner'
2121
buck2 run //examples/executor_runner:executor_runner \
22-
--config=executorch.register_custom_op_1=1 -- --model_path="./${model_name}.pte"
22+
--config=executorch.register_custom_op=1 -- --model_path="./${model_name}.pte"
2323
# should give correct result
2424

2525
echo "Removing ${model_name}.pte"
@@ -43,6 +43,38 @@ test_cmake_custom_op_1() {
4343
cmake-out/executor_runner --model_path="./${model_name}.pte"
4444
}
4545

46+
test_buck2_custom_op_2() {
47+
local model_name='custom_ops_2'
48+
49+
echo 'Building custom ops shared library'
50+
SO_LIB=$(buck2 build //examples/custom_ops:custom_ops_aot_lib_2 --show-output | grep "buck-out" | cut -d" " -f2)
51+
52+
echo "Exporting ${model_name}.pte"
53+
python3 -m "examples.custom_ops.${model_name}" --so_library="$SO_LIB"
54+
# should save file custom_ops_2.pte
55+
56+
buck2 run //examples/executor_runner:executor_runner \
57+
--config=executorch.register_custom_op=2 -- --model_path="./${model_name}.pte"
58+
# should give correct result
59+
echo "Removing ${model_name}.pte"
60+
rm "./${model_name}.pte"
61+
}
62+
63+
get_shared_lib_ext() {
64+
UNAME=$(uname)
65+
if [[ $UNAME == "Darwin" ]];
66+
then
67+
EXT=".dylib"
68+
elif [[ $UNAME == "Linux" ]];
69+
then
70+
EXT=".so"
71+
else
72+
echo "Unsupported platform $UNAME"
73+
exit 1
74+
fi
75+
echo $EXT
76+
}
77+
4678
test_cmake_custom_op_2() {
4779
local model_name='custom_ops_2'
4880
SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
@@ -56,10 +88,11 @@ test_cmake_custom_op_2() {
5688
-DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" ..)
5789

5890
echo 'Building executor_runner'
59-
cmake --build cmake-out -j9
91+
cmake --build cmake-out -j4
6092

93+
EXT=$(get_shared_lib_ext)
6194
echo "Exporting ${model_name}.pte"
62-
python3 -m "examples.custom_ops.${model_name}"
95+
python3 -m "examples.custom_ops.${model_name}" --so_library="cmake-out/examples/custom_ops/libcustom_ops_aot_lib$EXT"
6396
# should save file custom_ops_2.pte
6497

6598
echo 'Running executor_runner'
@@ -68,4 +101,5 @@ test_cmake_custom_op_2() {
68101

69102
test_buck2_custom_op_1
70103
test_cmake_custom_op_1
104+
test_buck2_custom_op_2
71105
test_cmake_custom_op_2

examples/executor_runner/targets.bzl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@ def define_common_targets():
77
TARGETS and BUCK files that call this function.
88
"""
99

10-
register_custom_op_1 = native.read_config("executorch", "register_custom_op_1", "0") == "1"
10+
register_custom_op = native.read_config("executorch", "register_custom_op", "0")
1111

12-
custom_ops_lib = ["//executorch/examples/custom_ops:lib_1"] if register_custom_op_1 else []
12+
if register_custom_op == "1":
13+
custom_ops_lib = ["//executorch/examples/custom_ops:lib_1"]
14+
elif register_custom_op == "2":
15+
custom_ops_lib = ["//executorch/examples/custom_ops:lib_2"]
16+
else:
17+
custom_ops_lib = []
1318

1419
# Test driver for models, uses all portable kernels.
1520
runtime.cxx_binary(

runtime/core/exec_aten/util/targets.bzl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,9 @@ def define_common_targets():
6161
":scalar_type_util" + aten_suffix,
6262
":dim_order_util" + aten_suffix,
6363
],
64+
# WARNING: using a deprecated API to avoid being built into a shared
65+
# library. In the case of dynamically loading so library we don't want
66+
# it to depend on other so libraries because that way we have to
67+
# specify library directory path.
68+
force_static = True,
6469
)

runtime/platform/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ def define_common_targets():
5050
visibility = [
5151
"//executorch/core/...",
5252
],
53+
# WARNING: using a deprecated API to avoid being built into a shared
54+
# library. In the case of dynamically loading so library we don't want
55+
# it to depend on other so libraries because that way we have to
56+
# specify library directory path.
57+
force_static = True,
5358
)
5459

5560
# Interfaces for executorch users
@@ -78,6 +83,11 @@ def define_common_targets():
7883
"//executorch/...",
7984
"@EXECUTORCH_CLIENTS",
8085
],
86+
# WARNING: using a deprecated API to avoid being built into a shared
87+
# library. In the case of dynamically loading so library we don't want
88+
# it to depend on other so libraries because that way we have to
89+
# specify library directory path.
90+
force_static = True,
8191
)
8292

8393
# Library for backend implementers to define implementations against.

shim/xplat/executorch/build/env_interface.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def _remove_platform_specific_args(kwargs):
107107
kwargs.pop(key)
108108
return kwargs
109109

110+
def _remove_unsupported_kwargs(kwargs):
111+
"""Removes environment unsupported kwargs
112+
"""
113+
return kwargs
114+
110115
def _patch_headers(kwargs):
111116
"""Patch (add or modify or remove) headers related attributes for this build environment.
112117
"""
@@ -202,6 +207,7 @@ env = struct(
202207
# @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode.
203208
python_test = native.python_test,
204209
remove_platform_specific_args = _remove_platform_specific_args,
210+
remove_unsupported_kwargs = _remove_unsupported_kwargs,
205211
resolve_external_dep = _resolve_external_dep,
206212
target_needs_patch = _target_needs_patch,
207213
EXTERNAL_DEP_FALLTHROUGH = _EXTERNAL_DEP_FALLTHROUGH,

shim/xplat/executorch/build/runtime_wrapper.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def _patch_kwargs_common(kwargs):
208208
return kwargs
209209

210210
def _patch_kwargs_cxx(kwargs):
211+
env.remove_unsupported_kwargs(kwargs)
211212
env.patch_platforms(kwargs)
212213
env.remove_platform_specific_args(kwargs)
213214
return _patch_kwargs_common(kwargs)

third-party/TARGETS

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -218,35 +218,54 @@ runtime.cxx_binary(
218218

219219
runtime.genrule(
220220
name = "libtorch_gen",
221-
out = "libtorch.so",
221+
outs = select({
222+
"ovr_config//os:macos": {
223+
"libtorch": ["libtorch.dylib"],
224+
"libc10": ["libc10.dylib"],
225+
"libtorch_cpu": ["libtorch_cpu.dylib"],
226+
"include": ["include"],
227+
},
228+
"DEFAULT": {
229+
"libtorch": ["libtorch.so"],
230+
"libc10": ["libc10.so"],
231+
"libtorch_cpu": ["libtorch_cpu.so"],
232+
"include": ["include"],
233+
},
234+
}),
235+
default_outs = ["."],
222236
srcs = ["link_torch.sh"],
223-
bash = 'bash $SRCS torch/lib/libtorch.so ${OUT}',
237+
bash = select({
238+
"ovr_config//os:macos": "bash $SRCS -f torch/lib/libtorch.dylib,torch/lib/libtorch_cpu.dylib,torch/lib/libc10.dylib,torch/include -o ${OUT}",
239+
"DEFAULT": "bash $SRCS -f torch/lib/libtorch.so,torch/lib/libtorch_cpu.so,torch/lib/libc10.so,torch/include -o ${OUT}",
240+
}),
224241
)
225242

226-
runtime.genrule(
227-
name = "torch_headers_gen",
228-
out = "include",
229-
srcs = ["link_torch.sh"],
230-
bash = 'bash $SRCS torch/include ${OUT}',
243+
prebuilt_cxx_library(
244+
name = "libc10",
245+
shared_lib = ":libtorch_gen[libc10]",
231246
)
232247

233-
runtime.genrule(
234-
name = "libtorch_dir",
235-
out = "lib",
236-
srcs = ["link_torch.sh"],
237-
bash = 'bash $SRCS torch/lib ${OUT}',
248+
prebuilt_cxx_library(
249+
name = "libtorch_cpu",
250+
shared_lib = ":libtorch_gen[libtorch_cpu]",
238251
)
239252

240253
prebuilt_cxx_library(
241254
name = "libtorch",
242-
shared_lib = ":libtorch_gen",
243-
soname = "libtorch.so",
255+
shared_lib = ":libtorch_gen[libtorch]",
244256
exported_preprocessor_flags = [
245-
"-D_GLIBCXX_USE_CXX11_ABI=0", # `libtorch` is built without CXX11_ABI so any target depends on it need to use the same build config.
246-
"-I$(location :torch_headers_gen)", # include header directories
247-
"-I$(location :torch_headers_gen)/torch/csrc/api/include", # include header directories
257+
"-D_GLIBCXX_USE_CXX11_ABI=0", # `libtorch` is built without CXX11_ABI so any target depends on it need to use the same build config.
258+
"-I$(location :libtorch_gen[include])", # include header directories
259+
"-I$(location :libtorch_gen[include])/torch/csrc/api/include", # include header directories
260+
],
261+
exported_linker_flags = select({
262+
"ovr_config//os:macos": ["-Xlinker", "-rpath", "$(location :libtorch_gen)", "-Xlinker"],
263+
"DEFAULT": ["-Wl,-rpath,$(location :libtorch_gen)"], # define rpath to locate shared library
264+
}),
265+
exported_headers = [":libtorch_gen[include]"],
266+
exported_deps = [
267+
":libc10",
268+
":libtorch_cpu",
248269
],
249-
exported_linker_flags = ["-Wl,-rpath,$(location :libtorch_dir)"], # define rpath to locate shared library
250-
exported_headers = [":torch_headers_gen"],
251270
visibility = ["PUBLIC"],
252271
)

0 commit comments

Comments
 (0)