Skip to content

Commit 7cf4ad9

Browse files
committed
Fixed the comment and added engine cache example
1 parent 1a309b8 commit 7cf4ad9

File tree

1 file changed

+71
-5
lines changed

1 file changed

+71
-5
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
The Mutable Torch TensorRT Module is designed to address these challenges, making interaction with the Torch-TensorRT module easier than ever.
1212
1313
In this tutorial, we are going to walk through
14-
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
15-
2. Save a Mutable Torch TensorRT Module
16-
3. Integration with Huggingface pipeline in LoRA use case
17-
4. Usage of dynamic shape with Mutable Torch TensorRT Module
14+
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
15+
2. Save a Mutable Torch TensorRT Module
16+
3. Integration with Huggingface pipeline in LoRA use case
17+
4. Usage of dynamic shape with Mutable Torch TensorRT Module
1818
"""
1919

20+
# %%
2021
import numpy as np
2122
import torch
2223
import torch_tensorrt as torch_trt
@@ -144,6 +145,12 @@
144145
# %%
145146
# Use Mutable Torch TensorRT module with dynamic shape
146147
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
148+
# When adding dynamic shape hint to MutableTorchTensorRTModule, The shape hint should EXACTLY follow the semantics of arg_inputs and kwarg_inputs passed to the forward function
149+
# and should not omit any entries (except None in the kwarg_inputs). If there is a nested dict/list in the input, the dynamic shape for that entry should also be an nested dict/list.
150+
# If the dynamic shape is not required for an input, an empty dictionary should be given as the shape hint for that input.
151+
# Note that you should exclude keyword arguments with value None as those will be filtered out.
152+
153+
147154
class Model(torch.nn.Module):
148155
def __init__(self):
149156
super().__init__()
@@ -167,7 +174,10 @@ def forward(self, a, b, c={}):
167174
dim_2 = torch.export.Dim("dim2", min=1, max=50)
168175
args_dynamic_shapes = ({1: dim_1}, {0: dim_0})
169176
kwarg_dynamic_shapes = {
170-
"c": {"a": {}, "b": {0: dim_2}},
177+
"c": {
178+
"a": {},
179+
"b": {0: dim_2},
180+
}, # a's shape does not change so we give it an empty dict
171181
}
172182
# Export the model first with custom dynamic shape constraints
173183
model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1)
@@ -181,3 +191,59 @@ def forward(self, a, b, c={}):
181191
}
182192
# Run without recompiling
183193
model(*inputs_2, **kwargs_2)
194+
195+
# %%
196+
# Use Mutable Torch TensorRT module with persistent cache
197+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
198+
# Leveraging engine caching, we are able to shortcut the engine compilation and save much time.
199+
import os
200+
201+
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
202+
203+
model = models.resnet18(pretrained=True).eval().to("cuda")
204+
enabled_precisions = {torch.float}
205+
debug = False
206+
min_block_size = 1
207+
use_python_runtime = True
208+
209+
times = []
210+
start = torch.cuda.Event(enable_timing=True)
211+
end = torch.cuda.Event(enable_timing=True)
212+
213+
214+
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
215+
# Mark the dim0 of inputs as dynamic
216+
model = torch_trt.MutableTorchTensorRTModule(
217+
model,
218+
use_python_runtime=use_python_runtime,
219+
enabled_precisions=enabled_precisions,
220+
debug=debug,
221+
min_block_size=min_block_size,
222+
immutable_weights=False,
223+
cache_built_engines=True,
224+
reuse_cached_engines=True,
225+
engine_cache_size=1 << 30, # 1GB
226+
)
227+
228+
229+
def remove_timing_cache(path=TIMING_CACHE_PATH):
230+
if os.path.exists(path):
231+
os.remove(path)
232+
233+
234+
remove_timing_cache()
235+
236+
for i in range(4):
237+
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
238+
239+
start.record()
240+
model(*inputs) # Recompile
241+
end.record()
242+
torch.cuda.synchronize()
243+
times.append(start.elapsed_time(end))
244+
245+
print("----------------dynamo_compile----------------")
246+
print("Without engine caching, used:", times[0], "ms")
247+
print("With engine caching used:", times[1], "ms")
248+
print("With engine caching used:", times[2], "ms")
249+
print("With engine caching used:", times[3], "ms")

0 commit comments

Comments
 (0)