Skip to content

Commit 9cbd31b

Browse files
committed
feat: Add example usage scripts for dynamo path
- Add sample scripts covering resnet18, transformers, and custom examples showcasing the `torch_tensorrt.dynamo.torch_compile` path, which can compile models with data-dependent control flow and other such restrictions which can make other compilation methods more difficult - Cover different customizeable features allowed in the new backend
1 parent 25db257 commit 9cbd31b

File tree

3 files changed

+147
-0
lines changed

3 files changed

+147
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
from torch_tensorrt.dynamo.torch_compile import create_backend
3+
from torch_tensorrt.fx.lower_setting import LowerPrecision
4+
5+
6+
##### Overview
7+
# This script is intended as an overview of the process by which
8+
# torch_tensorrt.dynamo.torch_compile works, and how it integrates
9+
# with the new torch.compile API.
10+
11+
# We begin by defining a model
12+
class Model(torch.nn.Module):
13+
def __init__(self) -> None:
14+
super().__init__()
15+
self.relu = torch.nn.ReLU()
16+
17+
def forward(self, x: torch.Tensor, y: torch.Tensor):
18+
x_out = self.relu(x)
19+
y_out = self.relu(y)
20+
x_y_out = x_out + y_out
21+
return torch.mean(x_y_out)
22+
23+
24+
##### Compilation using default settings
25+
26+
sample_inputs = [torch.rand((5, 7)).cuda(), torch.rand((5, 7)).cuda()]
27+
model = Model().eval().cuda()
28+
29+
# Next, we compile the model using torch.compile
30+
# For the default settings, we can simply call torch.compile
31+
# with the backend "tensorrt", and run the model on an
32+
# input to cause compilation, as so:
33+
optimized_model = torch.compile(model, backend="tensorrt")
34+
optimized_model(*sample_inputs)
35+
36+
37+
##### Compilation using custom settings
38+
39+
sample_inputs_half = [
40+
torch.rand((5, 7)).half().cuda(),
41+
torch.rand((5, 7)).half().cuda(),
42+
]
43+
model_half = Model().half().eval().cuda()
44+
45+
# Alternatively, if we want to customize certain options in the backend,
46+
# but still use the torch.compile call directly, we can call the
47+
# convenience/helper function create_backend to create a custom backend
48+
# which has been pre-populated with certain key
49+
custom_backend = create_backend(
50+
lower_precision=LowerPrecision.FP16, debug=True, max_num_trt_engines=2
51+
)
52+
optimized_model_custom = torch.compile(model_half, backend=custom_backend)
53+
optimized_model_custom(*sample_inputs_half)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch
2+
from torch_tensorrt.dynamo import torch_compile
3+
import torchvision.models as models
4+
5+
##### Overview
6+
# This script is intended as a sample of the torch_tensorrt.dynamo.torch_compile
7+
# workflow on the resnet18 model
8+
9+
10+
# Initialize model and sample inputs
11+
model = models.resnet18(pretrained=True).half().eval().to("cuda")
12+
inputs = [torch.randn((1, 3, 224, 224)).to("cuda").half()]
13+
14+
##### Optional Input Arguments
15+
16+
# Enabled precision for TensorRT optimization
17+
enabled_precisions = {torch.half}
18+
# Whether to print verbose logs
19+
debug = True
20+
# Workspace size for TensorRT
21+
workspace_size = 20 << 30
22+
# Maximum number of TRT Engines
23+
# (Higher value allows more graph segmentation)
24+
max_num_trt_engines = 100
25+
26+
27+
# Build and compile the model with torch.compile, using tensorrt backend
28+
optimized_model = torch_compile(
29+
model,
30+
inputs,
31+
enabled_precisions=enabled_precisions,
32+
debug=debug,
33+
workspace_size=workspace_size,
34+
max_num_trt_engines=max_num_trt_engines,
35+
)
36+
37+
38+
# Does not cause recompilation (same batch size as input)
39+
new_inputs = [torch.randn((1, 3, 224, 224)).half().to("cuda")]
40+
new_outputs = optimized_model(*new_inputs)
41+
42+
43+
# Does cause recompilation (new batch size)
44+
new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to("cuda")]
45+
new_batch_size_outputs = optimized_model(*new_batch_size_inputs)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
from torch_tensorrt.dynamo import torch_compile
3+
from transformers import BertModel
4+
5+
##### Overview
6+
# This script is intended as a sample of the torch_tensorrt.dynamo.torch_compile
7+
# workflow on the BERT base uncased model
8+
9+
10+
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
11+
inputs = [
12+
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"),
13+
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"),
14+
]
15+
16+
##### Optional Input Arguments
17+
18+
# Enabled precision for TensorRT optimization
19+
enabled_precisions = {torch.float}
20+
# Whether to print verbose logs
21+
debug = True
22+
# Workspace size for TensorRT
23+
workspace_size = 20 << 30
24+
# Maximum number of TRT Engines
25+
# (Higher value allows more graph segmentation)
26+
max_num_trt_engines = 200
27+
28+
29+
# Build and compile the model with torch.compile, using tensorrt backend
30+
optimized_model = torch_compile(
31+
model,
32+
inputs,
33+
enabled_precisions=enabled_precisions,
34+
debug=debug,
35+
workspace_size=workspace_size,
36+
max_num_trt_engines=max_num_trt_engines,
37+
)
38+
39+
# Does not cause recompilation (same batch size as input)
40+
new_inputs = [
41+
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"),
42+
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"),
43+
]
44+
new_outputs = optimized_model(*new_inputs)
45+
46+
47+
# Does cause recompilation (new batch size)
48+
new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to("cuda")]
49+
new_batch_size_outputs = optimized_model(*new_batch_size_inputs)

0 commit comments

Comments
 (0)