Skip to content

Commit 23e4888

Browse files
angelayifacebook-github-bot
authored andcommitted
Update tutorial
Summary: Removed the use of capture_pre_autograd_graph in places where we are not quantizing, since we want to minimize the usage of this API for easier deprecation in the future. Differential Revision: D56475332
1 parent cb77763 commit 23e4888

File tree

1 file changed

+36
-74
lines changed

1 file changed

+36
-74
lines changed

docs/source/tutorials_source/export-to-executorch-tutorial.py

Lines changed: 36 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,11 @@
4444
#
4545
# The first step of lowering to ExecuTorch is to export the given model (any
4646
# callable or ``torch.nn.Module``) to a graph representation. This is done via
47-
# the two-stage APIs, ``torch._export.capture_pre_autograd_graph``, and
48-
# ``torch.export``.
49-
#
50-
# Both APIs take in a model (any callable or ``torch.nn.Module``), a tuple of
47+
# ``torch.export``, which takes in an ``torch.nn.Module``, a tuple of
5148
# positional arguments, optionally a dictionary of keyword arguments (not shown
5249
# in the example), and a list of dynamic shapes (covered later).
5350

5451
import torch
55-
from torch._export import capture_pre_autograd_graph
5652
from torch.export import export, ExportedProgram
5753

5854

@@ -70,40 +66,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7066

7167

7268
example_args = (torch.randn(1, 3, 256, 256),)
73-
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
74-
print("Pre-Autograd ATen Dialect Graph")
75-
print(pre_autograd_aten_dialect)
76-
77-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
78-
print("ATen Dialect Graph")
69+
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
7970
print(aten_dialect)
8071

8172
######################################################################
82-
# The output of ``torch._export.capture_pre_autograd_graph`` is a fully
83-
# flattened graph (meaning the graph does not contain any module hierarchy,
84-
# except in the case of control flow operators). Furthermore, the captured graph
85-
# contains only ATen operators (~3000 ops) which are Autograd safe, for example, safe
86-
# for eager mode training.
87-
#
88-
# The output of ``torch.export`` further compiles the graph to a lower and
89-
# cleaner representation. Specifically, it has the following:
90-
#
91-
# - The graph is purely functional, meaning it does not contain operations with
92-
# side effects such as mutations or aliasing.
93-
# - The graph contains only a small defined
94-
# `Core ATen IR <https://pytorch.org/docs/stable/torch.compiler_ir.html#core-aten-ir>`__
95-
# operator set (~180 ops), along with registered custom operators.
96-
# - The nodes in the graph contain metadata captured during tracing, such as a
97-
# stacktrace from user's code.
73+
# The output of ``torch.export.export`` is a fully flattened graph (meaning the
74+
# graph does not contain any module hierarchy, except in the case of control
75+
# flow operators). Additionally, the graph is purely functional, meaning it does
76+
# not contain operations with side effects such as mutations or aliasing.
9877
#
9978
# More specifications about the result of ``torch.export`` can be found
100-
# `here <https://pytorch.org/docs/2.1/export.html>`__ .
79+
# `here <https://pytorch.org/docs/main/export.html>`__ .
10180
#
102-
# Since the result of ``torch.export`` is a graph containing the Core ATen
103-
# operators, we will call this the ``ATen Dialect``, and since
104-
# ``torch._export.capture_pre_autograd_graph`` returns a graph containing the
105-
# set of ATen operators which are Autograd safe, we will call it the
106-
# ``Pre-Autograd ATen Dialect``.
81+
# The graph returned by ``torch.export`` only contains functional ATen operators
82+
# (~2000 ops), which we will call the ``ATen Dialect``.
10783

10884
######################################################################
10985
# Expressing Dynamism
@@ -124,10 +100,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
124100
return x + y
125101

126102

127-
f = Basic()
128103
example_args = (torch.randn(3, 3), torch.randn(3, 3))
129-
pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
130-
aten_dialect: ExportedProgram = export(f, example_args)
104+
aten_dialect: ExportedProgram = export(Basic(), example_args)
131105

132106
# Works correctly
133107
print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3)))
@@ -153,15 +127,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
153127
return x + y
154128

155129

156-
f = Basic()
157130
example_args = (torch.randn(3, 3), torch.randn(3, 3))
158131
dim1_x = Dim("dim1_x", min=1, max=10)
159132
dynamic_shapes = {"x": {1: dim1_x}, "y": {1: dim1_x}}
160-
pre_autograd_aten_dialect = capture_pre_autograd_graph(
161-
f, example_args, dynamic_shapes=dynamic_shapes
133+
aten_dialect: ExportedProgram = export(
134+
Basic(), example_args, dynamic_shapes=dynamic_shapes
162135
)
163-
aten_dialect: ExportedProgram = export(f, example_args, dynamic_shapes=dynamic_shapes)
164-
print("ATen Dialect Graph")
165136
print(aten_dialect)
166137

167138
######################################################################
@@ -198,7 +169,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
198169
# As our goal is to capture the entire computational graph from a PyTorch
199170
# program, we might ultimately run into untraceable parts of programs. To
200171
# address these issues, the
201-
# `torch.export documentation <https://pytorch.org/docs/2.1/export.html#limitations-of-torch-export>`__,
172+
# `torch.export documentation <https://pytorch.org/docs/main/export.html#limitations-of-torch-export>`__,
202173
# or the
203174
# `torch.export tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__
204175
# would be the best place to look.
@@ -207,10 +178,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
207178
# Performing Quantization
208179
# -----------------------
209180
#
210-
# To quantize a model, we can do so between the call to
211-
# ``torch._export.capture_pre_autograd_graph`` and ``torch.export``, in the
212-
# ``Pre-Autograd ATen Dialect``. This is because quantization must operate at a
213-
# level which is safe for eager mode training.
181+
# To quantize a model, we first need to capture the graph with
182+
# ``torch._export.capture_pre_autograd_graph``, perform quantization, and then
183+
# call ``torch.export``. ``torch._export.capture_pre_autograd_graph`` returns a
184+
# graph which contains ATen operators which are Autograd safe, meaning they are
185+
# safe for eager-mode training, which is needed for quantization. We will call
186+
# the graph at this level, the ``Pre-Autograd ATen Dialect`` graph.
214187
#
215188
# Compared to
216189
# `FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__,
@@ -220,8 +193,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
220193
# will annotate the nodes in the graph with information needed to quantize the
221194
# model properly for a specific backend.
222195

196+
from torch._export import capture_pre_autograd_graph
197+
223198
example_args = (torch.randn(1, 3, 256, 256),)
224-
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
199+
pre_autograd_aten_dialect = capture_pre_autograd_graph(
200+
SimpleConv(), example_args
201+
)
225202
print("Pre-Autograd ATen Dialect Graph")
226203
print(pre_autograd_aten_dialect)
227204

@@ -268,13 +245,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
268245
from executorch.exir import EdgeProgramManager, to_edge
269246

270247
example_args = (torch.randn(1, 3, 256, 256),)
271-
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
272-
print("Pre-Autograd ATen Dialect Graph")
273-
print(pre_autograd_aten_dialect)
274-
275-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
276-
print("ATen Dialect Graph")
277-
print(aten_dialect)
248+
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
278249

279250
edge_program: EdgeProgramManager = to_edge(aten_dialect)
280251
print("Edge Dialect Graph")
@@ -298,16 +269,10 @@ def forward(self, x):
298269

299270

300271
encode_args = (torch.randn(1, 10),)
301-
aten_encode: ExportedProgram = export(
302-
capture_pre_autograd_graph(Encode(), encode_args),
303-
encode_args,
304-
)
272+
aten_encode: ExportedProgram = export(Encode(), encode_args)
305273

306274
decode_args = (torch.randn(1, 5),)
307-
aten_decode: ExportedProgram = export(
308-
capture_pre_autograd_graph(Decode(), decode_args),
309-
decode_args,
310-
)
275+
aten_decode: ExportedProgram = export(Decode(), decode_args)
311276

312277
edge_program: EdgeProgramManager = to_edge(
313278
{"encode": aten_encode, "decode": aten_decode}
@@ -328,8 +293,7 @@ def forward(self, x):
328293
# rather than the ``torch.ops.aten`` namespace.
329294

330295
example_args = (torch.randn(1, 3, 256, 256),)
331-
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
332-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
296+
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
333297
edge_program: EdgeProgramManager = to_edge(aten_dialect)
334298
print("Edge Dialect Graph")
335299
print(edge_program.exported_program())
@@ -353,7 +317,9 @@ def call_operator(self, op, args, kwargs, meta):
353317
print(transformed_edge_program.exported_program())
354318

355319
######################################################################
356-
# Note: if you see error like `torch._export.verifier.SpecViolationError: Operator torch._ops.aten._native_batch_norm_legit_functional.default is not Aten Canonical`,
320+
# Note: if you see error like ``torch._export.verifier.SpecViolationError:
321+
# Operator torch._ops.aten._native_batch_norm_legit_functional.default is not
322+
# Aten Canonical``,
357323
# please file an issue in https://github.com/pytorch/executorch/issues and we're happy to help!
358324

359325

@@ -365,7 +331,7 @@ def call_operator(self, op, args, kwargs, meta):
365331
# backend through the ``to_backend`` API. An in-depth documentation on the
366332
# specifics of backend delegation, including how to delegate to a backend and
367333
# how to implement a backend, can be found
368-
# `here <../compiler-delegate-and-partitioner.html>`__
334+
# `here <../compiler-delegate-and-partitioner.html>`__.
369335
#
370336
# There are three ways for using this API:
371337
#
@@ -393,8 +359,7 @@ def forward(self, x):
393359

394360
# Export and lower the module to Edge Dialect
395361
example_args = (torch.ones(1),)
396-
pre_autograd_aten_dialect = capture_pre_autograd_graph(LowerableModule(), example_args)
397-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
362+
aten_dialect: ExportedProgram = export(LowerableModule(), example_args)
398363
edge_program: EdgeProgramManager = to_edge(aten_dialect)
399364
to_be_lowered_module = edge_program.exported_program()
400365

@@ -460,8 +425,7 @@ def forward(self, x):
460425

461426

462427
example_args = (torch.ones(1),)
463-
pre_autograd_aten_dialect = capture_pre_autograd_graph(ComposedModule(), example_args)
464-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
428+
aten_dialect: ExportedProgram = export(ComposedModule(), example_args)
465429
edge_program: EdgeProgramManager = to_edge(aten_dialect)
466430
exported_program = edge_program.exported_program()
467431
print("Edge Dialect graph")
@@ -499,8 +463,7 @@ def forward(self, a, x, b):
499463

500464

501465
example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
502-
pre_autograd_aten_dialect = capture_pre_autograd_graph(Foo(), example_args)
503-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
466+
aten_dialect: ExportedProgram = export(Foo(), example_args)
504467
edge_program: EdgeProgramManager = to_edge(aten_dialect)
505468
exported_program = edge_program.exported_program()
506469
print("Edge Dialect graph")
@@ -534,8 +497,7 @@ def forward(self, a, x, b):
534497

535498

536499
example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
537-
pre_autograd_aten_dialect = capture_pre_autograd_graph(Foo(), example_args)
538-
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
500+
aten_dialect: ExportedProgram = export(Foo(), example_args)
539501
edge_program: EdgeProgramManager = to_edge(aten_dialect)
540502
exported_program = edge_program.exported_program()
541503
delegated_program = edge_program.to_backend(AddMulPartitionerDemo())

0 commit comments

Comments
 (0)