Skip to content

Commit f618121

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Update export documentation (#296)
Summary: Pull Request resolved: #296 In light of two stage APIs, export documentation neeeds to change. Update also links to the gh issue, establishing context and the need for the change, along with highlighting what is the long term plan for export API. Reviewed By: mergennachin Differential Revision: D49209531 fbshipit-source-id: 6c9dd56609f663f57090ecbba5bb8e2b5de94411
1 parent b2ef921 commit f618121

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed
-346 KB
Loading

docs/website/docs/tutorials/exporting_to_executorch.md

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ embedded devices. At a high level, the AOT steps are the following:
3232

3333
### 1.1 Exporting to EXIR ATen Dialect
3434

35-
The entrypoint to Executorch is through the `exir.capture` API. This function
36-
utilizes [torch.export](https://pytorch.org/docs/main/export.html) to
37-
fully capture a PyTorch Model (either `torch.nn.Module` or a callable) into a
35+
NB: Export APIs are undergoing changes to align better with long term state of export. Please refer to https://github.com/pytorch/executorch/issues/290, for more details.
36+
37+
The entrypoint to Executorch is through the `torch._export.capture_pre_autograd_graph` API, which is used
38+
to fully capture a PyTorch Model (either `torch.nn.Module` or a callable) into a
3839
`torch.fx` graph representation.
3940

4041
In order for the model to
@@ -48,23 +49,26 @@ through registering the custom operator to a torch library and providing a meta
4849
kernel. To learn more about exporting a model or if you have trouble exporting,
4950
you can look at [these docs](../export/00_export_manual.md)
5051

51-
To enable exporting input shape-dependent models, the `exir.capture` API also
52+
To enable exporting input shape-dependent models, this API also
5253
takes in a list of constraints, where users can specify which input shapes are
5354
dynamic and impose ranges on them. To learn more about constraints, you can look
5455
at [these docs](../export/constraint_apis.md)
5556

56-
The output of `exir.capture` is a fully flattened graph (meaning the graph does
57-
not contain any module heirachy, except in the case of control flow operators)
58-
containing the
59-
[Core ATen Operators](https://pytorch.org/docs/main/ir.html), functional
60-
variants of custom
61-
operators (they do not do any mutations or aliasing), and control flow
62-
operators. The detailed specification for the result of `exir.capture` can be
57+
The output of `torch._export.capture_pre_autograd_graph` is a fully flattened graph (meaning the graph does
58+
not contain any module heirachy, except in the case of control flow operators).
59+
Furthermore, the captured graph is in ATen dialect with ATen opset which is autograd safe, i.e. safe for eager mode training.
60+
This is important for quantization as noted in https://github.com/pytorch/executorch/issues/290.
61+
62+
ATen operator set of graph obtained via `torch._export.capture_pre_autograd_graph` is full set of ATen ops (~3000). This
63+
operator set is further refined into [Core ATen Operators](https://pytorch.org/docs/master/ir.html) via `exir.capture`.
64+
Resulting IR is functional, i.e. does not contain any mutations or aliasing on operator outputs.
65+
The detailed specification for the result of `exir.capture` can be
6366
found in the [EXIR Reference](../ir_spec/00_exir.md) and is specifically in the
6467
[EXIR ATen Dialect](../ir_spec/01_aten_dialect.md).
6568

6669
```python
6770
import torch
71+
from torch import _export as export
6872
import executorch.exir as exir
6973

7074
class MyModule(torch.nn.Module):
@@ -76,7 +80,9 @@ class MyModule(torch.nn.Module):
7680
def forward(self, x):
7781
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
7882

79-
aten_dialect = exir.capture(MyModule(), (torch.randn(3, 4),), constraints)
83+
pre_autograd_aten_dialect = export.capture_pre_autograd_graph(MyModule(), (torch.randn(3, 4),), constraints)
84+
# Quantization APIs can optionally be invoked on top of pre_autograd_aten_dialect
85+
aten_dialect = exir.capture(pre_autograd_aten_dialect, (torch.randn(3, 4),), constraints)
8086

8187
print(aten_dialect.exported_program)
8288
"""
@@ -95,7 +101,7 @@ At this point, users can choose to run additional passes through the
95101
`exported_program._transform(passes)` function. A tutorial on how to write
96102
transformations can be found [here](./passes.md).
97103

98-
Additionally, users can run quantization at this step. A tutorial for doing so can be found [here](./quantization_flow.md).
104+
For quantization API usage, a tutorial for doing so can be found [here](./quantization_flow.md).
99105

100106
### 1.2 Lower to EXIR Edge Dialect
101107

@@ -111,7 +117,8 @@ documentation on the Edge Dialect can be found
111117
This lowering will be done through the `to_edge()` API.
112118

113119
```python
114-
aten_dialect = exir.capture(MyModule(), (torch.randn(3, 4),))
120+
pre_autograd_aten_dialect = export.capture_pre_autograd_graph(MyModule(), (torch.randn(3, 4),), constraints)
121+
aten_dialect = exir.capture(pre_autograd_aten_dialect, (torch.randn(3, 4),))
115122
edge_dialect = aten_dialect.to_edge()
116123

117124
print(edge_dialect.exported_program)
@@ -162,7 +169,8 @@ planning pass can also be passed into `to_executorch`. A tutorial on how to
162169
write a memory plnaning pass is here (TODO).
163170

164171
```python
165-
aten_dialect = exir.capture(MyModule(), (torch.randn(3, 4),))
172+
pre_autograd_aten_dialect = export.capture_pre_autograd_graph(MyModule(), (torch.randn(3, 4),), constraints)
173+
aten_dialect = exir.capture(pre_autograd_aten_dialect, (torch.randn(3, 4),))
166174
edge_dialect = aten_dialect.to_edge()
167175
# edge_dialect = to_backend(edge_dialect.exported_program, CustomBackendPartitioner)
168176
executorch_program = edge_dialect.to_executorch(executorch_backend_config)
@@ -190,7 +198,8 @@ Finally, the exported and delegated graph can be saved to a flatbuffer file to
190198
be loaded in the Executorch runtime.
191199

192200
```python
193-
edge_dialect = exir.capture(MyModule(), (torch.randn(3, 4),)).to_edge()
201+
pre_autograd_aten_dialect = export.capture_pre_autograd_graph(MyModule(), (torch.randn(3, 4),), constraints)
202+
edge_dialect = exir.capture(pre_autograd_aten_dialect, (torch.randn(3, 4),)).to_edge()
194203
# edge_dialect = to_backend(edge_dialect.exported_program, CustomBackendPartitioner)
195204
executorch_program = edge_dialect.to_executorch(executorch_backend_config)
196205
buffer = executorch_program.buffer

0 commit comments

Comments
 (0)