Skip to content

Commit 206af28

Browse files
kirklandsignmergennachin
authored andcommitted
Fix docs (#882)
Summary: Pull Request resolved: #882 Address issues on T166442606 1. Make sure all python code snippet and bash commands can run. 2. Modify some links a bit. Reviewed By: mcr229, shoumikhin Differential Revision: D50250157 fbshipit-source-id: 67cf996df0967d89f2f87c8ab6a2604bc829f546
1 parent 79c9e23 commit 206af28

File tree

1 file changed

+57
-35
lines changed

1 file changed

+57
-35
lines changed

docs/source/tutorial-xnnpack-delegate-lowering.md

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,47 +23,48 @@ In this tutorial, you will learn how to export an XNNPACK lowered Model and run
2323
import torch
2424
import torchvision.models as models
2525

26+
from torch.export import export
2627
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
27-
from executorch.examples.portable.utils import export_to_edge
2828
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
29+
from executorch.exir import to_edge
2930

3031

31-
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights).eval()
32+
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
3233
sample_inputs = (torch.randn(1, 3, 224, 224), )
3334

34-
edge = export_to_edge(mobilenet_v2, example_inputs)
35-
35+
edge = to_edge(export(mobilenet_v2, sample_inputs))
3636

3737
edge = edge.to_backend(XnnpackPartitioner)
3838
```
3939

40-
We will go through this example with the MobileNetV2 pretrained model downloaded from the TorchVision library. The flow of lowering a model starts after exporting the model `to_edge`. We call the `to_backend` api with the `XnnpackPartitioner`. The partitioner identifies the subgraphs suitable for XNNPACK backend delegate to consume. Afterwards, the identified subgraphs will be serialized with the XNNPACK Delegate flatbuffer schema and each subgraph will be replaced with a call to the XNNPACK Delegate.
40+
We will go through this example with the [MobileNetV2](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/) pretrained model downloaded from the TorchVision library. The flow of lowering a model starts after exporting the model `to_edge`. We call the `to_backend` api with the `XnnpackPartitioner`. The partitioner identifies the subgraphs suitable for XNNPACK backend delegate to consume. Afterwards, the identified subgraphs will be serialized with the XNNPACK Delegate flatbuffer schema and each subgraph will be replaced with a call to the XNNPACK Delegate.
4141

4242
```python
43-
>>> print(edge.exported_program.graph_module)
44-
class GraphModule(torch.nn.Module):
45-
def forward(self, arg314_1: f32[1, 3, 224, 224]):
46-
lowered_module_0 = self.lowered_module_0
47-
executorch_call_delegate = torch.ops.executorch_call_delegate(lowered_module_0, arg314_1)
48-
getitem: f32[1, 1280, 1, 1] = executorch_call_delegate[0]
49-
50-
aten_view_copy_default: f32[1, 1280] = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem, [1, 1280])
51-
52-
aten_clone_default: f32[1, 1280] = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default)
43+
>>> print(edge.exported_program().graph_module)
44+
GraphModule(
45+
(lowered_module_0): LoweredBackendModule()
46+
(lowered_module_1): LoweredBackendModule()
47+
)
5348

54-
lowered_module_1 = self.lowered_module_1
55-
executorch_call_delegate_1 = torch.ops.executorch_call_delegate(lowered_module_1, aten_clone_default)
56-
getitem_1: f32[1, 1000] = executorch_call_delegate_1[0]
57-
return (getitem_1,)
49+
def forward(self, arg314_1):
50+
lowered_module_0 = self.lowered_module_0
51+
executorch_call_delegate = torch.ops.executorch_call_delegate(lowered_module_0, arg314_1); lowered_module_0 = arg314_1 = None
52+
getitem = executorch_call_delegate[0]; executorch_call_delegate = None
53+
aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem, [1, 1280]); getitem = None
54+
aten_clone_default = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default); aten_view_copy_default = None
55+
lowered_module_1 = self.lowered_module_1
56+
executorch_call_delegate_1 = torch.ops.executorch_call_delegate(lowered_module_1, aten_clone_default); lowered_module_1 = aten_clone_default = None
57+
getitem_1 = executorch_call_delegate_1[0]; executorch_call_delegate_1 = None
58+
return (getitem_1,)
5859
```
5960

6061
We print the graph after lowering above to show the new nodes that were inserted to call the XNNPACK Delegate. The subgraphs which are being delegated to XNNPACK are the first argument at each call site. It can be observed that the majority of `convolution-relu-add` blocks and `linear` blocks were able to be delegated to XNNPACK. We can also see the operators which were not able to be lowered to the XNNPACK delegate, such as `clone` and `view_copy`.
6162

6263
```python
63-
from executorch.examples.portable.utils import save_pte_program
64-
6564
exec_prog = edge.to_executorch()
66-
save_pte_program(exec_prog.buffer, "xnnpack_mobilenetv2.pte")
65+
66+
with open("xnnpack_mobilenetv2.pte", "wb") as file:
67+
file.write(exec_prog.buffer)
6768
```
6869
After lowering to the XNNPACK Program, we can then prepare it for executorch and save the model as a `.pte` file. `.pte` is a binary format that stores the serialized ExecuTorch graph.
6970

@@ -72,37 +73,58 @@ After lowering to the XNNPACK Program, we can then prepare it for executorch and
7273
The XNNPACK delegate can also execute symmetrically quantized models. To understand the quantization flow and learn how to quantize models, refer to [Custom Quantization](quantization-custom-quantization.md) note. For the sake of this tutorial, we will leverage the `quantize()` python helper function conveniently added to the `executorch/executorch/examples` folder.
7374

7475
```python
75-
import torch._export as export
76-
from executorch.examples.xnnpack.quantization.utils import quantize
76+
from torch._export import capture_pre_autograd_graph
7777
from executorch.exir import EdgeCompileConfig
7878

79-
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights).eval()
79+
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
8080
sample_inputs = (torch.randn(1, 3, 224, 224), )
8181

82-
mobilenet_v2 = export.capture_pre_autograd_graph(mobilenet_v2, sample_inputs) # 2-stage export for quantization path
82+
mobilenet_v2 = capture_pre_autograd_graph(mobilenet_v2, sample_inputs) # 2-stage export for quantization path
83+
84+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
85+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
86+
get_symmetric_quantization_config,
87+
XNNPACKQuantizer,
88+
)
89+
90+
91+
def quantize(model, example_inputs):
92+
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
93+
print(f"Original model: {model}")
94+
quantizer = XNNPACKQuantizer()
95+
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
96+
operator_config = get_symmetric_quantization_config(is_per_channel=False)
97+
quantizer.set_global(operator_config)
98+
m = prepare_pt2e(model, quantizer)
99+
# calibration
100+
m(*example_inputs)
101+
m = convert_pt2e(m)
102+
print(f"Quantized model: {m}")
103+
# make sure we can export to flat buffer
104+
return m
105+
83106
quantized_mobilenetv2 = quantize(mobilenet_v2, sample_inputs)
84107
```
85108

86109
Quantization requires a two stage export. First we use the `capture_pre_autograd_graph` API to capture the model before giving it to `quantize` utility function. After performing the quantization step, we can now leverage the XNNPACK delegate to lower the quantized exported model graph. From here, the procedure is the same as for the non-quantized model lowering to XNNPACK.
87110

88111
```python
89112
# Continued from earlier...
90-
edge = export_to_edge(
91-
quantized_mobilenetv2,
92-
example_inputs,
93-
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False)
94-
)
113+
edge = to_edge(export(quantized_mobilenetv2, sample_inputs), compile_config=EdgeCompileConfig(_check_ir_validity=False))
114+
95115
edge = edge.to_backend(XnnpackPartitioner)
96116

97117
exec_prog = edge.to_executorch()
98-
save_pte_program(exec_prog.buffer, "qs8_xnnpack_mobilenetv2.pte")
118+
119+
with open("qs8_xnnpack_mobilenetv2.pte", "wb") as file:
120+
file.write(exec_prog.buffer)
99121
```
100122

101123
## Lowering with `aot_compiler.py` script
102124
We have also provided a script to quickly lower and export a few example models. You can run the script to generate lowered fp32 and quantized models. This script is used simply for convenience and performs all the same steps as those listed in the previous two sections.
103125

104126
```
105-
python3 -m examples.xnnpack.aot_compiler.py --model_name="mv2" --quantize --delegate
127+
python3 -m examples.xnnpack.aot_compiler --model_name="mv2" --quantize --delegate
106128
```
107129

108130
Note in the example above,
@@ -113,7 +135,7 @@ Note in the example above,
113135
The generated model file will be named `[model_name]_xnnpack_[qs8/fp32].pte` depending on the arguments supplied.
114136

115137
## Running the XNNPACK Model
116-
We will use `buck2` to run the `.pte` file with XNNPACK delegate instructions in it on your host platform. You can follow the instructions here to install [buck2](getting-started-setup.md). You can now run it with the prebuilt `xnn_executor_runner` provided in the examples. This will run the model on some sample inputs.
138+
We will use `buck2` to run the `.pte` file with XNNPACK delegate instructions in it on your host platform. You can follow the instructions here to install [buck2](getting-started-setup.md#building-a-runtime). You can now run it with the prebuilt `xnn_executor_runner` provided in the examples. This will run the model on some sample inputs.
117139

118140
```bash
119141
buck2 run examples/backend:xnn_executor_runner -- --model_path ./mv2_xnnpack_fp32.pte
@@ -122,4 +144,4 @@ buck2 run examples/backend:xnn_executor_runner -- --model_path ./mv2_xnnpack_qs8
122144
```
123145

124146
## Building and Linking with the XNNPACK Backend
125-
You can build the XNNPACK backend [target](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/targets.bzl#L54), and link it with your application binary such as an Android or iOS application. For more information on this you may take a look at this [resource](demo-apps-android.md) next.
147+
You can build the XNNPACK backend [BUCK target](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/targets.bzl#L54) and [CMake target](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/CMakeLists.txt#L83), and link it with your application binary such as an Android or iOS application. For more information on this you may take a look at this [resource](demo-apps-android.md) next.

0 commit comments

Comments
 (0)