Skip to content

Commit fd88b27

Browse files
mcr229facebook-github-bot
authored andcommitted
XNNPACK Lowering Tutorial (#634)
Summary: Pull Request resolved: #634 [preview](https://deploy-preview-634--resplendent-gnome-14e531.netlify.app) Reviewed By: digantdesai Differential Revision: D49945479 fbshipit-source-id: 4350c752f081b01bf0dc18733ff9068dda3f95e5
1 parent 2e33d62 commit fd88b27

File tree

3 files changed

+134
-1
lines changed

3 files changed

+134
-1
lines changed

docs/source/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ Topics in this section will help you get started with ExecuTorch.
198198
demo-apps-android
199199
build-run-xtensa
200200
tutorials/sdk-integration-tutorial
201+
tutorial-xnnpack-delegate-lowering
201202

202203
Tutorials and Examples
203204
~~~~~~~~~~~~~~~~~~~~~~
@@ -242,4 +243,11 @@ ExecuTorch tutorials.
242243
:link: build-run-xtensa.html
243244
:tags: DSP
244245

246+
.. customcarditem::
247+
:header: XNNPACK Backend Delegate Lowering Tutorial
248+
:card_description: A demo tutorial for lowering and export models with the XNNPACK Backend
249+
:image: _static/img/generic-pytorch-logo.png
250+
:link: tutorial-xnnpack-delegate-lowering.html
251+
:tags: Export,Delegation,Quantization,XNNPACK
252+
245253
.. customcardend::

docs/source/native-delegates-executorch-xnnpack-delegate.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,5 +154,5 @@ def _qdq_quantized_linear(
154154
You can read more indepth explanations on PyTorch 2 quantization [here](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html).
155155

156156
## See Also
157-
- Lowering to XNNPACK Tutorial (TBD)
158157
- [Integrating XNNPACK Delegate Android App](https://github.com/pytorch/executorch/blob/main/examples/demo-apps/android/ExecuTorchDemo/README.md)
158+
- [Complete the Lowering to XNNPACK Tutorial](tutorial-xnnpack-delegate-lowering.md)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# XNNPACK Backend Delegate Lowering Tutorial
2+
3+
The following tutorial will familiarize you with leveraging the ExecuTorch XNNPACK Delegate for accelerating your ML Models using CPU hardware. It will go over exporting and serializing a model to a binary file, targeting the XNNPACK Delegate Backend and running the model on a supported target platform. To get started quickly, use the script in the ExecuTorch repository with instructions on exporting and generating a binary file for a few sample models demonstrating the flow.
4+
5+
<!----This will show a grid card on the page----->
6+
::::{grid} 2
7+
:::{grid-item-card} What you will learn in this tutorial:
8+
:class-card: card-learn
9+
In this tutorial, you will learn how to export an XNNPACK lowered Model and run it on a target platform
10+
:::
11+
:::{grid-item-card} Before you begin it is recommended you go through the following:
12+
:class-card: card-prerequisites
13+
* [Setting up ExecuTorch](./getting-started-setup.md)
14+
* [Model Lowering Tutorial](./tutorials/export-to-executorch-tutorial)
15+
* [Custom Quantization](./quantization-custom-quantization.md)
16+
* [ExecuTorch XNNPACK Delegate](./native-delegates-XNNPACK-Delegate.md)
17+
:::
18+
::::
19+
20+
21+
## Lowering a model to XNNPACK
22+
```python
23+
import torch
24+
import torchvision.models as models
25+
26+
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
27+
from executorch.examples.portable.utils import export_to_edge
28+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
29+
30+
31+
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights).eval()
32+
sample_inputs = (torch.randn(1, 3, 224, 224), )
33+
34+
edge = export_to_edge(mobilenet_v2, example_inputs)
35+
36+
37+
edge.to_backend(XnnpackPartitioner)
38+
```
39+
40+
We will go through this example with the MobileNetV2 pretrained model downloaded from the TorchVision library. The flow of lowering a model starts after exporting the model `to_edge`. We call the `to_backend` api with the `XnnpackPartitioner`. The partitioner identifies the subgraphs suitable for XNNPACK backend delegate to consume. Afterwards, the identified subgraphs will be serialized with the XNNPACK Delegate flatbuffer schema and each subgraph will be replaced with a call to the XNNPACK Delegate.
41+
42+
```python
43+
>>> print(edge.exported_program.graph_module)
44+
class GraphModule(torch.nn.Module):
45+
def forward(self, arg314_1: f32[1, 3, 224, 224]):
46+
lowered_module_0 = self.lowered_module_0
47+
executorch_call_delegate = torch.ops.executorch_call_delegate(lowered_module_0, arg314_1)
48+
getitem: f32[1, 1280, 1, 1] = executorch_call_delegate[0]
49+
50+
aten_view_copy_default: f32[1, 1280] = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem, [1, 1280])
51+
52+
aten_clone_default: f32[1, 1280] = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default)
53+
54+
lowered_module_1 = self.lowered_module_1
55+
executorch_call_delegate_1 = torch.ops.executorch_call_delegate(lowered_module_1, aten_clone_default)
56+
getitem_1: f32[1, 1000] = executorch_call_delegate_1[0]
57+
return (getitem_1,)
58+
```
59+
60+
We print the graph after lowering above to show the new nodes that were inserted to call the XNNPACK Delegate. The subgraphs which are being delegated to XNNPACK are the first argument at each call site. It can be observed that the majority of `convolution-relu-add` blocks and `linear` blocks were able to be delegated to XNNPACK. We can also see the operators which were not able to be lowered to the XNNPACK delegate, such as `clone` and `view_copy`.
61+
62+
```python
63+
from executorch.examples.portable.utils import save_pte_program
64+
65+
exec_prog = edge.to_executorch()
66+
save_pte_program(exec_prog.buffer, "xnnpack_mobilenetv2.pte")
67+
```
68+
After lowering to the XNNPACK Program, we can then prepare it for executorch and save the model as a `.pte` file. `.pte` is a binary format that stores the serialized ExecuTorch graph.
69+
70+
71+
## Lowering a Quantized Model to XNNPACK
72+
The XNNPACK delegate can also execute symmetrically quantized models. To understand the quantization flow and learn how to quantize models, refer to [Custom Quantization](quantization-custom-quantization.md) note. For the sake of this tutorial, we will leverage the `quantize()` python helper function conveniently added to the `executorch/executorch/examples` folder.
73+
74+
```python
75+
import torch._export as export
76+
from executorch.examples.xnnpack.quantization.utils import quantize
77+
from executorch.exir import EdgeCompileConfig
78+
79+
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights).eval()
80+
sample_inputs = (torch.randn(1, 3, 224, 224), )
81+
82+
mobilenet_v2 = export.capture_pre_autograd_graph(mobilenet_v2, sample_inputs) # 2-stage export for quantization path
83+
quantized_mobilenetv2 = quantize(mobilenet_v2, sample_inputs)
84+
```
85+
86+
Quantization requires a two stage export. First we use the `capture_pre_autograd_graph` API to capture the model before giving it to `quantize` utility function. After performing the quantization step, we can now leverage the XNNPACK delegate to lower the quantized exported model graph. From here, the procedure is the same as for the non-quantized model lowering to XNNPACK.
87+
88+
```python
89+
# Continued from earlier...
90+
edge = export_to_edge(
91+
quantized_mobilenetv2,
92+
example_inputs,
93+
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False)
94+
)
95+
edge.to_backend(XnnpackPartitioner)
96+
97+
exec_prog = edge.to_executorch()
98+
save_pte_program(exec_prog.buffer, "qs8_xnnpack_mobilenetv2.pte")
99+
```
100+
101+
## Lowering with `aot_compiler.py` script
102+
We have also provided a script to quickly lower and export a few example models. You can run the script to generate lowered fp32 and quantized models. This script is used simply for convenience and performs all the same steps as those listed in the previous two sections.
103+
104+
```
105+
python3 -m examples.xnnpack.aot_compiler.py --model_name="mv2" --quantize --delegate
106+
```
107+
108+
Note in the example above,
109+
* the `-—model_name` specifies the model to use
110+
* the `-—quantize` flag controls whether the model should be quantized or not
111+
* the `-—delegate` flag controls whether we attempt to lower parts of the graph to the XNNPACK delegate.
112+
113+
The generated model file will be named `[model_name]_xnnpack_[qs8/fp32].pte` depending on the arguments supplied.
114+
115+
## Running the XNNPACK Model
116+
We will use `buck2` to run the `.pte` file with XNNPACK delegate instructions in it on your host platform. You can follow the instructions here to install [buck2](getting-started-setup.md). You can now run it with the prebuilt `xnn_executor_runner` provided in the examples. This will run the model on some sample inputs.
117+
118+
```bash
119+
buck2 run examples/backend:xnn_executor_runner -- --model_path ./mv2_xnnpack_fp32.pte
120+
# or to run the quantized variant
121+
buck2 run examples/backend:xnn_executor_runner -- --model_path ./mv2_xnnpack_qs8.pte
122+
```
123+
124+
## Building and Linking with the XNNPACK Backend
125+
You can build the XNNPACK backend [target](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/targets.bzl#L54), and link it with your application binary such as an Android or iOS application. For more information on this you may take a look at this [resource](runtime-backend-delegate-implementation-and-linking.md) next.

0 commit comments

Comments
 (0)