Skip to content

Commit fdfaa40

Browse files
angelayifacebook-github-bot
authored andcommitted
Update tutorial (#3242)
Summary: Removed the use of capture_pre_autograd_graph in places where we are not quantizing, since we want to minimize the usage of this API for easier deprecation in the future. Reviewed By: mergennachin Differential Revision: D56475332
1 parent ee8c3a6 commit fdfaa40

File tree

1 file changed

+33
-73
lines changed

1 file changed

+33
-73
lines changed

docs/source/tutorials_source/export-to-executorch-tutorial.py

Lines changed: 33 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,11 @@
4444
#
4545
# The first step of lowering to ExecuTorch is to export the given model (any
4646
# callable or ``torch.nn.Module``) to a graph representation. This is done via
47-
# the two-stage APIs, ``torch._export.capture_pre_autograd_graph``, and
48-
# ``torch.export``.
49-
#
50-
# Both APIs take in a model (any callable or ``torch.nn.Module``), a tuple of
47+
# ``torch.export``, which takes in an ``torch.nn.Module``, a tuple of
5148
# positional arguments, optionally a dictionary of keyword arguments (not shown
5249
# in the example), and a list of dynamic shapes (covered later).
5350

5451
import torch
55-
from torch._export import capture_pre_autograd_graph
5652
from torch.export import export, ExportedProgram
5753

5854

@@ -70,40 +66,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7066

7167

7268
example_args = (torch.randn(1, 3, 256, 256),)
73-
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
74-
print("Pre-Autograd ATen Dialect Graph")
75-
print(pre_autograd_aten_dialect)
76-
77-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
78-
print("ATen Dialect Graph")
69+
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
7970
print(aten_dialect)
8071

8172
######################################################################
82-
# The output of ``torch._export.capture_pre_autograd_graph`` is a fully
83-
# flattened graph (meaning the graph does not contain any module hierarchy,
84-
# except in the case of control flow operators). Furthermore, the captured graph
85-
# contains only ATen operators (~3000 ops) which are Autograd safe, for example, safe
86-
# for eager mode training.
87-
#
88-
# The output of ``torch.export`` further compiles the graph to a lower and
89-
# cleaner representation. Specifically, it has the following:
90-
#
91-
# - The graph is purely functional, meaning it does not contain operations with
92-
# side effects such as mutations or aliasing.
93-
# - The graph contains only a small defined
94-
# `Core ATen IR <https://pytorch.org/docs/stable/torch.compiler_ir.html#core-aten-ir>`__
95-
# operator set (~180 ops), along with registered custom operators.
96-
# - The nodes in the graph contain metadata captured during tracing, such as a
97-
# stacktrace from user's code.
73+
# The output of ``torch.export.export`` is a fully flattened graph (meaning the
74+
# graph does not contain any module hierarchy, except in the case of control
75+
# flow operators). Additionally, the graph is purely functional, meaning it does
76+
# not contain operations with side effects such as mutations or aliasing.
9877
#
9978
# More specifications about the result of ``torch.export`` can be found
100-
# `here <https://pytorch.org/docs/2.1/export.html>`__ .
79+
# `here <https://pytorch.org/docs/main/export.html>`__ .
10180
#
102-
# Since the result of ``torch.export`` is a graph containing the Core ATen
103-
# operators, we will call this the ``ATen Dialect``, and since
104-
# ``torch._export.capture_pre_autograd_graph`` returns a graph containing the
105-
# set of ATen operators which are Autograd safe, we will call it the
106-
# ``Pre-Autograd ATen Dialect``.
81+
# The graph returned by ``torch.export`` only contains functional ATen operators
82+
# (~2000 ops), which we will call the ``ATen Dialect``.
10783

10884
######################################################################
10985
# Expressing Dynamism
@@ -124,10 +100,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
124100
return x + y
125101

126102

127-
f = Basic()
128103
example_args = (torch.randn(3, 3), torch.randn(3, 3))
129-
pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
130-
aten_dialect: ExportedProgram = export(f, example_args)
104+
aten_dialect: ExportedProgram = export(Basic(), example_args)
131105

132106
# Works correctly
133107
print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3)))
@@ -153,15 +127,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
153127
return x + y
154128

155129

156-
f = Basic()
157130
example_args = (torch.randn(3, 3), torch.randn(3, 3))
158131
dim1_x = Dim("dim1_x", min=1, max=10)
159132
dynamic_shapes = {"x": {1: dim1_x}, "y": {1: dim1_x}}
160-
pre_autograd_aten_dialect = capture_pre_autograd_graph(
161-
f, example_args, dynamic_shapes=dynamic_shapes
133+
aten_dialect: ExportedProgram = export(
134+
Basic(), example_args, dynamic_shapes=dynamic_shapes
162135
)
163-
aten_dialect: ExportedProgram = export(f, example_args, dynamic_shapes=dynamic_shapes)
164-
print("ATen Dialect Graph")
165136
print(aten_dialect)
166137

167138
######################################################################
@@ -198,7 +169,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
198169
# As our goal is to capture the entire computational graph from a PyTorch
199170
# program, we might ultimately run into untraceable parts of programs. To
200171
# address these issues, the
201-
# `torch.export documentation <https://pytorch.org/docs/2.1/export.html#limitations-of-torch-export>`__,
172+
# `torch.export documentation <https://pytorch.org/docs/main/export.html#limitations-of-torch-export>`__,
202173
# or the
203174
# `torch.export tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__
204175
# would be the best place to look.
@@ -207,10 +178,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
207178
# Performing Quantization
208179
# -----------------------
209180
#
210-
# To quantize a model, we can do so between the call to
211-
# ``torch._export.capture_pre_autograd_graph`` and ``torch.export``, in the
212-
# ``Pre-Autograd ATen Dialect``. This is because quantization must operate at a
213-
# level which is safe for eager mode training.
181+
# To quantize a model, we first need to capture the graph with
182+
# ``torch._export.capture_pre_autograd_graph``, perform quantization, and then
183+
# call ``torch.export``. ``torch._export.capture_pre_autograd_graph`` returns a
184+
# graph which contains ATen operators which are Autograd safe, meaning they are
185+
# safe for eager-mode training, which is needed for quantization. We will call
186+
# the graph at this level, the ``Pre-Autograd ATen Dialect`` graph.
214187
#
215188
# Compared to
216189
# `FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__,
@@ -220,6 +193,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
220193
# will annotate the nodes in the graph with information needed to quantize the
221194
# model properly for a specific backend.
222195

196+
from torch._export import capture_pre_autograd_graph
197+
223198
example_args = (torch.randn(1, 3, 256, 256),)
224199
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
225200
print("Pre-Autograd ATen Dialect Graph")
@@ -268,13 +243,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
268243
from executorch.exir import EdgeProgramManager, to_edge
269244

270245
example_args = (torch.randn(1, 3, 256, 256),)
271-
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
272-
print("Pre-Autograd ATen Dialect Graph")
273-
print(pre_autograd_aten_dialect)
274-
275-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
276-
print("ATen Dialect Graph")
277-
print(aten_dialect)
246+
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
278247

279248
edge_program: EdgeProgramManager = to_edge(aten_dialect)
280249
print("Edge Dialect Graph")
@@ -298,16 +267,10 @@ def forward(self, x):
298267

299268

300269
encode_args = (torch.randn(1, 10),)
301-
aten_encode: ExportedProgram = export(
302-
capture_pre_autograd_graph(Encode(), encode_args),
303-
encode_args,
304-
)
270+
aten_encode: ExportedProgram = export(Encode(), encode_args)
305271

306272
decode_args = (torch.randn(1, 5),)
307-
aten_decode: ExportedProgram = export(
308-
capture_pre_autograd_graph(Decode(), decode_args),
309-
decode_args,
310-
)
273+
aten_decode: ExportedProgram = export(Decode(), decode_args)
311274

312275
edge_program: EdgeProgramManager = to_edge(
313276
{"encode": aten_encode, "decode": aten_decode}
@@ -328,8 +291,7 @@ def forward(self, x):
328291
# rather than the ``torch.ops.aten`` namespace.
329292

330293
example_args = (torch.randn(1, 3, 256, 256),)
331-
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
332-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
294+
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
333295
edge_program: EdgeProgramManager = to_edge(aten_dialect)
334296
print("Edge Dialect Graph")
335297
print(edge_program.exported_program())
@@ -353,7 +315,9 @@ def call_operator(self, op, args, kwargs, meta):
353315
print(transformed_edge_program.exported_program())
354316

355317
######################################################################
356-
# Note: if you see error like `torch._export.verifier.SpecViolationError: Operator torch._ops.aten._native_batch_norm_legit_functional.default is not Aten Canonical`,
318+
# Note: if you see error like ``torch._export.verifier.SpecViolationError:
319+
# Operator torch._ops.aten._native_batch_norm_legit_functional.default is not
320+
# Aten Canonical``,
357321
# please file an issue in https://github.com/pytorch/executorch/issues and we're happy to help!
358322

359323

@@ -365,7 +329,7 @@ def call_operator(self, op, args, kwargs, meta):
365329
# backend through the ``to_backend`` API. An in-depth documentation on the
366330
# specifics of backend delegation, including how to delegate to a backend and
367331
# how to implement a backend, can be found
368-
# `here <../compiler-delegate-and-partitioner.html>`__
332+
# `here <../compiler-delegate-and-partitioner.html>`__.
369333
#
370334
# There are three ways for using this API:
371335
#
@@ -393,8 +357,7 @@ def forward(self, x):
393357

394358
# Export and lower the module to Edge Dialect
395359
example_args = (torch.ones(1),)
396-
pre_autograd_aten_dialect = capture_pre_autograd_graph(LowerableModule(), example_args)
397-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
360+
aten_dialect: ExportedProgram = export(LowerableModule(), example_args)
398361
edge_program: EdgeProgramManager = to_edge(aten_dialect)
399362
to_be_lowered_module = edge_program.exported_program()
400363

@@ -460,8 +423,7 @@ def forward(self, x):
460423

461424

462425
example_args = (torch.ones(1),)
463-
pre_autograd_aten_dialect = capture_pre_autograd_graph(ComposedModule(), example_args)
464-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
426+
aten_dialect: ExportedProgram = export(ComposedModule(), example_args)
465427
edge_program: EdgeProgramManager = to_edge(aten_dialect)
466428
exported_program = edge_program.exported_program()
467429
print("Edge Dialect graph")
@@ -499,8 +461,7 @@ def forward(self, a, x, b):
499461

500462

501463
example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
502-
pre_autograd_aten_dialect = capture_pre_autograd_graph(Foo(), example_args)
503-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
464+
aten_dialect: ExportedProgram = export(Foo(), example_args)
504465
edge_program: EdgeProgramManager = to_edge(aten_dialect)
505466
exported_program = edge_program.exported_program()
506467
print("Edge Dialect graph")
@@ -534,8 +495,7 @@ def forward(self, a, x, b):
534495

535496

536497
example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
537-
pre_autograd_aten_dialect = capture_pre_autograd_graph(Foo(), example_args)
538-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
498+
aten_dialect: ExportedProgram = export(Foo(), example_args)
539499
edge_program: EdgeProgramManager = to_edge(aten_dialect)
540500
exported_program = edge_program.exported_program()
541501
delegated_program = edge_program.to_backend(AddMulPartitionerDemo())

0 commit comments

Comments
 (0)