Skip to content

Commit 15a6e57

Browse files
committed
feat: Add example usage scripts for dynamo path
- Add sample scripts covering resnet18, transformers, and custom examples showcasing the `torch_tensorrt.dynamo.compile` path, which can compile models with data-dependent control flow and other such restrictions which can make other compilation methods more difficult - Cover different customizeable features allowed in the new backend - Make scripts Sphinx-Gallery compatible Python files fix: Update `index.rst` - Show individual links in sidebar chore: Add note about Cuda Driver Error - Update arguments to Dynamo compile call in line with new schema updates fix: Update function calls to address API changes fix: Update file and reference naming for new API
1 parent ce06f6e commit 15a6e57

File tree

8 files changed

+346
-10
lines changed

8 files changed

+346
-10
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ docsrc/_build
3232
docsrc/_notebooks
3333
docsrc/_cpp_api
3434
docsrc/_tmp
35+
docsrc/tutorials/_rendered_examples
3536
*.so
3637
__pycache__
3738
*.egg-info
@@ -67,4 +68,4 @@ bazel-tensorrt
6768
*cifar-10-batches-py*
6869
bazel-project
6970
build/
70-
wheelhouse/
71+
wheelhouse/

docsrc/conf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"sphinx.ext.coverage",
4848
"sphinx.ext.mathjax",
4949
"sphinx.ext.viewcode",
50+
"sphinx_gallery.gen_gallery",
5051
]
5152

5253
napoleon_use_ivar = True
@@ -79,6 +80,12 @@
7980
# so a file named "default.css" will overwrite the builtin "default.css".
8081
html_static_path = ["_static"]
8182

83+
# sphinx-gallery configuration
84+
sphinx_gallery_conf = {
85+
"examples_dirs": "../examples/dynamo",
86+
"gallery_dirs": "tutorials/_rendered_examples/",
87+
}
88+
8289
# Setup the breathe extension
8390
breathe_projects = {"Torch-TensorRT": "./_tmp/xml"}
8491
breathe_default_project = "Torch-TensorRT"

docsrc/index.rst

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,43 @@ Getting Started
3636
getting_started/getting_started_with_windows
3737

3838

39-
Tutorials
39+
User Guide
4040
------------
4141
* :ref:`creating_a_ts_mod`
4242
* :ref:`getting_started_with_fx`
4343
* :ref:`ptq`
4444
* :ref:`runtime`
45-
* :ref:`serving_torch_tensorrt_with_triton`
4645
* :ref:`use_from_pytorch`
4746
* :ref:`using_dla`
47+
48+
.. toctree::
49+
:caption: User Guide
50+
:maxdepth: 1
51+
:hidden:
52+
53+
user_guide/creating_torchscript_module_in_python
54+
user_guide/getting_started_with_fx_path
55+
user_guide/ptq
56+
user_guide/runtime
57+
user_guide/use_from_pytorch
58+
user_guide/using_dla
59+
60+
Tutorials
61+
------------
62+
* :ref:`serving_torch_tensorrt_with_triton`
4863
* :ref:`notebooks`
64+
* :ref:`dynamo_compile`
4965

5066
.. toctree::
5167
:caption: Tutorials
52-
:maxdepth: 1
68+
:maxdepth: 3
5369
:hidden:
5470

55-
tutorials/creating_torchscript_module_in_python
56-
tutorials/getting_started_with_fx_path
57-
tutorials/ptq
58-
tutorials/runtime
5971
tutorials/serving_torch_tensorrt_with_triton
60-
tutorials/use_from_pytorch
61-
tutorials/using_dla
6272
tutorials/notebooks
73+
tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
74+
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
75+
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
6376

6477
Python API Documenation
6578
------------------------

docsrc/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
sphinx==4.5.0
2+
sphinx-gallery==0.13.0
23
breathe==4.33.1
34
exhale==0.3.1
45
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme

examples/dynamo/README.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
.. _torch_compile:
2+
3+
Dynamo Compile Examples
4+
================
5+
6+
This document contains examples of usage of the `torch_tensorrt.dynamo.compile` API which integrates with `torch.compile` functionality
7+
8+
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
9+
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
10+
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
.. _torch_compile_advanced_usage:
3+
4+
Torch Compile Advanced Usage
5+
======================================================
6+
7+
This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API."""
8+
9+
# %%
10+
# Imports and Model Definition
11+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
12+
13+
import torch
14+
import torch_tensorrt
15+
16+
# %%
17+
18+
# We begin by defining a model
19+
class Model(torch.nn.Module):
20+
def __init__(self) -> None:
21+
super().__init__()
22+
self.relu = torch.nn.ReLU()
23+
24+
def forward(self, x: torch.Tensor, y: torch.Tensor):
25+
x_out = self.relu(x)
26+
y_out = self.relu(y)
27+
x_y_out = x_out + y_out
28+
return torch.mean(x_y_out)
29+
30+
31+
# %%
32+
# Compilation with `torch.compile` Using Default Settings
33+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
34+
35+
# Define sample float inputs and initialize model
36+
sample_inputs = [torch.rand((5, 7)).cuda(), torch.rand((5, 7)).cuda()]
37+
model = Model().eval().cuda()
38+
39+
# %%
40+
41+
# Next, we compile the model using torch.compile
42+
# For the default settings, we can simply call torch.compile
43+
# with the backend "torch_tensorrt", and run the model on an
44+
# input to cause compilation, as so:
45+
optimized_model = torch.compile(model, backend="torch_tensorrt")
46+
optimized_model(*sample_inputs)
47+
48+
# %%
49+
# Compilation with `torch.compile` Using Custom Settings
50+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
51+
52+
# First, we use Torch utilities to clean up the workspace
53+
# after the previous compile invocation
54+
torch._dynamo.reset()
55+
56+
# Define sample half inputs and initialize model
57+
sample_inputs_half = [
58+
torch.rand((5, 7)).half().cuda(),
59+
torch.rand((5, 7)).half().cuda(),
60+
]
61+
model_half = Model().eval().cuda()
62+
63+
# %%
64+
65+
# If we want to customize certain options in the backend,
66+
# but still use the torch.compile call directly, we can provide
67+
# custom options to the backend via the "options" keyword
68+
# which takes in a dictionary mapping options to values.
69+
#
70+
# For accepted backend options, see the CompilationSettings dataclass:
71+
# py/torch_tensorrt/dynamo/_settings.py
72+
backend_kwargs = {
73+
"enabled_precisions": {torch.half},
74+
"debug": True,
75+
"min_block_size": 2,
76+
"torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
77+
"optimization_level": 4,
78+
"use_python_runtime": False,
79+
}
80+
81+
# Run the model on an input to cause compilation, as so:
82+
optimized_model_custom = torch.compile(
83+
model_half, backend="torch_tensorrt", options=backend_kwargs
84+
)
85+
optimized_model_custom(*sample_inputs_half)
86+
87+
# %%
88+
# Cleanup
89+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
90+
91+
# Finally, we use Torch utilities to clean up the workspace
92+
torch._dynamo.reset()
93+
94+
# %%
95+
# Cuda Driver Error Note
96+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
97+
#
98+
# Occasionally, upon exiting the Python runtime after Dynamo compilation with `torch_tensorrt`,
99+
# one may encounter a Cuda Driver Error. This issue is related to https://github.com/NVIDIA/TensorRT/issues/2052
100+
# and can be resolved by wrapping the compilation/inference in a function and using a scoped call, as in::
101+
#
102+
# if __name__ == '__main__':
103+
# compile_engine_and_infer()
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
.. _torch_compile_resnet:
3+
4+
Compiling ResNet using the Torch-TensorRT `torch.compile` Backend
5+
==========================================================
6+
7+
This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a ResNet model."""
8+
9+
# %%
10+
# Imports and Model Definition
11+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
12+
13+
import torch
14+
import torch_tensorrt
15+
import torchvision.models as models
16+
17+
# %%
18+
19+
# Initialize model with half precision and sample inputs
20+
model = models.resnet18(pretrained=True).half().eval().to("cuda")
21+
inputs = [torch.randn((1, 3, 224, 224)).to("cuda").half()]
22+
23+
# %%
24+
# Optional Input Arguments to `torch_tensorrt.compile`
25+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26+
27+
# Enabled precision for TensorRT optimization
28+
enabled_precisions = {torch.half}
29+
30+
# Whether to print verbose logs
31+
debug = True
32+
33+
# Workspace size for TensorRT
34+
workspace_size = 20 << 30
35+
36+
# Maximum number of TRT Engines
37+
# (Lower value allows more graph segmentation)
38+
min_block_size = 7
39+
40+
# Operations to Run in Torch, regardless of converter support
41+
torch_executed_ops = {}
42+
43+
# %%
44+
# Compilation with `torch_tensorrt.compile`
45+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
46+
47+
# Build and compile the model with torch.compile, using Torch-TensorRT backend
48+
optimized_model = torch_tensorrt.compile(
49+
model,
50+
ir="torch_compile",
51+
inputs=inputs,
52+
enabled_precisions=enabled_precisions,
53+
debug=debug,
54+
workspace_size=workspace_size,
55+
min_block_size=min_block_size,
56+
torch_executed_ops=torch_executed_ops,
57+
)
58+
59+
# %%
60+
# Equivalently, we could have run the above via the torch.compile frontend, as so:
61+
# `optimized_model = torch.compile(model, backend="torch_tensorrt", options={"enabled_precisions": enabled_precisions, ...}); optimized_model(*inputs)`
62+
63+
# %%
64+
# Inference
65+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66+
67+
# Does not cause recompilation (same batch size as input)
68+
new_inputs = [torch.randn((1, 3, 224, 224)).half().to("cuda")]
69+
new_outputs = optimized_model(*new_inputs)
70+
71+
# %%
72+
73+
# Does cause recompilation (new batch size)
74+
new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to("cuda")]
75+
new_batch_size_outputs = optimized_model(*new_batch_size_inputs)
76+
77+
# %%
78+
# Cleanup
79+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
80+
81+
# Finally, we use Torch utilities to clean up the workspace
82+
torch._dynamo.reset()
83+
84+
# %%
85+
# Cuda Driver Error Note
86+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
87+
#
88+
# Occasionally, upon exiting the Python runtime after Dynamo compilation with `torch_tensorrt`,
89+
# one may encounter a Cuda Driver Error. This issue is related to https://github.com/NVIDIA/TensorRT/issues/2052
90+
# and can be resolved by wrapping the compilation/inference in a function and using a scoped call, as in::
91+
#
92+
# if __name__ == '__main__':
93+
# compile_engine_and_infer()

0 commit comments

Comments
 (0)