Skip to content

Commit 5555623

Browse files
committed
Write new top-level export and lowering documentation
1 parent d60e12d commit 5555623

8 files changed

+186
-11
lines changed

docs/source/api-life-cycle.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ExecuTorch API Life Cycle and Deprecation Policy
1+
# API Life Cycle and Deprecation Policy
22

33
## API Life Cycle
44

docs/source/compiler-delegate-and-partitioner.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Backend and Delegate
1+
# Backends and Delegates
22

33
Audience: Vendors, Backend Delegate developers, who are interested in integrating their own compilers and hardware as part of ExecuTorch
44

docs/source/executorch-runtime-api-reference.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ExecuTorch Runtime API Reference
1+
Runtime API Reference
22
================================
33

44
The ExecuTorch C++ API provides an on-device execution framework for exported PyTorch models.

docs/source/export-to-executorch-api-reference.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Export to ExecuTorch API Reference
1+
Export API Reference
22
----------------------------------
33

44
For detailed information on how APIs evolve and the deprecation process, please refer to the `ExecuTorch API Life Cycle and Deprecation Policy <api-life-cycle.html>`__.

docs/source/getting-started.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ with open(“model.pte”, “wb”) as f:
5959
f.write(et_program.buffer)
6060
```
6161

62-
If the model requires varying input sizes, you will need to specify the varying dimensions and bounds as part of the `export` call. See [Exporting a Model for ExecuTorch](/TODO.md) for more information.
62+
If the model requires varying input sizes, you will need to specify the varying dimensions and bounds as part of the `export` call. See [Model Export and Lowering](using-executorch-export.md) for more information.
6363

6464
The hardware backend to target is controlled by the partitioner parameter to to\_edge\_transform\_and\_lower. In this example, the XnnpackPartitioner is used to target mobile CPUs. See the delegate-specific documentation for a full description of the partitioner and available options.
6565

@@ -198,7 +198,7 @@ For more information on the C++ APIs, see [Running an ExecuTorch Model Using the
198198
## Next Steps
199199
ExecuTorch provides a high-degree of customizability to support diverse hardware targets. Depending on your use cases, consider exploring one or more of the following pages:
200200
201-
- [Exporting a Model to ExecuTorch](/TODO.md) for advanced model conversion options.
201+
- [Export and Lowering](using-executorch-export.md) for advanced model conversion options.
202202
- [Delegates](/TODO.md) for available backends and configuration options.
203203
- [Using ExecuTorch on Android](/TODO.md) and [Using ExecuTorch on iOS](TODO.md) for mobile runtime integration.
204204
- [Using ExecuTorch with C++](/TODO.md) for embedded and mobile native development.

docs/source/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Topics in this section will help you get started with ExecuTorch.
5353
.. grid-item-card:: :octicon:`file-code;1em`
5454
Getting started with ExecuTorch
5555
:img-top: _static/img/card-background.svg
56-
:link: getting-started-setup.html
56+
:link: getting-started.html
5757
:link-type: url
5858

5959
A step-by-step tutorial on how to get started with
@@ -190,6 +190,7 @@ Topics in this section will help you get started with ExecuTorch.
190190

191191
backend-delegates-integration
192192
backend-delegates-dependencies
193+
compiler-delegate-and-partitioner
193194
debug-backend-delegate
194195

195196
.. toctree::
@@ -207,7 +208,6 @@ Topics in this section will help you get started with ExecuTorch.
207208
:caption: Compiler Entry Points
208209
:hidden:
209210

210-
compiler-delegate-and-partitioner
211211
compiler-backend-dialect
212212
compiler-custom-compiler-passes
213213
compiler-memory-planning

docs/source/runtime-python-api-reference.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ExecuTorch Runtime Python API Reference
1+
Runtime Python API Reference
22
----------------------------------
33
The Python ``executorch.runtime`` module wraps the C++ ExecuTorch runtime. It can load and execute serialized ``.pte`` program files: see the `Export to ExecuTorch Tutorial <tutorials/export-to-executorch-tutorial.html>`__ for how to convert a PyTorch ``nn.Module`` to an ExecuTorch ``.pte`` program file. Execution accepts and returns ``torch.Tensor`` values, making it a quick way to validate the correctness of the program.
44

Lines changed: 177 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,178 @@
1-
# Model Export
1+
# Model Export and Lowering
22

3-
Placeholder for top-level export documentation
3+
The section describes the process of taking a PyTorch model and converting to the runtime format used by ExecuTorch. This process is commonly known as "exporting", as it uses the PyTorch export functionality to convert a PyTorch model into a format suitable for on-device execution. This process yields a .pte file which is optimized for on-device execution using a particular backend.
4+
5+
## Prerequisites
6+
7+
Exporting requires the ExecuTorch python libraries to be installed, typically by running `pip install executorch`. See [Installation](getting-started.md#Installation) for more information. This process assumes you have a PyTorch model, can instantiate it from Python, and can provide example input tensors to run the model.
8+
9+
## The Export and Lowering Process
10+
11+
The process to export and lower a model to the .pte format typically involves the following steps:
12+
13+
1) Select a backend to target.
14+
2) Prepare the PyTorch model, including inputs and shape specification.
15+
3) Export the model using torch.export.export.
16+
4) Optimize the model for the target backend using to_edge_transform_and_lower.
17+
5) Create the .pte file by calling to_executorch and serializing the output.
18+
19+
<br/>
20+
21+
Quantization - the process of using reduced precision to reduce inference time and memory footprint - is also commonly done at this stage. See [Quantization Overview](quantization-overview.md) for more information.
22+
23+
## Hardware Backends
24+
25+
ExecuTorch backends provide hardware acceleration for a specific hardware target. In order to achieve maximum performance on target hardware, ExecuTorch optimizes the model for a specific backend during the export and lowering process. This means that the resulting .pte file is specialized for the specific hardware. In order to deploy to multiple backends, such as Core ML on iOS and Arm CPU on Android, it is common to generate a dedicated .pte file for each.
26+
27+
The choice of hardware backend is informed by the hardware that the model is intended to be deployed on. Each backend has specific hardware requires and level of model support. See the documentation for each hardware backend for more details.
28+
29+
As part of the .pte file creation process, ExecuTorch identifies portions of the model (partitions) that are supported for the given backend. These sections are processed by the backend ahead of time to support efficient execution. Portions of the model that are not supported on the delegate, if any, are executed using the portable fallback implementation on CPU. This allows for partial model acceleration when not all model operators are supported on the backend, but may have negative performance implications. In addition, multiple partitioners can be specified in order of priority. This allows for operators not supported on GPU to run on CPU via XNNPACK, for example.
30+
31+
### Available Backends
32+
33+
Commonly used hardware backends are listed below. For mobile, consider using XNNPACK for Android and XNNPACK or Core ML for iOS. To create a .pte file for a specific backend, pass the appropriate partitioner class to `to_edge_transform_and_lower`. See the appropriate backend documentation and the [Export and Lowering](#export-and-lowering) section below for more information.
34+
35+
- [XNNPACK (Mobile CPU)](native-delegates-executorch-xnnpack-delegate.md)
36+
- [Core ML (iOS)](native-delegates-executorch-coreml-delegate.md)
37+
- [Metal Performance Shaders (iOS GPU)](native-delegates-executorch-mps-delegate.md)
38+
- [Vulkan (Android GPU)](native-delegates-executorch-vulkan-delegate.md)
39+
- [Qualcomm NPU](native-delegates-executorch-qualcomm-delegate.md)
40+
- [MediaTek NPU](native-delegates-executorch-mediatek-delegate.md)
41+
- [Arm Ethos-U NPU](native-delegates-executorch-arm-ethos-u-delegate.md)
42+
- [Cadence DSP](native-delegates-executorch-cadence-delegate.md)
43+
44+
## Model Preparation
45+
46+
The export process takes in a standard PyTorch model, typically a `torch.nn.Module`. This can be an custom model definition, or a model from an existing source, such as TorchVision or HuggingFace. See [Getting Started with ExecuTorch](getting-started.md) for an example of lowering a TorchVision model.
47+
48+
Model export is done from Python. This is commonly done through a Python script or from an interactive Python notebook, such as Jupyter or Colab. The example below shows instantiation and inputs for a simple PyTorch model. The inputs are prepared as a tuple of torch.Tensors, and the model can run with these inputs.
49+
50+
```python
51+
import torch
52+
53+
class Model(torch.nn.Module):
54+
def __init__(self):
55+
super().__init__()
56+
self.seq = torch.nn.Sequential(
57+
torch.nn.Conv2d(1, 8, 3),
58+
torch.nn.ReLU(),
59+
torch.nn.Conv2d(8, 16, 3),
60+
torch.nn.ReLU(),
61+
torch.nn.AdaptiveAvgPool2d([[1,1]])
62+
)
63+
self.linear = torch.nn.Linear(16, 10)
64+
65+
def forward(self, x):
66+
y = self.seq(x)
67+
y = torch.flatten(y, 1)
68+
y = self.linear(y)
69+
return y
70+
71+
model = Model()
72+
inputs = (torch.randn(1,1,16,16),)
73+
outputs = model(*inputs)
74+
print(f"Model output: {outputs}")
75+
```
76+
77+
## Export and Lowering
78+
79+
To actually export and lower the model, call `export`, `to_edge_transform_and_lower`, and `to_executorch` in sequence. This yields an ExecuTorch program which can be serialized to a file. Putting it all together, lowering the example model above using the XNNPACK delegate for mobile CPU performance can be done as follows:
80+
81+
```python
82+
import torch
83+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
84+
from executorch.exir import to_edge_transform_and_lower
85+
from torch.export import Dim, export
86+
87+
class Model(torch.nn.Module):
88+
def __init__(self):
89+
super().__init__()
90+
self.seq = torch.nn.Sequential(
91+
torch.nn.Conv2d(1, 8, 3),
92+
torch.nn.ReLU(),
93+
torch.nn.Conv2d(8, 16, 3),
94+
torch.nn.ReLU(),
95+
torch.nn.AdaptiveAvgPool2d([1,1])
96+
)
97+
self.linear = torch.nn.Linear(16, 10)
98+
99+
def forward(self, x):
100+
y = self.seq(x)
101+
y = torch.flatten(y, 1)
102+
y = self.linear(y)
103+
return y
104+
105+
model = Model()
106+
inputs = (torch.randn(1,1,16,16),)
107+
dynamic_shapes = {
108+
"x": {
109+
2: Dim("h", min=16, max=1024),
110+
3: Dim("w", min=16, max=1024),
111+
}
112+
}
113+
114+
exported_program = export(model, inputs, dynamic_shapes=dynamic_shapes)
115+
executorch_program = to_edge_transform_and_lower(
116+
exported_program,
117+
partitioner = [XnnpackPartitioner()]
118+
).to_executorch()
119+
120+
with open("model.pte", "wb") as file:
121+
file.write(executorch_program.buffer)
122+
```
123+
124+
This yields a `model.pte` file which can be run on mobile devices.
125+
126+
### Supporting Varying Input Sizes (Dynamic Shapes)
127+
128+
The PyTorch export process uses the example inputs provided to trace through the model and reason about the size and type of tensors at each step. Unless told otherwise, export will assume a fixed input size equal to the example inputs and will use this information to optimize the model.
129+
130+
Many models require support for varying input sizes. To support this, export takes a `dynamic_shapes` parameter, which informs the compiler of which dimensions can vary and their bounds. This takes the form of a nested dictionary, where keys correspond to input names and values specify the bounds for each input.
131+
132+
In the example model, inputs are provided as 4-dimensions tensors following the standard convention of batch, channels, height, and width (NCHW). An input with the shape `[1, 3, 16, 16]` indicates 1 batch, 3 channels, and a height and width of 16.
133+
134+
Suppose your model supports images with sizes between 16x16 and 1024x1024. The shape bounds can be specified as follows:
135+
136+
```
137+
dynamic_shapes = {
138+
"x": {
139+
2: Dim("h", min=16, max=1024),
140+
3: Dim("w", min=16, max=1024),
141+
}
142+
}
143+
144+
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
145+
```
146+
147+
In the above example, `"x"` corresponds to the parameter name in `Model.forward`. The 2 and 3 keys correpond to dimensions 2 and 3, which are height and width. As there are no specifications for batch and channel dimensions, these values are fixed according to the example inputs.
148+
149+
ExecuTorch uses the shape bounds both to optimize the model and to plan memory for model execution. For this reason, it is advised to set the dimension upper bounds to no higher than needed, as higher bounds increase memory consumption.
150+
151+
For more complex use cases, dynamic shape specification allows for mathematical relationships between dimensions. For more information on dynamic shape specification, see [Expressing Dynamism](https://pytorch.org/docs/stable/export.html#expressing-dynamism).
152+
153+
## Testing the Model
154+
155+
Before integrating the runtime code, it is common to test the exported model from Python. This can be used to evaluate model accuracy and sanity check behavior before moving to the target device. Note that not all hardware backends are available from Python, as they may require specialized hardware to function. See the specific backend documentation for more information on hardware requirements and the availablilty of simulators. The XNNPACK delegate used in this example is always available on host machines.
156+
157+
```python
158+
from executorch.runtime import Runtime
159+
160+
runtime = Runtime.get()
161+
162+
input_tensor = torch.randn(1, 3, 32, 32)
163+
program = runtime.load_program("model.pte")
164+
method = program.load_method("forward")
165+
outputs = method.execute([input_tensor])
166+
```
167+
168+
For more information, see [Runtime API Reference](executorch-runtime-api-reference.md).
169+
170+
## Next Steps
171+
172+
The PyTorch and ExecuTorch export and lowering APIs provide a high level of customizability to meet the needs of diverse hardware and models. See [torch.export](https://pytorch.org/docs/main/export.html) and [Export API Reference](export-to-executorch-api-reference.md) for more information.
173+
174+
For advanced use cases, see the following:
175+
- [Quantization Overview](quantization-overview.md) for information on quantizing models to reduce inference time and memory footprint.
176+
- [Memory Planning](compiler-memory-planning.md) for information on controlling memory placement and planning.
177+
- [Custom Compiler Passes](compiler-custom-compiler-passes.md) for information on writing custom compiler passes.
178+
- [Export IR Specification](ir-exir.md) for information on the intermediate representation generated by export.

0 commit comments

Comments
 (0)