Skip to content

Update tutorial #3242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 33 additions & 73 deletions docs/source/tutorials_source/export-to-executorch-tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,11 @@
#
# The first step of lowering to ExecuTorch is to export the given model (any
# callable or ``torch.nn.Module``) to a graph representation. This is done via
# the two-stage APIs, ``torch._export.capture_pre_autograd_graph``, and
# ``torch.export``.
#
# Both APIs take in a model (any callable or ``torch.nn.Module``), a tuple of
# ``torch.export``, which takes in an ``torch.nn.Module``, a tuple of
# positional arguments, optionally a dictionary of keyword arguments (not shown
# in the example), and a list of dynamic shapes (covered later).

import torch
from torch._export import capture_pre_autograd_graph
from torch.export import export, ExportedProgram


Expand All @@ -70,40 +66,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


example_args = (torch.randn(1, 3, 256, 256),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
print("Pre-Autograd ATen Dialect Graph")
print(pre_autograd_aten_dialect)

aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
print("ATen Dialect Graph")
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
print(aten_dialect)

######################################################################
# The output of ``torch._export.capture_pre_autograd_graph`` is a fully
# flattened graph (meaning the graph does not contain any module hierarchy,
# except in the case of control flow operators). Furthermore, the captured graph
# contains only ATen operators (~3000 ops) which are Autograd safe, for example, safe
# for eager mode training.
#
# The output of ``torch.export`` further compiles the graph to a lower and
# cleaner representation. Specifically, it has the following:
#
# - The graph is purely functional, meaning it does not contain operations with
# side effects such as mutations or aliasing.
# - The graph contains only a small defined
# `Core ATen IR <https://pytorch.org/docs/stable/torch.compiler_ir.html#core-aten-ir>`__
# operator set (~180 ops), along with registered custom operators.
# - The nodes in the graph contain metadata captured during tracing, such as a
# stacktrace from user's code.
# The output of ``torch.export.export`` is a fully flattened graph (meaning the
# graph does not contain any module hierarchy, except in the case of control
# flow operators). Additionally, the graph is purely functional, meaning it does
# not contain operations with side effects such as mutations or aliasing.
#
# More specifications about the result of ``torch.export`` can be found
# `here <https://pytorch.org/docs/2.1/export.html>`__ .
# `here <https://pytorch.org/docs/main/export.html>`__ .
#
# Since the result of ``torch.export`` is a graph containing the Core ATen
# operators, we will call this the ``ATen Dialect``, and since
# ``torch._export.capture_pre_autograd_graph`` returns a graph containing the
# set of ATen operators which are Autograd safe, we will call it the
# ``Pre-Autograd ATen Dialect``.
# The graph returned by ``torch.export`` only contains functional ATen operators
# (~2000 ops), which we will call the ``ATen Dialect``.

######################################################################
# Expressing Dynamism
Expand All @@ -124,10 +100,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y


f = Basic()
example_args = (torch.randn(3, 3), torch.randn(3, 3))
pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
aten_dialect: ExportedProgram = export(f, example_args)
aten_dialect: ExportedProgram = export(Basic(), example_args)

# Works correctly
print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3)))
Expand All @@ -153,15 +127,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y


f = Basic()
example_args = (torch.randn(3, 3), torch.randn(3, 3))
dim1_x = Dim("dim1_x", min=1, max=10)
dynamic_shapes = {"x": {1: dim1_x}, "y": {1: dim1_x}}
pre_autograd_aten_dialect = capture_pre_autograd_graph(
f, example_args, dynamic_shapes=dynamic_shapes
aten_dialect: ExportedProgram = export(
Basic(), example_args, dynamic_shapes=dynamic_shapes
)
aten_dialect: ExportedProgram = export(f, example_args, dynamic_shapes=dynamic_shapes)
print("ATen Dialect Graph")
print(aten_dialect)

######################################################################
Expand Down Expand Up @@ -198,7 +169,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# As our goal is to capture the entire computational graph from a PyTorch
# program, we might ultimately run into untraceable parts of programs. To
# address these issues, the
# `torch.export documentation <https://pytorch.org/docs/2.1/export.html#limitations-of-torch-export>`__,
# `torch.export documentation <https://pytorch.org/docs/main/export.html#limitations-of-torch-export>`__,
# or the
# `torch.export tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__
# would be the best place to look.
Expand All @@ -207,10 +178,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# Performing Quantization
# -----------------------
#
# To quantize a model, we can do so between the call to
# ``torch._export.capture_pre_autograd_graph`` and ``torch.export``, in the
# ``Pre-Autograd ATen Dialect``. This is because quantization must operate at a
# level which is safe for eager mode training.
# To quantize a model, we first need to capture the graph with
# ``torch._export.capture_pre_autograd_graph``, perform quantization, and then
# call ``torch.export``. ``torch._export.capture_pre_autograd_graph`` returns a
# graph which contains ATen operators which are Autograd safe, meaning they are
# safe for eager-mode training, which is needed for quantization. We will call
# the graph at this level, the ``Pre-Autograd ATen Dialect`` graph.
#
# Compared to
# `FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__,
Expand All @@ -220,6 +193,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# will annotate the nodes in the graph with information needed to quantize the
# model properly for a specific backend.

from torch._export import capture_pre_autograd_graph

example_args = (torch.randn(1, 3, 256, 256),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
print("Pre-Autograd ATen Dialect Graph")
Expand Down Expand Up @@ -268,13 +243,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
from executorch.exir import EdgeProgramManager, to_edge

example_args = (torch.randn(1, 3, 256, 256),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
print("Pre-Autograd ATen Dialect Graph")
print(pre_autograd_aten_dialect)

aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
print("ATen Dialect Graph")
print(aten_dialect)
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)

edge_program: EdgeProgramManager = to_edge(aten_dialect)
print("Edge Dialect Graph")
Expand All @@ -298,16 +267,10 @@ def forward(self, x):


encode_args = (torch.randn(1, 10),)
aten_encode: ExportedProgram = export(
capture_pre_autograd_graph(Encode(), encode_args),
encode_args,
)
aten_encode: ExportedProgram = export(Encode(), encode_args)

decode_args = (torch.randn(1, 5),)
aten_decode: ExportedProgram = export(
capture_pre_autograd_graph(Decode(), decode_args),
decode_args,
)
aten_decode: ExportedProgram = export(Decode(), decode_args)

edge_program: EdgeProgramManager = to_edge(
{"encode": aten_encode, "decode": aten_decode}
Expand All @@ -328,8 +291,7 @@ def forward(self, x):
# rather than the ``torch.ops.aten`` namespace.

example_args = (torch.randn(1, 3, 256, 256),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
print("Edge Dialect Graph")
print(edge_program.exported_program())
Expand All @@ -353,7 +315,9 @@ def call_operator(self, op, args, kwargs, meta):
print(transformed_edge_program.exported_program())

######################################################################
# Note: if you see error like `torch._export.verifier.SpecViolationError: Operator torch._ops.aten._native_batch_norm_legit_functional.default is not Aten Canonical`,
# Note: if you see error like ``torch._export.verifier.SpecViolationError:
# Operator torch._ops.aten._native_batch_norm_legit_functional.default is not
# Aten Canonical``,
# please file an issue in https://github.com/pytorch/executorch/issues and we're happy to help!


Expand All @@ -365,7 +329,7 @@ def call_operator(self, op, args, kwargs, meta):
# backend through the ``to_backend`` API. An in-depth documentation on the
# specifics of backend delegation, including how to delegate to a backend and
# how to implement a backend, can be found
# `here <../compiler-delegate-and-partitioner.html>`__
# `here <../compiler-delegate-and-partitioner.html>`__.
#
# There are three ways for using this API:
#
Expand Down Expand Up @@ -393,8 +357,7 @@ def forward(self, x):

# Export and lower the module to Edge Dialect
example_args = (torch.ones(1),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(LowerableModule(), example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
aten_dialect: ExportedProgram = export(LowerableModule(), example_args)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
to_be_lowered_module = edge_program.exported_program()

Expand Down Expand Up @@ -460,8 +423,7 @@ def forward(self, x):


example_args = (torch.ones(1),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(ComposedModule(), example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
aten_dialect: ExportedProgram = export(ComposedModule(), example_args)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
exported_program = edge_program.exported_program()
print("Edge Dialect graph")
Expand Down Expand Up @@ -499,8 +461,7 @@ def forward(self, a, x, b):


example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
pre_autograd_aten_dialect = capture_pre_autograd_graph(Foo(), example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
aten_dialect: ExportedProgram = export(Foo(), example_args)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
exported_program = edge_program.exported_program()
print("Edge Dialect graph")
Expand Down Expand Up @@ -534,8 +495,7 @@ def forward(self, a, x, b):


example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
pre_autograd_aten_dialect = capture_pre_autograd_graph(Foo(), example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
aten_dialect: ExportedProgram = export(Foo(), example_args)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
exported_program = edge_program.exported_program()
delegated_program = edge_program.to_backend(AddMulPartitionerDemo())
Expand Down