Skip to content

Commit 9870e11

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add custom ops registration to examples
Summary: ## Context I plan to add 2 (or 3) examples for different custom ops registration mechanisms. User should be able to use any of these options to use their custom ops. A proper README.md will be added. ## Solution For the first option, we support the traditional PyTorch op registration python API. This requires users to write python implementations of both functional op and out variant op, like demonstrated in this diff. Note that those ops are only being registered into PyTorch JIT runtime for EXIR to consume. We also use buck2 target macro `executorch_generated_lib` to register custom ops to Executorch runtime. For the second option, we want to leverage the C++ kernel user wrote for Executorch runtime, treat it as a valid PyTorch op kernel and register it into PyTorch JIT runtime. This way we don't have to write any python kernel. This can be done through CMake build, by pulling in PyTorch C++ dependency, then enabling ATen mode. This will be done once CMake diff D47927863 is landed. The third option will be the same as the first but on CMake build system. Note that CMake and Buck2 will then have different capabilities because pulling PyTorch C++ lib in Buck2 can't reuse the existing BUCK files. Reviewed By: cccclai Differential Revision: D48054313 fbshipit-source-id: 15fe77a4a69f3260b8fe09d7ce51b2f1e92cce68
1 parent 789f4ce commit 9870e11

File tree

8 files changed

+226
-0
lines changed

8 files changed

+226
-0
lines changed

examples/custom_ops/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Custom Operator Registration Examples (WIP)
2+
This folder contains examples to register custom operators into PyTorch as well as register its kernels into Executorch runtime.
3+
4+
## How to run
5+
6+
Prerequisite: finish the [setting up wiki](https://github.com/pytorch/executorch/blob/main/docs/website/docs/tutorials/00_setting_up_executorch.md).
7+
8+
Run:
9+
10+
```bash
11+
bash test_custom_ops.sh
12+
```
13+
14+
## AOT registration
15+
16+
In order to use custom ops in Executorch AOT flow (EXIR), the first option is to register the custom ops into PyTorch JIT runtime using `torch.library` APIs.
17+
18+
We can see the example in `custom_ops_1.py` where we try to register `my_ops::mul3` and `my_ops::mul3_out`. `my_ops` is the namespace and it will show up in the way we use the operator like `torch.ops.my_ops.mul3.default`. For more information about PyTorch operator, checkout [`pytorch/torch/_ops.py`](https://github.com/pytorch/pytorch/blob/main/torch/_ops.py).
19+
20+
Notice that we need both functional variant and out variant for custom ops, because EXIR will need to perform memory planning on the out variant `my_ops::mul3_out`.
21+
22+
## C++ kernel registration
23+
24+
After the model is exported by EXIR, we need C++ implementations of these custom ops in order to run it. `custom_ops_1.cpp` is an example C++ kernel. Other than that, we also need a way to bind the PyTorch op to this kernel. This binding is specified in `custom_ops.yaml`:
25+
```yaml
26+
- func: my_ops::mul3.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)
27+
kernels:
28+
- arg_meta: null
29+
kernel_name: custom::mul3_out_impl # sub-namespace native:: is auto-added
30+
```
31+
For how to write these YAML entries, please refer to [`kernels/portable/README.md`](https://github.com/pytorch/executorch/blob/main/kernels/portable/README.md).

examples/custom_ops/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()

examples/custom_ops/custom_ops.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# See the kernels/portable/README.md for a description of the syntax used
4+
# by this file.
5+
6+
# important to keep the namespace
7+
- func: my_ops::mul3.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)
8+
kernels:
9+
- arg_meta: null
10+
kernel_name: custom::mul3_out_impl # sub-namespace native:: is auto-added

examples/custom_ops/custom_ops_1.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
11+
namespace custom {
12+
namespace native {
13+
14+
using exec_aten::ScalarType;
15+
using exec_aten::Tensor;
16+
using torch::executor::RuntimeContext;
17+
18+
namespace {
19+
void check_preconditions(const Tensor& in, Tensor& out) {
20+
ET_CHECK_MSG(
21+
out.scalar_type() == ScalarType::Float,
22+
"Expected out tensor to have dtype Float, but got %hhd instead",
23+
out.scalar_type());
24+
ET_CHECK_MSG(
25+
in.scalar_type() == ScalarType::Float,
26+
"Expected in tensor to have dtype Float, but got %hhd instead",
27+
in.scalar_type());
28+
ET_CHECK_MSG(
29+
out.dim() == in.dim(),
30+
"Number of dims of out tensor is not compatible with inputs");
31+
ET_CHECK_MSG(
32+
out.numel() == in.numel(),
33+
"Number of elements of out tensor %zd is not compatible with inputs %zd",
34+
ssize_t(out.numel()),
35+
ssize_t(in.numel()));
36+
}
37+
} // namespace
38+
// mul3.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)
39+
Tensor& mul3_out_impl(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
40+
(void)ctx;
41+
42+
check_preconditions(in, out);
43+
float* out_data = out.mutable_data_ptr<float>();
44+
const float* in_data = in.const_data_ptr<float>();
45+
for (size_t out_idx = 0; out_idx < out.numel(); ++out_idx) {
46+
out_data[out_idx] = in_data[out_idx] * 3;
47+
}
48+
return out;
49+
}
50+
} // namespace native
51+
} // namespace custom

examples/custom_ops/custom_ops_1.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Example of showcasing registering custom operator through torch library API."""
8+
import torch
9+
10+
from examples.export.export_example import export_to_ff
11+
from torch.library import impl, Library
12+
13+
my_op_lib = Library("my_ops", "DEF")
14+
15+
# registering an operator that multiplies input tensor by 3 and returns it.
16+
my_op_lib.define("mul3(Tensor input) -> Tensor") # should print 'mul3'
17+
18+
19+
@impl(my_op_lib, "mul3", dispatch_key="CompositeExplicitAutograd")
20+
def mul3_impl(a: torch.Tensor) -> torch.Tensor:
21+
return a * 3
22+
23+
24+
# registering the out variant.
25+
my_op_lib.define(
26+
"mul3.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)"
27+
) # should print 'mul3.out'
28+
29+
30+
@impl(my_op_lib, "mul3.out", dispatch_key="CompositeExplicitAutograd")
31+
def mul3_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
32+
a.mul_(3)
33+
out.copy_(a)
34+
return out
35+
36+
37+
# example model
38+
class Model(torch.nn.Module):
39+
def forward(self, a):
40+
return torch.ops.my_ops.mul3.default(a)
41+
42+
43+
def main():
44+
m = Model()
45+
input = torch.randn(2, 3)
46+
# capture and lower
47+
export_to_ff("custom_ops_1", m, (input,))
48+
49+
50+
if __name__ == "__main__":
51+
main()

examples/custom_ops/targets.bzl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib")
3+
4+
def define_common_targets():
5+
"""Defines targets that should be shared between fbcode and xplat.
6+
7+
The directory containing this targets.bzl file should also contain both
8+
TARGETS and BUCK files that call this function.
9+
"""
10+
runtime.export_file(
11+
name = "custom_ops.yaml",
12+
visibility = [
13+
"//executorch/...",
14+
"@EXECUTORCH_CLIENTS",
15+
],
16+
)
17+
18+
et_operator_library(
19+
name = "executorch_all_ops",
20+
include_all_operators = True,
21+
define_static_targets = True,
22+
visibility = [
23+
"//executorch/codegen/...",
24+
"@EXECUTORCH_CLIENTS",
25+
],
26+
)
27+
28+
runtime.cxx_library(
29+
name = "custom_kernel_lib",
30+
srcs = ["custom_ops_1.cpp"],
31+
deps = [
32+
"//executorch/runtime/kernel:kernel_includes",
33+
],
34+
visibility = [
35+
"//executorch/...",
36+
"@EXECUTORCH_CLIENTS",
37+
],
38+
)
39+
40+
executorch_generated_lib(
41+
name = "generated_lib",
42+
deps = [
43+
":executorch_all_ops",
44+
":custom_kernel_lib",
45+
],
46+
custom_ops_yaml_target = ":custom_ops.yaml",
47+
visibility = [
48+
"//executorch/...",
49+
"@EXECUTORCH_CLIENTS",
50+
],
51+
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# Test the end-to-end flow of using custom operator in a PyTorch model and use EXIR to capture and export a model file. Then use `executor_runner` demo C++ binary to run the model.
9+
10+
test_custom_op_1() {
11+
echo 'Exporting custom_ops_1.pte'
12+
python3 -m examples.custom_ops.custom_ops_1
13+
# should save file custom_ops_1.pte
14+
15+
echo 'Running executor_runner'
16+
buck2 run //fbcode/executorch/examples/executor_runner:executor_runner -- --model_path=./custom_ops_1.pte
17+
# should give correct result
18+
19+
echo 'Removing custom_ops_1.pte'
20+
rm ./custom_ops_1.pte
21+
}
22+
23+
test_custom_op_1

examples/executor_runner/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def define_common_targets():
1717
"//executorch/extension/data_loader:file_data_loader",
1818
"//executorch/util:util",
1919
"//executorch/kernels/portable:generated_lib_all_ops",
20+
"//executorch/examples/custom_ops:generated_lib",
2021
],
2122
external_deps = [
2223
"gflags",

0 commit comments

Comments
 (0)