You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/kernel-library-custom-aten-kernel.md
+103-1Lines changed: 103 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -86,10 +86,88 @@ ATen operator with a dtype/dim order specialized kernel (works for `Double` dtyp
86
86
kernel_name: torch::executor::add_out
87
87
88
88
```
89
+
### Custom Ops C++ API
90
+
91
+
For a custom kernel that implements a custom operator, we provides 2 ways to register it into ExecuTorch runtime:
92
+
1. Using `EXECUTORCH_LIBRARY` and `WRAP_TO_ATEN` C++ macros.
93
+
2. Using `functions.yaml` and codegen'd C++ libraries.
94
+
95
+
The first option requires C++17 and doesn't have selective build support yet, but it's faster than the second option where we have to go through yaml authoring and build system tweaking.
96
+
97
+
The first option is particularly suitable for fast prototyping but can also be used in production.
98
+
99
+
Similar to `TORCH_LIBRARY`, `EXECUTORCH_LIBRARY` takes the operator name and the C++ function name and register them into ExecuTorch runtime.
100
+
101
+
#### Prepare custom kernel implementation
102
+
103
+
Define your custom operator schema for both functional variant (used in AOT compilation) and out variant (used in ExecuTorch runtime). The schema needs to follow PyTorch ATen convention (see native_functions.yaml). For example:
Now we need to write some wrapper for this op to show up in PyTorch, but don’t worry we don’t need to rewrite the kernel. Create a separate .cpp for this purpose:
Link it into ExecuTorch runtime: In our `CMakeLists.txt`` that builds the binary/application, we just need to add custom_linear.h/cpp into the binary target. We can build a dynamically loaded library (.so or .dylib) and link it as well.
154
+
155
+
Link it into PyTorch runtime: We need to package custom_linear.h, custom_linear.cpp and custom_linear_pytorch.cpp into a dynamically loaded library (.so or .dylib) and load it into our python environment. One way of doing this is:
# Now we have access to the custom op, backed by kernel implemented in custom_linear.cpp.
162
+
op = torch.ops.myop.custom_linear.default
163
+
```
164
+
89
165
90
166
### Custom Ops Yaml Entry
91
167
92
-
For custom ops (the ones that are not part of the out variants of core ATen opset) we need to specify the operator schema as well as a `kernel` section. So instead of `op` we use `func` with the operator schema. As an example, here’s a yaml entry for a custom op:
168
+
As mentioned above, this option provides more support in terms of selective build and features such as merging operator libraries.
169
+
170
+
First we need to specify the operator schema as well as a `kernel` section. So instead of `op` we use `func` with the operator schema. As an example, here’s a yaml entry for a custom op:
We also provide the ability to merge two yaml files, given a precedence. `merge_yaml(FUNCTIONS_YAML functions_yaml FALLBACK_YAML fallback_yaml OUTPUT_DIR out_dir)` merges functions_yaml and fallback_yaml into a single yaml, if there's duplicate entries in functions_yaml and fallback_yaml, this macro will always take the one in functions_yaml.
241
+
242
+
Example:
243
+
244
+
```yaml
245
+
# functions.yaml
246
+
- op: add.out
247
+
kernels:
248
+
- arg_meta: null
249
+
kernel_name: torch::executor::opt_add_out
250
+
```
251
+
252
+
And out fallback:
253
+
254
+
```yaml
255
+
# fallback.yaml
256
+
- op: add.out
257
+
kernels:
258
+
- arg_meta: null
259
+
kernel_name: torch::executor::add_out
260
+
```
261
+
262
+
The merged yaml will have the entry in functions.yaml.
263
+
162
264
#### Buck2
163
265
164
266
`executorch_generated_lib` is the macro that takes the yaml files and depends on the selective build macro `et_operator_library`. For an example:
0 commit comments