Skip to content

Update export tutorial #2806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .jenkins/metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"needs": "linux.16xlarge.nvidia.gpu"
},
"intermediate_source/torchvision_tutorial.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu",
"needs": "linux.g5.4xlarge.nvidia.gpu",
"_comment": "does not require a5g but needs to run before gpu_quantization_torchao_tutorial.py."
},
"advanced_source/coding_ddpg.py": {
Expand All @@ -39,6 +39,9 @@
"intermediate_source/torch_compile_tutorial.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu"
},
"intermediate_source/torch_export_tutorial.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu"
},
"intermediate_source/scaled_dot_product_attention_tutorial.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu"
},
Expand Down
187 changes: 117 additions & 70 deletions intermediate_source/torch_export_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# .. warning::
#
# ``torch.export`` and its related features are in prototype status and are subject to backwards compatibility
# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.2.
# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.3.
#
# :func:`torch.export` is the PyTorch 2.X way to export PyTorch models into
# standardized model representations, intended
# to be run on different (i.e. Python-less) environments.
# to be run on different (i.e. Python-less) environments. The official
# documentation can be found `here <https://pytorch.org/docs/main/export.html>`__.
#
# In this tutorial, you will learn how to use :func:`torch.export` to extract
# ``ExportedProgram``'s (i.e. single-graph representations) from PyTorch programs.
Expand Down Expand Up @@ -71,7 +72,7 @@ def forward(self, x, y):
mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod(torch.randn(8, 100), torch.randn(8, 100)))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))


######################################################################
Expand Down Expand Up @@ -100,7 +101,7 @@ def forward(self, x, y):
# Other attributes of interest in ``ExportedProgram`` include:
#
# - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph.
# - ``range_constraints`` and ``equality_constraints`` -- constraints, covered later
# - ``range_constraints`` -- constraints, covered later

print(exported_mod.graph_signature)

Expand All @@ -123,54 +124,58 @@ def forward(self, x, y):
#
# - data-dependent control flow

def bad1(x):
if x.sum() > 0:
return torch.sin(x)
return torch.cos(x)
class Bad1(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return torch.sin(x)
return torch.cos(x)

import traceback as tb
try:
export(bad1, (torch.randn(3, 3),))
export(Bad1(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()

######################################################################
# - accessing tensor data with ``.data``

def bad2(x):
x.data[0, 0] = 3
return x
class Bad2(torch.nn.Module):
def forward(self, x):
x.data[0, 0] = 3
return x

try:
export(bad2, (torch.randn(3, 3),))
export(Bad2(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()

######################################################################
# - calling unsupported functions (such as many built-in functions)

def bad3(x):
x = x + 1
return x + id(x)
class Bad3(torch.nn.Module):
def forward(self, x):
x = x + 1
return x + id(x)

try:
export(bad3, (torch.randn(3, 3),))
export(Bad3(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()

######################################################################
# - unsupported Python language features (e.g. throwing exceptions, match statements)

def bad4(x):
try:
x = x + 1
raise RuntimeError("bad")
except:
x = x + 2
return x
class Bad4(torch.nn.Module):
def forward(self, x):
try:
x = x + 1
raise RuntimeError("bad")
except:
x = x + 2
return x

try:
export(bad4, (torch.randn(3, 3),))
export(Bad4(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()

Expand All @@ -188,16 +193,17 @@ def bad4(x):

from functorch.experimental.control_flow import cond

def bad1_fixed(x):
def true_fn(x):
return torch.sin(x)
def false_fn(x):
return torch.cos(x)
return cond(x.sum() > 0, true_fn, false_fn, [x])
class Bad1Fixed(torch.nn.Module):
def forward(self, x):
def true_fn(x):
return torch.sin(x)
def false_fn(x):
return torch.cos(x)
return cond(x.sum() > 0, true_fn, false_fn, [x])

exported_bad1_fixed = export(bad1_fixed, (torch.randn(3, 3),))
print(exported_bad1_fixed(torch.ones(3, 3)))
print(exported_bad1_fixed(-torch.ones(3, 3)))
exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))

######################################################################
# There are limitations to ``cond`` that one should be aware of:
Expand Down Expand Up @@ -255,7 +261,7 @@ def forward(self, x, y):
exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))

try:
exported_mod2(torch.randn(10, 100), torch.randn(10, 100))
exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100))
except Exception:
tb.print_exc()

Expand Down Expand Up @@ -286,32 +292,33 @@ def forward(self, x, y):

inp1 = torch.randn(10, 10, 2)

def dynamic_shapes_example1(x):
x = x[:, 2:]
return torch.relu(x)
class DynamicShapesExample1(torch.nn.Module):
def forward(self, x):
x = x[:, 2:]
return torch.relu(x)

inp1_dim0 = Dim("inp1_dim0")
inp1_dim1 = Dim("inp1_dim1", min=4, max=18)
dynamic_shapes1 = {
"x": {0: inp1_dim0, 1: inp1_dim1},
}

exported_dynamic_shapes_example1 = export(dynamic_shapes_example1, (inp1,), dynamic_shapes=dynamic_shapes1)
exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1)

print(exported_dynamic_shapes_example1(torch.randn(5, 5, 2)))
print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2)))

try:
exported_dynamic_shapes_example1(torch.randn(8, 1, 2))
exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2))
except Exception:
tb.print_exc()

try:
exported_dynamic_shapes_example1(torch.randn(8, 20, 2))
exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2))
except Exception:
tb.print_exc()

try:
exported_dynamic_shapes_example1(torch.randn(8, 8, 3))
exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3))
except Exception:
tb.print_exc()

Expand All @@ -325,7 +332,7 @@ def dynamic_shapes_example1(x):
}

try:
export(dynamic_shapes_example1, (inp1,), dynamic_shapes=dynamic_shapes1_bad)
export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad)
except Exception:
tb.print_exc()

Expand All @@ -336,8 +343,9 @@ def dynamic_shapes_example1(x):
inp2 = torch.randn(4, 8)
inp3 = torch.randn(8, 2)

def dynamic_shapes_example2(x, y):
return x @ y
class DynamicShapesExample2(torch.nn.Module):
def forward(self, x, y):
return x @ y

inp2_dim0 = Dim("inp2_dim0")
inner_dim = Dim("inner_dim")
Expand All @@ -348,12 +356,12 @@ def dynamic_shapes_example2(x, y):
"y": {0: inner_dim, 1: inp3_dim1},
}

exported_dynamic_shapes_example2 = export(dynamic_shapes_example2, (inp2, inp3), dynamic_shapes=dynamic_shapes2)
exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2)

print(exported_dynamic_shapes_example2(torch.randn(2, 16), torch.randn(16, 4)))
print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4)))

try:
exported_dynamic_shapes_example2(torch.randn(4, 8), torch.randn(4, 2))
exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2))
except Exception:
tb.print_exc()

Expand All @@ -367,18 +375,19 @@ def dynamic_shapes_example2(x, y):
inp4 = torch.randn(8, 16)
inp5 = torch.randn(16, 32)

def dynamic_shapes_example3(x, y):
if x.shape[0] <= 16:
return x @ y[:, :16]
return y
class DynamicShapesExample3(torch.nn.Module):
def forward(self, x, y):
if x.shape[0] <= 16:
return x @ y[:, :16]
return y

dynamic_shapes3 = {
"x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},
"y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())},
}

try:
export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3)
export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3)
except Exception:
tb.print_exc()

Expand All @@ -400,8 +409,8 @@ def suggested_fixes():
}

dynamic_shapes3_fixed = suggested_fixes()
exported_dynamic_shapes_example3 = export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
print(exported_dynamic_shapes_example3(torch.randn(4, 32), torch.randn(32, 64)))
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64)))

######################################################################
# Note that in the example above, because we constrained the value of ``x.shape[0]`` in
Expand All @@ -414,18 +423,16 @@ def suggested_fixes():

import logging
torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO)
exported_dynamic_shapes_example3 = export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)

# reset to previous values
torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING)

######################################################################
# We can view an ``ExportedProgram``'s constraints using the ``range_constraints`` and
# ``equality_constraints`` attributes. The logging above reveals what the symbols ``s0, s1, ...``
# represent.
# We can view an ``ExportedProgram``'s symbolic shape ranges using the
# ``range_constraints`` field.

print(exported_dynamic_shapes_example3.range_constraints)
print(exported_dynamic_shapes_example3.equality_constraints)

######################################################################
# Custom Ops
Expand All @@ -438,7 +445,7 @@ def suggested_fixes():
# - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/docs/main/library.html>`__)
# as with any other custom op

from torch.library import Library, impl
from torch.library import Library, impl, impl_abstract

m = Library("my_custom_library", "DEF")

Expand All @@ -453,25 +460,26 @@ def custom_op(x):
# - Define a ``"Meta"`` implementation of the custom op that returns an empty
# tensor with the same shape as the expected output

@impl(m, "custom_op", "Meta")
@impl_abstract("my_custom_library::custom_op")
def custom_op_meta(x):
return torch.empty_like(x)

######################################################################
# - Call the custom op from the code you want to export using ``torch.ops``

def custom_op_example(x):
x = torch.sin(x)
x = torch.ops.my_custom_library.custom_op(x)
x = torch.cos(x)
return x
class CustomOpExample(torch.nn.Module):
def forward(self, x):
x = torch.sin(x)
x = torch.ops.my_custom_library.custom_op(x)
x = torch.cos(x)
return x

######################################################################
# - Export the code as before

exported_custom_op_example = export(custom_op_example, (torch.randn(3, 3),))
exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
exported_custom_op_example.graph_module.print_readable()
print(exported_custom_op_example(torch.randn(3, 3)))
print(exported_custom_op_example.module()(torch.randn(3, 3)))

######################################################################
# Note in the above outputs that the custom op is included in the exported graph.
Expand Down Expand Up @@ -606,6 +614,45 @@ def cond_predicate(x):
# ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach
# out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by ``torch.export``.

######################################################################
# Running the Exported Program
# ----------------------------
#
# As ``torch.export`` is only a graph capturing mechanism, calling the artifact
# produced by ``torch.export`` eagerly will be equivalent to running the eager
# module. To optimize the execution of the Exported Program, we can pass this
# exported artifact to backends such as Inductor through ``torch.compile``,
# `AOTInductor <https://pytorch.org/docs/main/torch.compiler_aot_inductor.html>`__,
# or `TensorRT <https://pytorch.org/TensorRT/dynamo/dynamo_export.html>`__.

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
x = self.linear(x)
return x

ep = torch.export.export(M().to(device="cuda"), (torch.ones(2, 3, device="cuda"),))
inp = torch.randn(2, 3, device="cuda")

# Run it eagerly
res = ep.module()(inp)
print(res)

# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)

# Compile the exported program to a .so using AOTInductor
so_path = torch._export.aot_compile(ep.module(), (inp,))
# Load and run the .so in python.
# To load and run it in a C++ environment, please take a look at
# https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
res = torch._export.aot_load(so_path, device="cuda")(inp)
print(res)

######################################################################
# Conclusion
# ----------
Expand Down