|
| 1 | +""" |
| 2 | +.. _cudagraphs_wrapper_example: |
| 3 | +
|
| 4 | +Wrapped runtime module for cuda graphs |
| 5 | +====================================== |
| 6 | +
|
| 7 | +If Torch-TensorRT encounters unsupported operations during compilation, it can fall back to using |
| 8 | +PyTorch's native implementation for those specific operations. This fallback mechanism allows the |
| 9 | +rest of the model to be executed using TensorRT, while only the unsupported parts are handled by PyTorch. |
| 10 | +This fallback results in a graph break, which can reduce the overall performance benefits of using |
| 11 | +TensorRT because it introduces additional overhead from switching between TensorRT and PyTorch execution contexts |
| 12 | +
|
| 13 | +Applying CUDA Graphs to a PyTorch module that contains graph breaks can enhance performance by leveraging |
| 14 | +the benefits of CUDA Graphs even in the presence of these breaks. Torch-TensorRT provides |
| 15 | +wrapper runtime module with CUDA Graphs for modules that have graph breaks allows you to mitigate the |
| 16 | +inefficiencies introduced by these breaks |
| 17 | +""" |
| 18 | + |
| 19 | +# %% |
| 20 | +# Imports and Model Definition |
| 21 | +# ---------------------------------- |
| 22 | + |
| 23 | +import torch |
| 24 | +import torch_tensorrt |
| 25 | + |
| 26 | + |
| 27 | +class SampleModel(torch.nn.Module): |
| 28 | + def forward(self, x): |
| 29 | + return torch.relu((x + 2) * 0.5) |
| 30 | + |
| 31 | + |
| 32 | +model = SampleModel().eval().cuda() |
| 33 | +input = torch.randn((1, 3, 224, 224)).to("cuda") |
| 34 | + |
| 35 | +# %% |
| 36 | +# Compiler options |
| 37 | +# ---------------------------------- |
| 38 | +# |
| 39 | +# The 'torch_executed_ops' compiler option is used to demonstrate graph breaks for this example. |
| 40 | +# debug=True compiler option provides detailed insights into the compilation process and helps |
| 41 | +# pinpoint where graph breaks occur |
| 42 | + |
| 43 | +# Create a TensorRT-compiled model |
| 44 | +trt_model = torch_tensorrt.compile( |
| 45 | + model, |
| 46 | + ir="dynamo", |
| 47 | + inputs=[input], |
| 48 | + min_block_size=1, |
| 49 | + pass_through_build_failures=True, |
| 50 | + debug=True, |
| 51 | + torch_executed_ops={"torch.ops.aten.mul.Tensor"}, |
| 52 | +) |
| 53 | + |
| 54 | +# %% |
| 55 | +# Compiler log |
| 56 | +# ---------------------------------- |
| 57 | +# |
| 58 | +# This compiler log indicates torch.ops.aten.mul.Tensor operator is executed by PyTorch. |
| 59 | +# Peformance of this module can be enhanced by using wrapped module. |
| 60 | + |
| 61 | +############################################################################## |
| 62 | +# .. code-block:: none |
| 63 | +# |
| 64 | +# ++++++++++++++ Dry-Run Results for Graph +++++++++++++++++ |
| 65 | +# |
| 66 | +# The graph consists of 3 Total Operators, of which 2 operators are supported, 66.67% coverage |
| 67 | +# |
| 68 | +# The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph: |
| 69 | +# torch.ops.aten.mul.Tensor: 1 |
| 70 | +# |
| 71 | +# The following nodes are currently set to run in Torch: |
| 72 | +# Node: torch.ops.aten.mul.Tensor, with layer location: /mul |
| 73 | +# Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner |
| 74 | + |
| 75 | +# %% |
| 76 | +# Running wrapped module with cuda graphs |
| 77 | +# ---------------------------------- |
| 78 | +# |
| 79 | +# Please note that initializing with wrapper module involve warm-up phase where the module |
| 80 | +# is executed several times. This ensures that memory allocations and initializations are |
| 81 | +# not recorded in CUDA Graphs. |
| 82 | +# When using the TensorRT module within a CUDA Graph context manager, a wrapped_module is returned. |
| 83 | +# This module captures the execution graph, allowing for efficient replay during subsequent |
| 84 | +# inferences by reducing kernel launch overheads and improving performance. |
| 85 | +with torch_tensorrt.runtime.enable_cudagraphs(trt_model) as wrapped_module: |
| 86 | + wrapped_module(input) |
0 commit comments