|
| 1 | +# XNNPACK Delegate Lowering Tutorial |
| 2 | + |
| 3 | +The following tutorial will familiarize you with leveraging the ExecuTorch XNNPACK Delegate for accelerating your ML Models using CPU hardware. It will go over exporting and serializing a model to a binary file, targeting the XNNPACK Delegate Backend and running the model on a supported target platform. To get started quickly, use the script in the ExecuTorch repository with instructions on exporting and generating a binary file for a few sample models demonstrating the flow. |
| 4 | + |
| 5 | +<!----This will show a grid card on the page-----> |
| 6 | +::::{grid} 2 |
| 7 | +:::{grid-item-card} What you will learn in this tutorial: |
| 8 | +:class-card: card-learn |
| 9 | +In this tutorial, you will learn how to export an XNNPACK Lowered Model and run it on a target platform |
| 10 | +::: |
| 11 | +:::{grid-item-card} Before you begin it is recommended you go through the following: |
| 12 | +:class-card: card-prerequisites |
| 13 | +* [Installing Buck2](./getting-started-setup.md) |
| 14 | +* [Setting up ExecuTorch](./examples-end-to-end-to-lower-model-to-delegate.md) |
| 15 | +* [Model Lowering Tutorial](./runtime-backend-delegate-implementation-and-linking.md) |
| 16 | +* [Custom Quantization](./quantization-custom-quantization.md) |
| 17 | +* [Executorch XNNPACK Delegate](./native-delegates-XNNPACK-Delegate.md) |
| 18 | +::: |
| 19 | +:::: |
| 20 | + |
| 21 | + |
| 22 | +## Lowering a model to XNNPACK |
| 23 | +```python |
| 24 | +import torch |
| 25 | +import torchvision.models as models |
| 26 | + |
| 27 | +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights |
| 28 | +from executorch.examples.portable.utils import export_to_edge |
| 29 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
| 30 | + |
| 31 | + |
| 32 | +mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights).eval() |
| 33 | +sample_inputs = (torch.randn(1, 3, 224, 224), ) |
| 34 | + |
| 35 | +edge = export_to_edge(mobilenet_v2, example_inputs) |
| 36 | + |
| 37 | + |
| 38 | +edge.exported_program = edge.to_backend(XnnpackPartitioner) |
| 39 | +``` |
| 40 | + |
| 41 | +We will go through this example with the MobileNetV2 pretrained model downloade from [TorchVision library](reference). The flow of lowering a model starts after exporting the model `to_edge`. We call the `to_backend` api with the `XnnpackPartitioner`. The partitioner identifies the subgraphs suitable for XNNPACK backend delegate to consume. After which these identified subgraphs, will be serialized with the XNNPACK Delegate flatbuffer schema and will be replaced with calls (one per subgraph) to the XNNPACK Delegate. |
| 42 | + |
| 43 | +```python |
| 44 | +class GraphModule(torch.nn.Module): |
| 45 | + def forward(self, arg314_1: f32[1, 3, 224, 224]): |
| 46 | + lowered_module_0 = self.lowered_module_0 |
| 47 | + executorch_call_delegate = torch.ops.executorch_call_delegate(lowered_module_0, arg314_1) |
| 48 | + getitem: f32[1, 1280, 1, 1] = executorch_call_delegate[0] |
| 49 | + |
| 50 | + aten_view_copy_default: f32[1, 1280] = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem, [1, 1280]) |
| 51 | + |
| 52 | + aten_clone_default: f32[1, 1280] = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default) |
| 53 | + |
| 54 | + lowered_module_1 = self.lowered_module_1 |
| 55 | + executorch_call_delegate_1 = torch.ops.executorch_call_delegate(lowered_module_1, aten_clone_default) |
| 56 | + getitem_1: f32[1, 1000] = executorch_call_delegate_1[0] |
| 57 | + return (getitem_1,) |
| 58 | +``` |
| 59 | + |
| 60 | +We print the graph after lowering above to show the new nodes to call the XNNPACK Delegate, with the subgraphs which were delegated to XNNPACK as the first argument. The majority of `convolution-relu-add` blocks and `linear` blocks were able to be delegated to xnnpack. We can also see the operators which were not able to be lowered to the XNNPACK delegate like `clone` and `view_copy`. |
| 61 | + |
| 62 | +```python |
| 63 | +from executorch.examples.portable.utils import save_pte_program |
| 64 | + |
| 65 | +exec_prog = edge.to_executorch() |
| 66 | +save_pte_program(exec_prog.buffer, "xnnpack_mobilenetv2.pte") |
| 67 | +``` |
| 68 | +After lowering to the XNNPACK Program, we can then prepare it for executorch and save the model as a `.pte` file. `.pte` is a binary format for the ExecuTorch runtime to consume. |
| 69 | + |
| 70 | + |
| 71 | +## Lowering a Quantized Model to XNNPACK |
| 72 | +The XNNPACK delegate can also execute symmetrically quantized models. Understanding quantization flow, and how to quantize models can be read about in [Custom Quantization](quantization-custom-quantization.md) note. For the sake of this tutorial, we will leverage the quantize python helper function conveniently added to the `executorch/executorch/examples` folder. |
| 73 | + |
| 74 | +```python |
| 75 | +import torch._export as export |
| 76 | +from executorch.examples.xnnpack.quantization.utils import quantize |
| 77 | +from executorch.exir import EdgeCompileConfig |
| 78 | + |
| 79 | +mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights).eval() |
| 80 | +sample_inputs = (torch.randn(1, 3, 224, 224), ) |
| 81 | + |
| 82 | +mobilenet_v2 = export.capture_pre_autograd_graph(mobilenet_v2, sample_inputs) |
| 83 | +quantized_mobilenetv2 = quantize(mobilenet_v2, sample_inputs) |
| 84 | +``` |
| 85 | + |
| 86 | +Quantization requires a two stage export. First we use the `capture_pre_autograd_graph` API to capture the model before giving it to `quantize` utility function. After performing the quantization step, we can now leverage the XNNPACK delegate to lower quantized, exported model graph. We can follow the same steps, as for the non-quantized model lowering to XNNPACK, to now lower this quantized model. |
| 87 | + |
| 88 | +```python |
| 89 | +# Continued from earlier... |
| 90 | +edge = export_to_edge( |
| 91 | + quantized_mobilenetv2, |
| 92 | + example_inputs, |
| 93 | + edge_compile_config=EdgeCompileConfig(_check_ir_validity=False) |
| 94 | +) |
| 95 | +edge.exported_program = edge.to_backend(XnnpackPartitioner) |
| 96 | + |
| 97 | +exec_prog = edge.to_executorch() |
| 98 | +save_pte_program(exec_prog.buffer, "qs8_xnnpack_mobilenetv2.pte") |
| 99 | +``` |
| 100 | + |
| 101 | +## Lowering with aot_compiler.py script |
| 102 | +We have also provided a script to quickly lower and export a few example models. You can run the script to generate lowered fp32 and quantized models. This script is used simply for convenience and performs all the same steps as those listed in the previous two sections |
| 103 | + |
| 104 | +``` |
| 105 | +python3 -m examples.xnnpack.aot_compiler.py --model_name="mv2" --quantize --delegate |
| 106 | +``` |
| 107 | + |
| 108 | +Note in the example above, `-—model_name` specifies the model to use, the `-—quantize` flag controls whether the model should be quantized or not, the `-—delegate` flag controls whether we attempt to lower parts of the graph to the XNNPACK delegate. The generated model file will be named `[model_name]_xnnpack_[qs8/fp32].pte` depending on the arguments supplied. |
| 109 | + |
| 110 | +## Running the XNNPACK Model |
| 111 | +We will use `buck2` to run the model `.pte` file with XNNPACK delegate instructions in it on your host platform. You can follow the instructions here to install [buck2](getting-started-setup.md). You can now run it with the prebuilt `xnn_executor_runner` provided in the examples. This will run the model on some sample inputs. |
| 112 | + |
| 113 | +```bash |
| 114 | +buck2 run examples/backend:xnn_executor_runner -- --model_path ./mv2_xnnpack_fp32.pte |
| 115 | +# or to run the quantized variant |
| 116 | +buck2 run examples/backend:xnn_executor_runner -- --model_path ./mv2_xnnpack_qs8.pte |
| 117 | +``` |
| 118 | + |
| 119 | +## Building and Linking with the XNNPACK Backend |
| 120 | +You can build the XNNPACK backend [target](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/targets.bzl#L54), and link it with your application binary such as an Android or iOS application. For more information on this you may take a look at this [resource](runtime-backend-delegate-implementation-and-linking.md) next. |
0 commit comments