Skip to content

Commit b40f8ef

Browse files
committed
Update export tutorial
1 parent 361c4c7 commit b40f8ef

File tree

2 files changed

+121
-71
lines changed

2 files changed

+121
-71
lines changed

.jenkins/metadata.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"needs": "linux.16xlarge.nvidia.gpu"
3030
},
3131
"intermediate_source/torchvision_tutorial.py": {
32-
"needs": "linux.g5.4xlarge.nvidia.gpu",
32+
"needs": "linux.g5.4xlarge.nvidia.gpu",
3333
"_comment": "does not require a5g but needs to run before gpu_quantization_torchao_tutorial.py."
3434
},
3535
"advanced_source/coding_ddpg.py": {
@@ -39,6 +39,9 @@
3939
"intermediate_source/torch_compile_tutorial.py": {
4040
"needs": "linux.g5.4xlarge.nvidia.gpu"
4141
},
42+
"intermediate_source/torch_export_tutorial.py": {
43+
"needs": "linux.g5.4xlarge.nvidia.gpu"
44+
},
4245
"intermediate_source/scaled_dot_product_attention_tutorial.py": {
4346
"needs": "linux.g5.4xlarge.nvidia.gpu"
4447
},

intermediate_source/torch_export_tutorial.py

Lines changed: 117 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
# .. warning::
1212
#
1313
# ``torch.export`` and its related features are in prototype status and are subject to backwards compatibility
14-
# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.2.
14+
# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.3.
1515
#
1616
# :func:`torch.export` is the PyTorch 2.X way to export PyTorch models into
1717
# standardized model representations, intended
18-
# to be run on different (i.e. Python-less) environments.
18+
# to be run on different (i.e. Python-less) environments. The official
19+
# documentation can be found `here <https://pytorch.org/docs/main/export.html>`__.
1920
#
2021
# In this tutorial, you will learn how to use :func:`torch.export` to extract
2122
# ``ExportedProgram``'s (i.e. single-graph representations) from PyTorch programs.
@@ -71,7 +72,7 @@ def forward(self, x, y):
7172
mod = MyModule()
7273
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
7374
print(type(exported_mod))
74-
print(exported_mod(torch.randn(8, 100), torch.randn(8, 100)))
75+
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))
7576

7677

7778
######################################################################
@@ -100,7 +101,7 @@ def forward(self, x, y):
100101
# Other attributes of interest in ``ExportedProgram`` include:
101102
#
102103
# - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph.
103-
# - ``range_constraints`` and ``equality_constraints`` -- constraints, covered later
104+
# - ``range_constraints`` -- constraints, covered later
104105

105106
print(exported_mod.graph_signature)
106107

@@ -123,54 +124,58 @@ def forward(self, x, y):
123124
#
124125
# - data-dependent control flow
125126

126-
def bad1(x):
127-
if x.sum() > 0:
128-
return torch.sin(x)
129-
return torch.cos(x)
127+
class Bad1(torch.nn.Module):
128+
def forward(self, x):
129+
if x.sum() > 0:
130+
return torch.sin(x)
131+
return torch.cos(x)
130132

131133
import traceback as tb
132134
try:
133-
export(bad1, (torch.randn(3, 3),))
135+
export(Bad1(), (torch.randn(3, 3),))
134136
except Exception:
135137
tb.print_exc()
136138

137139
######################################################################
138140
# - accessing tensor data with ``.data``
139141

140-
def bad2(x):
141-
x.data[0, 0] = 3
142-
return x
142+
class Bad2(torch.nn.Module):
143+
def forward(self, x):
144+
x.data[0, 0] = 3
145+
return x
143146

144147
try:
145-
export(bad2, (torch.randn(3, 3),))
148+
export(Bad2(), (torch.randn(3, 3),))
146149
except Exception:
147150
tb.print_exc()
148151

149152
######################################################################
150153
# - calling unsupported functions (such as many built-in functions)
151154

152-
def bad3(x):
153-
x = x + 1
154-
return x + id(x)
155+
class Bad3(torch.nn.Module):
156+
def forward(self, x):
157+
x = x + 1
158+
return x + id(x)
155159

156160
try:
157-
export(bad3, (torch.randn(3, 3),))
161+
export(Bad3(), (torch.randn(3, 3),))
158162
except Exception:
159163
tb.print_exc()
160164

161165
######################################################################
162166
# - unsupported Python language features (e.g. throwing exceptions, match statements)
163167

164-
def bad4(x):
165-
try:
166-
x = x + 1
167-
raise RuntimeError("bad")
168-
except:
169-
x = x + 2
170-
return x
168+
class Bad4(torch.nn.Module):
169+
def forward(self, x):
170+
try:
171+
x = x + 1
172+
raise RuntimeError("bad")
173+
except:
174+
x = x + 2
175+
return x
171176

172177
try:
173-
export(bad4, (torch.randn(3, 3),))
178+
export(Bad4(), (torch.randn(3, 3),))
174179
except Exception:
175180
tb.print_exc()
176181

@@ -188,16 +193,17 @@ def bad4(x):
188193

189194
from functorch.experimental.control_flow import cond
190195

191-
def bad1_fixed(x):
192-
def true_fn(x):
193-
return torch.sin(x)
194-
def false_fn(x):
195-
return torch.cos(x)
196-
return cond(x.sum() > 0, true_fn, false_fn, [x])
196+
class Bad1Fixed(torch.nn.Module):
197+
def forward(self, x):
198+
def true_fn(x):
199+
return torch.sin(x)
200+
def false_fn(x):
201+
return torch.cos(x)
202+
return cond(x.sum() > 0, true_fn, false_fn, [x])
197203

198-
exported_bad1_fixed = export(bad1_fixed, (torch.randn(3, 3),))
199-
print(exported_bad1_fixed(torch.ones(3, 3)))
200-
print(exported_bad1_fixed(-torch.ones(3, 3)))
204+
exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
205+
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
206+
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))
201207

202208
######################################################################
203209
# There are limitations to ``cond`` that one should be aware of:
@@ -255,7 +261,7 @@ def forward(self, x, y):
255261
exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))
256262

257263
try:
258-
exported_mod2(torch.randn(10, 100), torch.randn(10, 100))
264+
exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100))
259265
except Exception:
260266
tb.print_exc()
261267

@@ -286,32 +292,33 @@ def forward(self, x, y):
286292

287293
inp1 = torch.randn(10, 10, 2)
288294

289-
def dynamic_shapes_example1(x):
290-
x = x[:, 2:]
291-
return torch.relu(x)
295+
class DynamicShapesExample1(torch.nn.Module):
296+
def forward(self, x):
297+
x = x[:, 2:]
298+
return torch.relu(x)
292299

293300
inp1_dim0 = Dim("inp1_dim0")
294301
inp1_dim1 = Dim("inp1_dim1", min=4, max=18)
295302
dynamic_shapes1 = {
296303
"x": {0: inp1_dim0, 1: inp1_dim1},
297304
}
298305

299-
exported_dynamic_shapes_example1 = export(dynamic_shapes_example1, (inp1,), dynamic_shapes=dynamic_shapes1)
306+
exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1)
300307

301-
print(exported_dynamic_shapes_example1(torch.randn(5, 5, 2)))
308+
print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2)))
302309

303310
try:
304-
exported_dynamic_shapes_example1(torch.randn(8, 1, 2))
311+
exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2))
305312
except Exception:
306313
tb.print_exc()
307314

308315
try:
309-
exported_dynamic_shapes_example1(torch.randn(8, 20, 2))
316+
exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2))
310317
except Exception:
311318
tb.print_exc()
312319

313320
try:
314-
exported_dynamic_shapes_example1(torch.randn(8, 8, 3))
321+
exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3))
315322
except Exception:
316323
tb.print_exc()
317324

@@ -325,7 +332,7 @@ def dynamic_shapes_example1(x):
325332
}
326333

327334
try:
328-
export(dynamic_shapes_example1, (inp1,), dynamic_shapes=dynamic_shapes1_bad)
335+
export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad)
329336
except Exception:
330337
tb.print_exc()
331338

@@ -336,8 +343,9 @@ def dynamic_shapes_example1(x):
336343
inp2 = torch.randn(4, 8)
337344
inp3 = torch.randn(8, 2)
338345

339-
def dynamic_shapes_example2(x, y):
340-
return x @ y
346+
class DynamicShapesExample2(torch.nn.Module):
347+
def forward(self, x, y):
348+
return x @ y
341349

342350
inp2_dim0 = Dim("inp2_dim0")
343351
inner_dim = Dim("inner_dim")
@@ -348,12 +356,12 @@ def dynamic_shapes_example2(x, y):
348356
"y": {0: inner_dim, 1: inp3_dim1},
349357
}
350358

351-
exported_dynamic_shapes_example2 = export(dynamic_shapes_example2, (inp2, inp3), dynamic_shapes=dynamic_shapes2)
359+
exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2)
352360

353-
print(exported_dynamic_shapes_example2(torch.randn(2, 16), torch.randn(16, 4)))
361+
print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4)))
354362

355363
try:
356-
exported_dynamic_shapes_example2(torch.randn(4, 8), torch.randn(4, 2))
364+
exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2))
357365
except Exception:
358366
tb.print_exc()
359367

@@ -367,18 +375,19 @@ def dynamic_shapes_example2(x, y):
367375
inp4 = torch.randn(8, 16)
368376
inp5 = torch.randn(16, 32)
369377

370-
def dynamic_shapes_example3(x, y):
371-
if x.shape[0] <= 16:
372-
return x @ y[:, :16]
373-
return y
378+
class DynamicShapesExample3(torch.nn.Module):
379+
def forward(self, x, y):
380+
if x.shape[0] <= 16:
381+
return x @ y[:, :16]
382+
return y
374383

375384
dynamic_shapes3 = {
376385
"x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},
377386
"y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())},
378387
}
379388

380389
try:
381-
export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3)
390+
export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3)
382391
except Exception:
383392
tb.print_exc()
384393

@@ -400,8 +409,8 @@ def suggested_fixes():
400409
}
401410

402411
dynamic_shapes3_fixed = suggested_fixes()
403-
exported_dynamic_shapes_example3 = export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
404-
print(exported_dynamic_shapes_example3(torch.randn(4, 32), torch.randn(32, 64)))
412+
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
413+
print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64)))
405414

406415
######################################################################
407416
# Note that in the example above, because we constrained the value of ``x.shape[0]`` in
@@ -414,18 +423,16 @@ def suggested_fixes():
414423

415424
import logging
416425
torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO)
417-
exported_dynamic_shapes_example3 = export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
426+
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
418427

419428
# reset to previous values
420429
torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING)
421430

422431
######################################################################
423-
# We can view an ``ExportedProgram``'s constraints using the ``range_constraints`` and
424-
# ``equality_constraints`` attributes. The logging above reveals what the symbols ``s0, s1, ...``
425-
# represent.
432+
# We can view an ``ExportedProgram``'s symbolic shape ranges using the
433+
# ``range_constraints`` field.
426434

427435
print(exported_dynamic_shapes_example3.range_constraints)
428-
print(exported_dynamic_shapes_example3.equality_constraints)
429436

430437
######################################################################
431438
# Custom Ops
@@ -438,7 +445,7 @@ def suggested_fixes():
438445
# - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/docs/main/library.html>`__)
439446
# as with any other custom op
440447

441-
from torch.library import Library, impl
448+
from torch.library import Library, impl, impl_abstract
442449

443450
m = Library("my_custom_library", "DEF")
444451

@@ -453,25 +460,26 @@ def custom_op(x):
453460
# - Define a ``"Meta"`` implementation of the custom op that returns an empty
454461
# tensor with the same shape as the expected output
455462

456-
@impl(m, "custom_op", "Meta")
463+
@impl_abstract("my_custom_library::custom_op")
457464
def custom_op_meta(x):
458465
return torch.empty_like(x)
459466

460467
######################################################################
461468
# - Call the custom op from the code you want to export using ``torch.ops``
462469

463-
def custom_op_example(x):
464-
x = torch.sin(x)
465-
x = torch.ops.my_custom_library.custom_op(x)
466-
x = torch.cos(x)
467-
return x
470+
class CustomOpExample(torch.nn.Module):
471+
def forward(self, x):
472+
x = torch.sin(x)
473+
x = torch.ops.my_custom_library.custom_op(x)
474+
x = torch.cos(x)
475+
return x
468476

469477
######################################################################
470478
# - Export the code as before
471479

472-
exported_custom_op_example = export(custom_op_example, (torch.randn(3, 3),))
480+
exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
473481
exported_custom_op_example.graph_module.print_readable()
474-
print(exported_custom_op_example(torch.randn(3, 3)))
482+
print(exported_custom_op_example.module()(torch.randn(3, 3)))
475483

476484
######################################################################
477485
# Note in the above outputs that the custom op is included in the exported graph.
@@ -606,6 +614,45 @@ def cond_predicate(x):
606614
# ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach
607615
# out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by ``torch.export``.
608616

617+
######################################################################
618+
# Running the Exported Program
619+
# ----------------------------
620+
#
621+
# As ``torch.export`` is only a graph capturing mechanism, calling the artifact
622+
# produced by ``torch.export`` eagerly will be equivalent to running the eager
623+
# module. To optimize the execution of the Exported Program, we can pass this
624+
# exported artifact to backends such Inductor through ``torch.compile``,
625+
# `AOTInductor <https://pytorch.org/docs/main/torch.compiler_aot_inductor.html>`__,
626+
# or `TensorRT <https://pytorch.org/TensorRT/dynamo/dynamo_export.html>`__.
627+
628+
class M(torch.nn.Module):
629+
def __init__(self):
630+
super().__init__()
631+
self.linear = torch.nn.Linear(3, 3)
632+
633+
def forward(self, x):
634+
x = self.linear(x)
635+
return x
636+
637+
ep = torch.export.export(M().to(device="cuda"), (torch.ones(2, 3, device="cuda"),))
638+
inp = torch.randn(2, 3, device="cuda")
639+
640+
# Run it eagerly
641+
res = ep.module()(inp)
642+
print(res)
643+
644+
# Run it with torch.compile
645+
res = torch.compile(ep.module(), backend="inductor")(inp)
646+
print(res)
647+
648+
# Compile the exported program to a .so using AOTInductor
649+
so_path = torch._export.aot_compile(ep.module(), (inp,))
650+
# Load and run the .so in python.
651+
# To load and run it in a C++ environment, please take a look at
652+
# https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
653+
res = torch._export.aot_load(so_path, device="cuda")(inp)
654+
print(res)
655+
609656
######################################################################
610657
# Conclusion
611658
# ----------

0 commit comments

Comments
 (0)