Skip to content

examples: Add example usage scripts for torch_tensorrt.dynamo.compile path #1890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ docsrc/_build
docsrc/_notebooks
docsrc/_cpp_api
docsrc/_tmp
docsrc/tutorials/_rendered_examples
*.so
__pycache__
*.egg-info
Expand Down Expand Up @@ -66,4 +67,4 @@ bazel-tensorrt
*.cache
*cifar-10-batches-py*
bazel-project
build/
build/
7 changes: 7 additions & 0 deletions docsrc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"sphinx.ext.coverage",
"sphinx.ext.mathjax",
"sphinx.ext.viewcode",
"sphinx_gallery.gen_gallery",
]

napoleon_use_ivar = True
Expand Down Expand Up @@ -79,6 +80,12 @@
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]

# sphinx-gallery configuration
sphinx_gallery_conf = {
"examples_dirs": "../examples/dynamo",
"gallery_dirs": "tutorials/_rendered_examples/",
}

# Setup the breathe extension
breathe_projects = {"Torch-TensorRT": "./_tmp/xml"}
breathe_default_project = "Torch-TensorRT"
Expand Down
31 changes: 22 additions & 9 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,43 @@ Getting Started
getting_started/getting_started_with_windows


Tutorials
User Guide
------------
* :ref:`creating_a_ts_mod`
* :ref:`getting_started_with_fx`
* :ref:`ptq`
* :ref:`runtime`
* :ref:`serving_torch_tensorrt_with_triton`
* :ref:`use_from_pytorch`
* :ref:`using_dla`

.. toctree::
:caption: User Guide
:maxdepth: 1
:hidden:

user_guide/creating_torchscript_module_in_python
user_guide/getting_started_with_fx_path
user_guide/ptq
user_guide/runtime
user_guide/use_from_pytorch
user_guide/using_dla

Tutorials
------------
* :ref:`serving_torch_tensorrt_with_triton`
* :ref:`notebooks`
* :ref:`dynamo_compile`

.. toctree::
:caption: Tutorials
:maxdepth: 1
:maxdepth: 3
:hidden:

tutorials/creating_torchscript_module_in_python
tutorials/getting_started_with_fx_path
tutorials/ptq
tutorials/runtime
tutorials/serving_torch_tensorrt_with_triton
tutorials/use_from_pytorch
tutorials/using_dla
tutorials/notebooks
tutorials/_rendered_examples/dynamo_compile_resnet_example
tutorials/_rendered_examples/dynamo_compile_transformers_example
tutorials/_rendered_examples/dynamo_compile_advanced_usage

Python API Documenation
------------------------
Expand Down
1 change: 1 addition & 0 deletions docsrc/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
sphinx==4.5.0
sphinx-gallery==0.13.0
breathe==4.33.1
exhale==0.3.1
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
Expand Down
12 changes: 12 additions & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. _dynamo_compile:

Dynamo Compile Examples
================

This document contains examples of usage of the `torch_tensorrt.dynamo.compile` API which integrates with `torch.compile` functionality

Overview of Available Scripts
-----------------------------------------------
- `dynamo_compile_resnet_example.py <./dynamo_compile_resnet_example.html>`_: Example showcasing compilation of ResNet model
- `dynamo_compile_transformers_example.py <./dynamo_compile_transformers_example.html>`_: Example showcasing compilation of transformer-based model
- `dynamo_compile_advanced_usage.py <./dynamo_compile_advanced_usage.html>`_: Advanced usage including making a custom backend to use directly with the `torch.compile` API
83 changes: 83 additions & 0 deletions examples/dynamo/dynamo_compile_advanced_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Dynamo Compile Advanced Usage
=========================

This interactive script is intended as an overview of the process by which `torch_tensorrt.dynamo.compile` works, and how it integrates with the new `torch.compile` API."""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
from torch_tensorrt.dynamo.backend import create_backend
from torch_tensorrt.fx.lower_setting import LowerPrecision

# %%

# We begin by defining a model
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x: torch.Tensor, y: torch.Tensor):
x_out = self.relu(x)
y_out = self.relu(y)
x_y_out = x_out + y_out
return torch.mean(x_y_out)


# %%
# Compilation with `torch.compile` Using Default Settings
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Define sample float inputs and initialize model
sample_inputs = [torch.rand((5, 7)).cuda(), torch.rand((5, 7)).cuda()]
model = Model().eval().cuda()

# %%

# Next, we compile the model using torch.compile
# For the default settings, we can simply call torch.compile
# with the backend "tensorrt", and run the model on an
# input to cause compilation, as so:
optimized_model = torch.compile(model, backend="tensorrt")
optimized_model(*sample_inputs)

# %%
# Compilation with `torch.compile` Using Custom Settings
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Define sample half inputs and initialize model
sample_inputs_half = [
torch.rand((5, 7)).half().cuda(),
torch.rand((5, 7)).half().cuda(),
]
model_half = Model().eval().cuda()

# %%

# If we want to customize certain options in the backend,
# but still use the torch.compile call directly, we can call the
# convenience/helper function create_backend to create a custom backend
# which has been pre-populated with certain keys
custom_backend = create_backend(
lower_precision=LowerPrecision.FP16,
debug=True,
min_block_size=2,
torch_executed_ops={},
)

# Run the model on an input to cause compilation, as so:
optimized_model_custom = torch.compile(model_half, backend=custom_backend)
optimized_model_custom(*sample_inputs_half)

# %%
# Cleanup
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Finally, we use Torch utilities to clean up the workspace
torch._dynamo.reset()

with torch.no_grad():
torch.cuda.empty_cache()
82 changes: 82 additions & 0 deletions examples/dynamo/dynamo_compile_resnet_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Dynamo Compile ResNet Example
=========================

This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a ResNet model."""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
import torch_tensorrt
import torchvision.models as models

# %%

# Initialize model with half precision and sample inputs
model = models.resnet18(pretrained=True).half().eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda").half()]

# %%
# Optional Input Arguments to `torch_tensorrt.dynamo.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Enabled precision for TensorRT optimization
enabled_precisions = {torch.half}

# Whether to print verbose logs
debug = True

# Workspace size for TensorRT
workspace_size = 20 << 30

# Maximum number of TRT Engines
# (Lower value allows more graph segmentation)
min_block_size = 3

# Operations to Run in Torch, regardless of converter support
torch_executed_ops = {}

# %%
# Compilation with `torch_tensorrt.dynamo.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Build and compile the model with torch.compile, using Torch-TensorRT backend
optimized_model = torch_tensorrt.dynamo.compile(
model,
inputs,
enabled_precisions=enabled_precisions,
debug=debug,
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
)

# %%
# Equivalently, we could have run the above via the convenience frontend, as so:
# `torch_tensorrt.compile(model, ir="dynamo_compile", inputs=inputs, ...)`

# %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Does not cause recompilation (same batch size as input)
new_inputs = [torch.randn((1, 3, 224, 224)).half().to("cuda")]
new_outputs = optimized_model(*new_inputs)

# %%

# Does cause recompilation (new batch size)
new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to("cuda")]
new_batch_size_outputs = optimized_model(*new_batch_size_inputs)

# %%
# Cleanup
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Finally, we use Torch utilities to clean up the workspace
torch._dynamo.reset()

with torch.no_grad():
torch.cuda.empty_cache()
92 changes: 92 additions & 0 deletions examples/dynamo/dynamo_compile_transformers_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Dynamo Compile Transformers Example
=========================

This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model."""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
import torch_tensorrt
from transformers import BertModel

# %%

# Initialize model with float precision and sample inputs
model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda")
inputs = [
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]


# %%
# Optional Input Arguments to `torch_tensorrt.dynamo.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Enabled precision for TensorRT optimization
enabled_precisions = {torch.float}

# Whether to print verbose logs
debug = True

# Workspace size for TensorRT
workspace_size = 20 << 30

# Maximum number of TRT Engines
# (Lower value allows more graph segmentation)
min_block_size = 3

# Operations to Run in Torch, regardless of converter support
torch_executed_ops = {}

# %%
# Compilation with `torch_tensorrt.dynamo.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Build and compile the model with torch.compile, using tensorrt backend
optimized_model = torch_tensorrt.dynamo.compile(
model,
inputs,
enabled_precisions=enabled_precisions,
debug=debug,
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
)

# %%
# Equivalently, we could have run the above via the convenience frontend, as so:
# `torch_tensorrt.compile(model, ir="dynamo_compile", inputs=inputs, ...)`

# %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Does not cause recompilation (same batch size as input)
new_inputs = [
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]
new_outputs = optimized_model(*new_inputs)

# %%

# Does cause recompilation (new batch size)
new_inputs = [
torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"),
torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"),
]
new_outputs = optimized_model(*new_inputs)

# %%
# Cleanup
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Finally, we use Torch utilities to clean up the workspace
torch._dynamo.reset()

with torch.no_grad():
torch.cuda.empty_cache()