Skip to content

Commit 7f58048

Browse files
committed
chore: update doc
1 parent 23131c3 commit 7f58048

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Tutorials
6767
* :ref:`custom_kernel_plugins`
6868
* :ref:`mutable_torchtrt_module_example`
6969
* :ref:`weight_streaming_example`
70+
* :ref:`pre_allocated_output_example`
7071

7172
.. toctree::
7273
:caption: Tutorials
@@ -84,6 +85,7 @@ Tutorials
8485
tutorials/_rendered_examples/dynamo/custom_kernel_plugins
8586
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
8687
tutorials/_rendered_examples/dynamo/weight_streaming_example
88+
tutorials/_rendered_examples/dynamo/pre_allocated_output_example
8789

8890
Dynamo Frontend
8991
----------------
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
.. _pre_allocated_output_example:
3+
4+
Pre-allocated output buffer
5+
======================================================
6+
7+
The TensorRT runtime module acts as a wrapper around a PyTorch model (or subgraph) that has been compiled and optimized into a TensorRT engine.
8+
9+
When the compiled module is executed, input and output tensors are set to TensorRT context for processing.
10+
If output buffer allocation is moved after the execution of the TensorRT context and used it for next inference, GPU tasks and memory allocation tasks can operate concurrently. This overlap allows for more efficient use of GPU resources, potentially improving the performance of inference.
11+
12+
This optimization is particularly effective in below cases
13+
14+
1. Small inference time
15+
- The allocation of output buffers typically requires minimal CPU cycles, as the caching mechanism efficiently handles memory reuse. The time taken for this allocation is relatively constant compared to the overall inference time, leading to noticeable performance improvements, especially in scenarios involving small inference workloads. This is because the reduced allocation time contributes to faster execution when the computational workload is not large enough to overshadow these savings.
16+
2. Multiple graph breaks
17+
- If the module contains operations that are not supported by TensorRT, the unsupported parts are handled by PyTorch and this fallback results in a graph break. The cumulative effect of optimized buffer allocations across multiple subgraphs can enhance overall inference performance.
18+
- While optimizing output buffers can mitigate some of this overhead, reducing or removing graph breaks should be prioritized as it enables more comprehensive optimizations
19+
3. Static input or infrequent input shape change
20+
- If shape is changed, pre-allocated buffer cannot be used for next inference and there will new allocation before executing the TensorRT context. This feature is not suitable for use cases with frequent input shape changes
21+
"""
22+
23+
# %%
24+
# Imports and Model Definition
25+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26+
27+
import timeit
28+
29+
import numpy as np
30+
import torch
31+
import torch_tensorrt
32+
from transformers import BertModel
33+
34+
# %%
35+
# Define function to measure inference performance
36+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
37+
38+
39+
def test_module_perf(model, *input):
40+
timings = []
41+
42+
# Warm-up phase to ensure consistent and accurate performance measurements.
43+
with torch.no_grad():
44+
for _ in range(3):
45+
model(*input)
46+
torch.cuda.synchronize()
47+
48+
# Timing phase to measure inference performance
49+
with torch.no_grad():
50+
for i in range(10):
51+
start_time = timeit.default_timer()
52+
model(*input)
53+
torch.cuda.synchronize()
54+
end_time = timeit.default_timer()
55+
timings.append(end_time - start_time)
56+
times = np.array(timings)
57+
time_med = np.median(times)
58+
59+
# Return the median time as a representative performance metric
60+
return time_med
61+
62+
63+
# %%
64+
# Load model and compile
65+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66+
67+
# Load bert model
68+
model = (
69+
BertModel.from_pretrained("bert-base-uncased", torchscript=True)
70+
.eval()
71+
.half()
72+
.to("cuda")
73+
)
74+
# Define sample inputs
75+
inputs = [
76+
torch.randint(0, 5, (1, 128), dtype=torch.int32).to("cuda"),
77+
torch.randint(0, 5, (1, 128), dtype=torch.int32).to("cuda"),
78+
]
79+
# Next, we compile the model using torch_tensorrt.compile
80+
optimized_model = torch_tensorrt.compile(
81+
model,
82+
ir="dynamo",
83+
enabled_precisions={torch.half},
84+
inputs=inputs,
85+
)
86+
87+
# %%
88+
# Enable/Disable pre-allocated output buffer feature using runtime api
89+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
90+
91+
# We can enable the pre-allocated output buffer with a context manager
92+
with torch_tensorrt.runtime.enable_pre_allocated_outputs(optimized_model):
93+
out_trt = optimized_model(*inputs)
94+
95+
# Alternatively, we can enable the feature using a context object
96+
pre_allocated_output_ctx = torch_tensorrt.runtime.enable_pre_allocated_outputs(
97+
optimized_model
98+
)
99+
pre_allocated_output_ctx.set_pre_allocated_output(True)
100+
time_opt = test_module_perf(optimized_model, *inputs)
101+
102+
# Disable the pre-allocated output buffer feature and perform inference normally
103+
pre_allocated_output_ctx.set_pre_allocated_output(False)
104+
out_trt = optimized_model(*inputs)
105+
time_normal = test_module_perf(optimized_model, *inputs)
106+
107+
time_opt_ms = time_opt * 1000
108+
time_normal_ms = time_normal * 1000
109+
110+
print(f"normal trt model time: {time_normal_ms:.3f} ms")
111+
print(f"pre-allocated output buffer model time: {time_opt_ms:.3f} ms")

0 commit comments

Comments
 (0)