Skip to content

Commit c5f2f45

Browse files
mcr229facebook-github-bot
authored andcommitted
XNNPACK Lowering Tutorial (#634)
Summary: Pull Request resolved: #634 Differential Revision: D49945479 fbshipit-source-id: 68d6b557e0c1d6a94ca3e329129a03a27e72ddce
1 parent 80e9e38 commit c5f2f45

File tree

3 files changed

+121
-1
lines changed

3 files changed

+121
-1
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ Topics in this section will help you get started with ExecuTorch.
188188
tutorials/export-to-executorch-tutorial
189189
build-run-xtensa
190190
tutorials/sdk-integration-tutorial
191+
tutorial-xnnpack-delegate-lowering
191192

192193
Tutorials and Examples
193194
~~~~~~~~~~~~~~~~~~~~~~

docs/source/native-delegates-executorch-xnnpack-delegate.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,5 +129,5 @@ print(quantized_model)
129129
You will now see the Q/DQ representation of the model, which means `torch.ops.quantized_decomposed.dequantize_per_tensor` are inserted at quantized operator inputs and `torch.ops.quantized_decomposed.quantize_per_tensor` are inserted at operator outputs. [Example](https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/pt2e/representation/rewrite.py#L40):
130130

131131
## See Also
132-
- Lowering to XNNPACK Tutorial (TBD)
133132
- [Integrating XNNPACK Delegate Android App](https://github.com/pytorch/executorch/blob/main/examples/demo-apps/android/ExecuTorchDemo/README.md)
133+
- [Complete the Lowering to XNNPACK Tutorial](tutorial-xnnpack-delegate-lowering.md)
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# XNNPACK Delegate Lowering Tutorial
2+
3+
The following tutorial will familiarize you with leveraging the ExecuTorch XNNPACK Delegate for accelerating your ML Models using CPU hardware. It will go over exporting and serializing a model to a binary file, targeting the XNNPACK Delegate Backend and running the model on a supported target platform. To get started quickly, use the script in the ExecuTorch repository with instructions on exporting and generating a binary file for a few sample models demonstrating the flow.
4+
5+
<!----This will show a grid card on the page----->
6+
::::{grid} 2
7+
:::{grid-item-card} What you will learn in this tutorial:
8+
:class-card: card-learn
9+
In this tutorial, you will learn how to export an XNNPACK Lowered Model and run it on a target platform
10+
:::
11+
:::{grid-item-card} Before you begin it is recommended you go through the following:
12+
:class-card: card-prerequisites
13+
* [Installing Buck2](./getting-started-setup.md)
14+
* [Setting up ExecuTorch](./examples-end-to-end-to-lower-model-to-delegate.md)
15+
* [Model Lowering Tutorial](./runtime-backend-delegate-implementation-and-linking.md)
16+
* [Custom Quantization](./quantization-custom-quantization.md) (optional)
17+
* [Executorch XNNPACK Delegate](./native-delegates-XNNPACK-Delegate.md) (optional)
18+
:::
19+
::::
20+
21+
22+
## Lowering a model to XNNPACK
23+
```python
24+
import torch
25+
import torchvision.models as models
26+
27+
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
28+
from executorch.examples.portable.utils import export_to_edge
29+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
30+
31+
32+
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights).eval()
33+
sample_inputs = (torch.randn(1, 3, 224, 224), )
34+
35+
edge = export_to_edge(mobilenet_v2, example_inputs)
36+
37+
38+
edge.exported_program = to_backend(edge.exported_program, XnnpackPartitioner)
39+
```
40+
41+
We will go through this example with the MobileNetV2 model. The flow of lowering a model starts after exporting the model `to_edge`. We call the `to_backend` api with the `XnnpackPartitioner`. The partitioner identifies the subgraphs suitable for XNNPACK lowering. After which these subgraphs, will be serialized with the XNNPACK Delegate flatbuffer schema and will be replaced with calls to the XNNPACK Delegate.
42+
43+
```
44+
class GraphModule(torch.nn.Module):
45+
def forward(self, arg314_1: f32[1, 3, 224, 224]):
46+
lowered_module_0 = self.lowered_module_0
47+
executorch_call_delegate = torch.ops.executorch_call_delegate(lowered_module_0, arg314_1)
48+
getitem: f32[1, 1280, 1, 1] = executorch_call_delegate[0]
49+
50+
aten_view_copy_default: f32[1, 1280] = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem, [1, 1280])
51+
52+
aten_clone_default: f32[1, 1280] = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default)
53+
54+
lowered_module_1 = self.lowered_module_1
55+
executorch_call_delegate_1 = torch.ops.executorch_call_delegate(lowered_module_1, aten_clone_default)
56+
getitem_1: f32[1, 1000] = executorch_call_delegate_1[0]
57+
return (getitem_1,)
58+
```
59+
60+
We print the graph after lowering above to show the new calls to the XNNPACK Delegate, and the subgraphs which were delegated to XNNPACK. The majority of convolution-relu-add blocks and linear blocks were able to be delegated to xnnpack. We can also see the operators which were not able to be lowered to the XNNPACK delegate.
61+
62+
```python
63+
from executorch.examples.export.utils import save_pte_program
64+
65+
exec_prog = edge.to_executorch()
66+
save_pte_program(exec_prog.buffer, "xnnpack_mobilenetv2.pte")
67+
```
68+
After lowering to the XNNPACK Program, we can then prepare it for executorch and save the model as a pte file.
69+
70+
71+
## Lowering a Quantized model to XNNPACK
72+
The XNNPACK Delegate is also a backend for executing symmetrically quantized models. Understanding quantization flow and how to quantize models can be read about [Custom Quantization](quantization-custom-quantization.md), but for the sake of this tutorial we will leverage the quantize function conveniently added to the executorch/examples folder.
73+
74+
```python
75+
import torch._export as export
76+
from executorch.examples.xnnpack.quantization.utils import quantize
77+
from executorch.exir import EdgeCompileConfig
78+
79+
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights).eval()
80+
sample_inputs = (torch.randn(1, 3, 224, 224), )
81+
82+
mobilenet_v2 = export.capture_pre_autograd_graph(mobilenet_v2, sample_inputs)
83+
quantized_mobilenetv2 = quantize(mobilenet_v2, sample_inputs)
84+
```
85+
86+
Quantization requires two stage export, so we first use the `capture_pre_autograd_graph` api to capture the model before giving it to quantize. After performing these quantization steps, we can now leverage the XNNPACK Delegate to lower quantized subgraphs, and use XNNPACK as a backend for executing quantized models. We can now follow the same steps for lowering to XNNPACK to now lower this quantized model.
87+
88+
```python
89+
edge = export_to_edge(
90+
mobilenet_v2,
91+
example_inputs,
92+
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False)
93+
)
94+
edge.exported_program = to_backend(edge.exported_program, XnnpackPartitioner)
95+
96+
exec_prog = edge.to_executorch()
97+
save_pte_program(exec_prog.buffer, "qs8_xnnpack_mobilenetv2.pte")
98+
```
99+
100+
## Lowering with aot_compiler.py script
101+
We have also provided a script to quickly lower and export a few example models. You can run the script to generate lowered fp32 and quantized models
102+
103+
```
104+
python3 -m examples.xnnpack.aot_compiler.py --model_name="mv2" --quantize --delegate
105+
```
106+
107+
Note in above, —model_name specifies the model to use, the —quantize flag controls whether the model is quantized, the —delegate flag controls whether we lower to the xnnpack delegate. The generated model file will be named [model_name]_xnnpack_[qs8/fp32].pte
108+
109+
## Running the XNNPACK model on your host platform
110+
We will use buck2 to run the XNNPACK Lowered model on your host platform. You can follow the instructions here to install [buck2](getting-started-setup.md). Once you have your lowered model, you can run with the prebuilt `xnn_executor_runner` provided in the examples. This will simply run the model on some sample inputs.
111+
112+
```
113+
buck2 run examples/backend:xnn_executor_runner -- --model_path ./mv2_xnnpack_fp32.pte
114+
```
115+
116+
## Building and Linking with the XNNPACK Backend
117+
You can build the XNNPACK Backend [target](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/targets.bzl#L54) and link it with your binary. For more information on Backend Delegate linking you can take a look at this [resource](runtime-backend-delegate-implementation-and-linking.md)
118+
119+
After building the XNNPACK Backend and exporting your XNNPACK Delegated Model, you can also integrate the model ad backend into your app to leverage the CPU acceleration on your device. You can follow the following guide on how to integrate the XNNPACK into your app (TBD)

0 commit comments

Comments
 (0)