44
44
#
45
45
# The first step of lowering to ExecuTorch is to export the given model (any
46
46
# callable or ``torch.nn.Module``) to a graph representation. This is done via
47
- # the two-stage APIs, ``torch._export.capture_pre_autograd_graph``, and
48
- # ``torch.export``.
49
- #
50
- # Both APIs take in a model (any callable or ``torch.nn.Module``), a tuple of
47
+ # ``torch.export``, which takes in an ``torch.nn.Module``, a tuple of
51
48
# positional arguments, optionally a dictionary of keyword arguments (not shown
52
49
# in the example), and a list of dynamic shapes (covered later).
53
50
54
51
import torch
55
- from torch ._export import capture_pre_autograd_graph
56
52
from torch .export import export , ExportedProgram
57
53
58
54
@@ -70,40 +66,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
66
71
67
72
68
example_args = (torch .randn (1 , 3 , 256 , 256 ),)
73
- pre_autograd_aten_dialect = capture_pre_autograd_graph (SimpleConv (), example_args )
74
- print ("Pre-Autograd ATen Dialect Graph" )
75
- print (pre_autograd_aten_dialect )
76
-
77
- aten_dialect : ExportedProgram = export (pre_autograd_aten_dialect , example_args )
78
- print ("ATen Dialect Graph" )
69
+ aten_dialect : ExportedProgram = export (SimpleConv (), example_args )
79
70
print (aten_dialect )
80
71
81
72
######################################################################
82
- # The output of ``torch._export.capture_pre_autograd_graph`` is a fully
83
- # flattened graph (meaning the graph does not contain any module hierarchy,
84
- # except in the case of control flow operators). Furthermore, the captured graph
85
- # contains only ATen operators (~3000 ops) which are Autograd safe, for example, safe
86
- # for eager mode training.
87
- #
88
- # The output of ``torch.export`` further compiles the graph to a lower and
89
- # cleaner representation. Specifically, it has the following:
90
- #
91
- # - The graph is purely functional, meaning it does not contain operations with
92
- # side effects such as mutations or aliasing.
93
- # - The graph contains only a small defined
94
- # `Core ATen IR <https://pytorch.org/docs/stable/torch.compiler_ir.html#core-aten-ir>`__
95
- # operator set (~180 ops), along with registered custom operators.
96
- # - The nodes in the graph contain metadata captured during tracing, such as a
97
- # stacktrace from user's code.
73
+ # The output of ``torch.export.export`` is a fully flattened graph (meaning the
74
+ # graph does not contain any module hierarchy, except in the case of control
75
+ # flow operators). Additionally, the graph is purely functional, meaning it does
76
+ # not contain operations with side effects such as mutations or aliasing.
98
77
#
99
78
# More specifications about the result of ``torch.export`` can be found
100
- # `here <https://pytorch.org/docs/2.1 /export.html>`__ .
79
+ # `here <https://pytorch.org/docs/main /export.html>`__ .
101
80
#
102
- # Since the result of ``torch.export`` is a graph containing the Core ATen
103
- # operators, we will call this the ``ATen Dialect``, and since
104
- # ``torch._export.capture_pre_autograd_graph`` returns a graph containing the
105
- # set of ATen operators which are Autograd safe, we will call it the
106
- # ``Pre-Autograd ATen Dialect``.
81
+ # The graph returned by ``torch.export`` only contains functional ATen operators
82
+ # (~2000 ops), which we will call the ``ATen Dialect``.
107
83
108
84
######################################################################
109
85
# Expressing Dynamism
@@ -124,10 +100,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
124
100
return x + y
125
101
126
102
127
- f = Basic ()
128
103
example_args = (torch .randn (3 , 3 ), torch .randn (3 , 3 ))
129
- pre_autograd_aten_dialect = capture_pre_autograd_graph (f , example_args )
130
- aten_dialect : ExportedProgram = export (f , example_args )
104
+ aten_dialect : ExportedProgram = export (Basic (), example_args )
131
105
132
106
# Works correctly
133
107
print (aten_dialect .module ()(torch .ones (3 , 3 ), torch .ones (3 , 3 )))
@@ -153,15 +127,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
153
127
return x + y
154
128
155
129
156
- f = Basic ()
157
130
example_args = (torch .randn (3 , 3 ), torch .randn (3 , 3 ))
158
131
dim1_x = Dim ("dim1_x" , min = 1 , max = 10 )
159
132
dynamic_shapes = {"x" : {1 : dim1_x }, "y" : {1 : dim1_x }}
160
- pre_autograd_aten_dialect = capture_pre_autograd_graph (
161
- f , example_args , dynamic_shapes = dynamic_shapes
133
+ aten_dialect : ExportedProgram = export (
134
+ Basic () , example_args , dynamic_shapes = dynamic_shapes
162
135
)
163
- aten_dialect : ExportedProgram = export (f , example_args , dynamic_shapes = dynamic_shapes )
164
- print ("ATen Dialect Graph" )
165
136
print (aten_dialect )
166
137
167
138
######################################################################
@@ -198,7 +169,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
198
169
# As our goal is to capture the entire computational graph from a PyTorch
199
170
# program, we might ultimately run into untraceable parts of programs. To
200
171
# address these issues, the
201
- # `torch.export documentation <https://pytorch.org/docs/2.1 /export.html#limitations-of-torch-export>`__,
172
+ # `torch.export documentation <https://pytorch.org/docs/main /export.html#limitations-of-torch-export>`__,
202
173
# or the
203
174
# `torch.export tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__
204
175
# would be the best place to look.
@@ -207,10 +178,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
207
178
# Performing Quantization
208
179
# -----------------------
209
180
#
210
- # To quantize a model, we can do so between the call to
211
- # ``torch._export.capture_pre_autograd_graph`` and ``torch.export``, in the
212
- # ``Pre-Autograd ATen Dialect``. This is because quantization must operate at a
213
- # level which is safe for eager mode training.
181
+ # To quantize a model, we first need to capture the graph with
182
+ # ``torch._export.capture_pre_autograd_graph``, perform quantization, and then
183
+ # call ``torch.export``. ``torch._export.capture_pre_autograd_graph`` returns a
184
+ # graph which contains ATen operators which are Autograd safe, meaning they are
185
+ # safe for eager-mode training, which is needed for quantization. We will call
186
+ # the graph at this level, the ``Pre-Autograd ATen Dialect`` graph.
214
187
#
215
188
# Compared to
216
189
# `FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__,
@@ -220,6 +193,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
220
193
# will annotate the nodes in the graph with information needed to quantize the
221
194
# model properly for a specific backend.
222
195
196
+ from torch ._export import capture_pre_autograd_graph
197
+
223
198
example_args = (torch .randn (1 , 3 , 256 , 256 ),)
224
199
pre_autograd_aten_dialect = capture_pre_autograd_graph (SimpleConv (), example_args )
225
200
print ("Pre-Autograd ATen Dialect Graph" )
@@ -268,13 +243,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
268
243
from executorch .exir import EdgeProgramManager , to_edge
269
244
270
245
example_args = (torch .randn (1 , 3 , 256 , 256 ),)
271
- pre_autograd_aten_dialect = capture_pre_autograd_graph (SimpleConv (), example_args )
272
- print ("Pre-Autograd ATen Dialect Graph" )
273
- print (pre_autograd_aten_dialect )
274
-
275
- aten_dialect : ExportedProgram = export (pre_autograd_aten_dialect , example_args )
276
- print ("ATen Dialect Graph" )
277
- print (aten_dialect )
246
+ aten_dialect : ExportedProgram = export (SimpleConv (), example_args )
278
247
279
248
edge_program : EdgeProgramManager = to_edge (aten_dialect )
280
249
print ("Edge Dialect Graph" )
@@ -298,16 +267,10 @@ def forward(self, x):
298
267
299
268
300
269
encode_args = (torch .randn (1 , 10 ),)
301
- aten_encode : ExportedProgram = export (
302
- capture_pre_autograd_graph (Encode (), encode_args ),
303
- encode_args ,
304
- )
270
+ aten_encode : ExportedProgram = export (Encode (), encode_args )
305
271
306
272
decode_args = (torch .randn (1 , 5 ),)
307
- aten_decode : ExportedProgram = export (
308
- capture_pre_autograd_graph (Decode (), decode_args ),
309
- decode_args ,
310
- )
273
+ aten_decode : ExportedProgram = export (Decode (), decode_args )
311
274
312
275
edge_program : EdgeProgramManager = to_edge (
313
276
{"encode" : aten_encode , "decode" : aten_decode }
@@ -328,8 +291,7 @@ def forward(self, x):
328
291
# rather than the ``torch.ops.aten`` namespace.
329
292
330
293
example_args = (torch .randn (1 , 3 , 256 , 256 ),)
331
- pre_autograd_aten_dialect = capture_pre_autograd_graph (SimpleConv (), example_args )
332
- aten_dialect : ExportedProgram = export (pre_autograd_aten_dialect , example_args )
294
+ aten_dialect : ExportedProgram = export (SimpleConv (), example_args )
333
295
edge_program : EdgeProgramManager = to_edge (aten_dialect )
334
296
print ("Edge Dialect Graph" )
335
297
print (edge_program .exported_program ())
@@ -353,7 +315,9 @@ def call_operator(self, op, args, kwargs, meta):
353
315
print (transformed_edge_program .exported_program ())
354
316
355
317
######################################################################
356
- # Note: if you see error like `torch._export.verifier.SpecViolationError: Operator torch._ops.aten._native_batch_norm_legit_functional.default is not Aten Canonical`,
318
+ # Note: if you see error like ``torch._export.verifier.SpecViolationError:
319
+ # Operator torch._ops.aten._native_batch_norm_legit_functional.default is not
320
+ # Aten Canonical``,
357
321
# please file an issue in https://github.com/pytorch/executorch/issues and we're happy to help!
358
322
359
323
@@ -365,7 +329,7 @@ def call_operator(self, op, args, kwargs, meta):
365
329
# backend through the ``to_backend`` API. An in-depth documentation on the
366
330
# specifics of backend delegation, including how to delegate to a backend and
367
331
# how to implement a backend, can be found
368
- # `here <../compiler-delegate-and-partitioner.html>`__
332
+ # `here <../compiler-delegate-and-partitioner.html>`__.
369
333
#
370
334
# There are three ways for using this API:
371
335
#
@@ -393,8 +357,7 @@ def forward(self, x):
393
357
394
358
# Export and lower the module to Edge Dialect
395
359
example_args = (torch .ones (1 ),)
396
- pre_autograd_aten_dialect = capture_pre_autograd_graph (LowerableModule (), example_args )
397
- aten_dialect : ExportedProgram = export (pre_autograd_aten_dialect , example_args )
360
+ aten_dialect : ExportedProgram = export (LowerableModule (), example_args )
398
361
edge_program : EdgeProgramManager = to_edge (aten_dialect )
399
362
to_be_lowered_module = edge_program .exported_program ()
400
363
@@ -460,8 +423,7 @@ def forward(self, x):
460
423
461
424
462
425
example_args = (torch .ones (1 ),)
463
- pre_autograd_aten_dialect = capture_pre_autograd_graph (ComposedModule (), example_args )
464
- aten_dialect : ExportedProgram = export (pre_autograd_aten_dialect , example_args )
426
+ aten_dialect : ExportedProgram = export (ComposedModule (), example_args )
465
427
edge_program : EdgeProgramManager = to_edge (aten_dialect )
466
428
exported_program = edge_program .exported_program ()
467
429
print ("Edge Dialect graph" )
@@ -499,8 +461,7 @@ def forward(self, a, x, b):
499
461
500
462
501
463
example_args = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
502
- pre_autograd_aten_dialect = capture_pre_autograd_graph (Foo (), example_args )
503
- aten_dialect : ExportedProgram = export (pre_autograd_aten_dialect , example_args )
464
+ aten_dialect : ExportedProgram = export (Foo (), example_args )
504
465
edge_program : EdgeProgramManager = to_edge (aten_dialect )
505
466
exported_program = edge_program .exported_program ()
506
467
print ("Edge Dialect graph" )
@@ -534,8 +495,7 @@ def forward(self, a, x, b):
534
495
535
496
536
497
example_args = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
537
- pre_autograd_aten_dialect = capture_pre_autograd_graph (Foo (), example_args )
538
- aten_dialect : ExportedProgram = export (pre_autograd_aten_dialect , example_args )
498
+ aten_dialect : ExportedProgram = export (Foo (), example_args )
539
499
edge_program : EdgeProgramManager = to_edge (aten_dialect )
540
500
exported_program = edge_program .exported_program ()
541
501
delegated_program = edge_program .to_backend (AddMulPartitionerDemo ())
0 commit comments