Skip to content

Commit 28fd11a

Browse files
authored
Merge pull request #1966 from pytorch/dynamo_torch_compile_examples
examples: Add example usage scripts for `torch_tensorrt.dynamo.compile` path [1.1 / x]
2 parents ce06f6e + 968aca4 commit 28fd11a

18 files changed

+499
-10
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ docsrc/_build
3232
docsrc/_notebooks
3333
docsrc/_cpp_api
3434
docsrc/_tmp
35+
docsrc/tutorials/_rendered_examples
3536
*.so
3637
__pycache__
3738
*.egg-info
@@ -67,4 +68,4 @@ bazel-tensorrt
6768
*cifar-10-batches-py*
6869
bazel-project
6970
build/
70-
wheelhouse/
71+
wheelhouse/

docsrc/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ endif
3535
rm -rf $(SOURCEDIR)/_py_api
3636
rm -rf $(SOURCEDIR)/_build
3737
rm -rf $(SOURCEDIR)/_tmp
38+
rm -rf $(SOURCEDIR)/tutorials/_rendered_examples
3839

3940
html:
4041
# mkdir -p $(SOURCEDIR)/_notebooks

docsrc/_static/css/custom.css

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
/* sphinx-design styles for cards/tabs
2+
*/
3+
4+
.sphx-glr-thumbcontainer {
5+
padding: 50%;
6+
display: flex;
7+
align-content: center;
8+
}

docsrc/_static/css/pytorch_theme.css

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
body {
2+
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
3+
}
4+
5+
/* Default header fonts are ugly */
6+
h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
7+
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
8+
}
9+
10+
/* Use white for docs background */
11+
.wy-side-nav-search {
12+
background-color: #fff;
13+
}
14+
15+
.wy-nav-content-wrap, .wy-menu li.current > a {
16+
background-color: #fff;
17+
}
18+
19+
@media screen and (min-width: 1400px) {
20+
.wy-nav-content-wrap {
21+
background-color: rgba(0, 0, 0, 0.0470588);
22+
}
23+
24+
.wy-nav-content {
25+
background-color: #fff;
26+
}
27+
}
28+
29+
/* Fixes for mobile */
30+
.wy-nav-top {
31+
background-color: #fff;
32+
background-image: url('../img/pytorch-logo-dark.svg');
33+
background-repeat: no-repeat;
34+
background-position: center;
35+
padding: 0;
36+
margin: 0.4045em 0.809em;
37+
color: #333;
38+
}
39+
40+
.wy-nav-top > a {
41+
display: none;
42+
}
43+
44+
@media screen and (max-width: 768px) {
45+
.wy-side-nav-search>a img.logo {
46+
height: 60px;
47+
}
48+
}
49+
50+
/* This is needed to ensure that logo above search scales properly */
51+
.wy-side-nav-search a {
52+
display: block;
53+
}
54+
55+
/* This ensures that multiple constructors will remain in separate lines. */
56+
.rst-content dl:not(.docutils) dt {
57+
display: table;
58+
}
59+
60+
/* Use our red for literals (it's very similar to the original color) */
61+
.rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal {
62+
color: #F05732;
63+
}
64+
65+
.rst-content tt.xref, a .rst-content tt, .rst-content tt.xref,
66+
.rst-content code.xref, a .rst-content tt, a .rst-content code {
67+
color: #404040;
68+
}
69+
70+
/* Change link colors (except for the menu) */
71+
72+
a {
73+
color: #F05732;
74+
}
75+
76+
a:hover {
77+
color: #F05732;
78+
}
79+
80+
81+
a:visited {
82+
color: #D44D2C;
83+
}
84+
85+
.wy-menu a {
86+
color: #b3b3b3;
87+
}
88+
89+
.wy-menu a:hover {
90+
color: #b3b3b3;
91+
}
92+
93+
a.icon.icon-home {
94+
color: #D44D2C;
95+
}
96+
97+
.version{
98+
color: #D44D2C !important;
99+
}
100+
101+
/* Default footer text is quite big */
102+
footer {
103+
font-size: 80%;
104+
}
105+
106+
footer .rst-footer-buttons {
107+
font-size: 125%; /* revert footer settings - 1/80% = 125% */
108+
}
109+
110+
footer p {
111+
font-size: 100%;
112+
}
113+
114+
/* For hidden headers that appear in TOC tree */
115+
/* see https://stackoverflow.com/a/32363545/3343043 */
116+
.rst-content .hidden-section {
117+
display: none;
118+
}
119+
120+
nav .hidden-section {
121+
display: inherit;
122+
}
123+
124+
/* Make code blocks have a background */
125+
.codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] {
126+
background: rgba(0, 0, 0, 0.0470588);
127+
}

docsrc/conf.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import torch
1919
import pytorch_sphinx_theme
2020
import torch_tensorrt
21+
from docutils.parsers.rst import Directive, directives
22+
from docutils.statemachine import StringList
23+
from docutils import nodes
2124

2225
# -- Project information -----------------------------------------------------
2326

@@ -47,6 +50,7 @@
4750
"sphinx.ext.coverage",
4851
"sphinx.ext.mathjax",
4952
"sphinx.ext.viewcode",
53+
"sphinx_gallery.gen_gallery",
5054
]
5155

5256
napoleon_use_ivar = True
@@ -78,6 +82,18 @@
7882
# relative to this directory. They are copied after the builtin static files,
7983
# so a file named "default.css" will overwrite the builtin "default.css".
8084
html_static_path = ["_static"]
85+
# Custom CSS paths should either relative to html_static_path
86+
# or fully qualified paths (eg. https://...)
87+
html_css_files = [
88+
"https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css",
89+
"css/custom.css",
90+
]
91+
92+
# sphinx-gallery configuration
93+
sphinx_gallery_conf = {
94+
"examples_dirs": "../examples",
95+
"gallery_dirs": "tutorials/_rendered_examples/",
96+
}
8197

8298
# Setup the breathe extension
8399
breathe_projects = {"Torch-TensorRT": "./_tmp/xml"}

docsrc/index.rst

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,43 @@ Getting Started
3636
getting_started/getting_started_with_windows
3737

3838

39-
Tutorials
39+
User Guide
4040
------------
4141
* :ref:`creating_a_ts_mod`
4242
* :ref:`getting_started_with_fx`
4343
* :ref:`ptq`
4444
* :ref:`runtime`
45-
* :ref:`serving_torch_tensorrt_with_triton`
4645
* :ref:`use_from_pytorch`
4746
* :ref:`using_dla`
47+
48+
.. toctree::
49+
:caption: User Guide
50+
:maxdepth: 1
51+
:hidden:
52+
53+
user_guide/creating_torchscript_module_in_python
54+
user_guide/getting_started_with_fx_path
55+
user_guide/ptq
56+
user_guide/runtime
57+
user_guide/use_from_pytorch
58+
user_guide/using_dla
59+
60+
Tutorials
61+
------------
62+
* :ref:`torch_tensorrt_tutorials`
63+
* :ref:`serving_torch_tensorrt_with_triton`
4864
* :ref:`notebooks`
4965

5066
.. toctree::
5167
:caption: Tutorials
52-
:maxdepth: 1
68+
:maxdepth: 3
5369
:hidden:
5470

55-
tutorials/creating_torchscript_module_in_python
56-
tutorials/getting_started_with_fx_path
57-
tutorials/ptq
58-
tutorials/runtime
5971
tutorials/serving_torch_tensorrt_with_triton
60-
tutorials/use_from_pytorch
61-
tutorials/using_dla
6272
tutorials/notebooks
73+
tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
74+
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
75+
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
6376

6477
Python API Documenation
6578
------------------------

docsrc/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
sphinx==4.5.0
2+
sphinx-gallery==0.13.0
23
breathe==4.33.1
34
exhale==0.3.1
45
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
File renamed without changes.
File renamed without changes.
File renamed without changes.

examples/README.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
.. _torch_tensorrt_tutorials:
2+
3+
Torch-TensorRT Tutorials
4+
===========================
5+
6+
The user guide covers the basic concepts and usage of Torch-TensorRT.
7+
We also provide a number of tutorials to explore specific usecases and advanced concepts

examples/dynamo/README.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
.. _torch_compile:
2+
3+
Dynamo / ``torch.compile``
4+
----------------------------
5+
6+
Torch-TensorRT provides a backend for the new ``torch.compile`` API released in PyTorch 2.0. In the following examples we describe
7+
a number of ways you can leverage this backend to accelerate inference.
8+
9+
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
10+
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
11+
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
.. _torch_compile_advanced_usage:
3+
4+
Torch Compile Advanced Usage
5+
======================================================
6+
7+
This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API."""
8+
9+
# %%
10+
# Imports and Model Definition
11+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
12+
13+
import torch
14+
import torch_tensorrt
15+
16+
# %%
17+
18+
# We begin by defining a model
19+
class Model(torch.nn.Module):
20+
def __init__(self) -> None:
21+
super().__init__()
22+
self.relu = torch.nn.ReLU()
23+
24+
def forward(self, x: torch.Tensor, y: torch.Tensor):
25+
x_out = self.relu(x)
26+
y_out = self.relu(y)
27+
x_y_out = x_out + y_out
28+
return torch.mean(x_y_out)
29+
30+
31+
# %%
32+
# Compilation with `torch.compile` Using Default Settings
33+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
34+
35+
# Define sample float inputs and initialize model
36+
sample_inputs = [torch.rand((5, 7)).cuda(), torch.rand((5, 7)).cuda()]
37+
model = Model().eval().cuda()
38+
39+
# %%
40+
41+
# Next, we compile the model using torch.compile
42+
# For the default settings, we can simply call torch.compile
43+
# with the backend "torch_tensorrt", and run the model on an
44+
# input to cause compilation, as so:
45+
optimized_model = torch.compile(model, backend="torch_tensorrt")
46+
optimized_model(*sample_inputs)
47+
48+
# %%
49+
# Compilation with `torch.compile` Using Custom Settings
50+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
51+
52+
# First, we use Torch utilities to clean up the workspace
53+
# after the previous compile invocation
54+
torch._dynamo.reset()
55+
56+
# Define sample half inputs and initialize model
57+
sample_inputs_half = [
58+
torch.rand((5, 7)).half().cuda(),
59+
torch.rand((5, 7)).half().cuda(),
60+
]
61+
model_half = Model().eval().cuda()
62+
63+
# %%
64+
65+
# If we want to customize certain options in the backend,
66+
# but still use the torch.compile call directly, we can provide
67+
# custom options to the backend via the "options" keyword
68+
# which takes in a dictionary mapping options to values.
69+
#
70+
# For accepted backend options, see the CompilationSettings dataclass:
71+
# py/torch_tensorrt/dynamo/_settings.py
72+
backend_kwargs = {
73+
"enabled_precisions": {torch.half},
74+
"debug": True,
75+
"min_block_size": 2,
76+
"torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
77+
"optimization_level": 4,
78+
"use_python_runtime": False,
79+
}
80+
81+
# Run the model on an input to cause compilation, as so:
82+
optimized_model_custom = torch.compile(
83+
model_half, backend="torch_tensorrt", options=backend_kwargs
84+
)
85+
optimized_model_custom(*sample_inputs_half)
86+
87+
# %%
88+
# Cleanup
89+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
90+
91+
# Finally, we use Torch utilities to clean up the workspace
92+
torch._dynamo.reset()
93+
94+
# %%
95+
# Cuda Driver Error Note
96+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
97+
#
98+
# Occasionally, upon exiting the Python runtime after Dynamo compilation with `torch_tensorrt`,
99+
# one may encounter a Cuda Driver Error. This issue is related to https://github.com/NVIDIA/TensorRT/issues/2052
100+
# and can be resolved by wrapping the compilation/inference in a function and using a scoped call, as in::
101+
#
102+
# if __name__ == '__main__':
103+
# compile_engine_and_infer()

0 commit comments

Comments
 (0)