Skip to content

Commit cee6d24

Browse files
authored
docs: A tutorial on how to overload converters in Torch-TensorRT (#3197)
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 15e349d commit cee6d24

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

docsrc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ Tutorials
114114
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
115115
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
116116
tutorials/_rendered_examples/dynamo/torch_export_cudagraphs
117+
tutorials/_rendered_examples/dynamo/converter_overloading
117118
tutorials/_rendered_examples/dynamo/custom_kernel_plugins
118119
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
119120
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion

examples/dynamo/README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ a number of ways you can leverage this backend to accelerate inference.
1111
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
1212
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
1313
* :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"`
14+
* :ref:`converter_overloading`: How to write custom converters and overload existing ones
1415
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
1516
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
1617
* :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""
2+
.. _converter_overloading:
3+
4+
Overloading Torch-TensorRT Converters with Custom Converters
5+
===================================================================
6+
7+
If for some reason you want to change the conversion behavior of a specific PyTorch operation to TensorRT, you can do so by writing a custom converter and overloading Torch-TensorRT's.
8+
This may be for reasons like wanting to use a custom kernel instead of TensorRT's kernels or because you want to use a different implementation of a layer in TensorRT than the one
9+
Torch-TensorRT would normally use.
10+
11+
In this tutorial, we will demonstrate how to overload Torch-TensorRT's conversion of the `torch.nn.functional.gelu` operation to TensorRT with a custom converter that uses a different implementation
12+
of the GeLU layer.
13+
14+
"""
15+
16+
import logging
17+
import sys
18+
19+
import torch
20+
import torch_tensorrt
21+
22+
# %% GeLU Operator in PyTorch
23+
#
24+
# GeLU has 2 modes in PyTorch, one using the ``erf`` function and the other using the ``tanh`` approximation.
25+
# TensorRT natively supports both implementations as an activation layer, but suppose we want to use a custom implementation of GeLU in TensorRT only for ``tanh`` mode.
26+
27+
28+
class GeLU(torch.nn.Module):
29+
def __init__(self, mode="tanh"):
30+
super().__init__()
31+
self.mode = mode
32+
33+
def forward(self, x):
34+
return torch.nn.functional.gelu(x, approximate=self.mode)
35+
36+
37+
my_mod = GeLU(mode="tanh")
38+
ex_input = torch.randn(2, 5).to("cuda")
39+
40+
41+
# %%
42+
# As a baseline, we can use the standard Torch-TensorRT GeLU converter (in tanh approximation mode) with our module.
43+
my_standard_gelu = torch_tensorrt.compile(
44+
my_mod, arg_inputs=(ex_input,), min_block_size=1
45+
)
46+
print(my_standard_gelu.graph)
47+
print(my_standard_gelu(ex_input))
48+
49+
# %%
50+
# Writing a Custom Converter
51+
# --------------------------
52+
#
53+
# Converters are functions that take a specific instance of a PyTorch operation in a PyTorch graph and convert it to an equivalent set TensorRT operations in an under-construction TensorRT graph.
54+
# They are registered with Torch-TensorRT using the ``@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter`` decorator.
55+
# At a code level, converter takes the current conversion state (``ConversionCtx``), the next operator in the graph to convert, and the arguments to that node
56+
# and returns the placeholder outputs for that operation, while as side-effect inserting the necessary TensorRT layers into the TensorRT network.
57+
#
58+
59+
from typing import Dict, Sequence, Tuple, Union
60+
61+
from torch.fx.node import Argument, Node, Target
62+
from torch_tensorrt.dynamo import CompilationSettings
63+
from torch_tensorrt.dynamo.conversion import ConversionContext
64+
65+
import tensorrt as trt
66+
67+
# %%
68+
# Converter Metadata
69+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
70+
71+
72+
@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
73+
# The PyTorch operation to convert, when this operation is encountered, this converter will be called
74+
torch.ops.aten.gelu.default,
75+
# Validators are functions that determine that given a specific node, if it can be converted by the converter
76+
capability_validator=lambda node, settings: (
77+
"approximate" in node.kwargs and node.kwargs["approximate"] == "tanh"
78+
),
79+
# Can this converter be used in cases where the input shapes are dynamic
80+
supports_dynamic_shapes=True,
81+
# Set the priority of the converter to supersede the default one
82+
priority=torch_tensorrt.dynamo.conversion.ConverterPriority.HIGH,
83+
)
84+
85+
# %%
86+
# For the decorator defining a converter, there is one required argument and a few optional ones.
87+
# All converters need a target operator they will run against, the idea being that when there is an instance of ``torch.ops.aten.gelu.default`` in the graph, this converter will be called.
88+
#
89+
# Following the target operator, you can provide additional metadata that defines the capabilities of the converter and the priority of the converter verses other possible converters for the target in question
90+
#
91+
# The primary tool for defining the capabilities of a converter is the ``capability_validator`` argument,
92+
# which is a lambda function that takes a specific node in the graph as well as the user compilation settings and returns a boolean indicating if the converter can be used for that node.
93+
# This validator function gets run prior to the graph partitioning phase against each instance of the converter target op. Nodes where there are no converters with validators that pass during this phase, will be executed in PyTorch at runtime.
94+
# This is useful for cases where you want to use a custom converter only in specific cases, like in our case where we only want to use our converter when ``approximate == "tanh"``.
95+
#
96+
# Distinct to the validator is the ``supports_dynamic_shapes`` argument, which is a boolean indicating if the converter can be used in cases where the input shapes are dynamic.
97+
# If this is set to ``False``, in cases where the inputs provided by the user are dynamic, this converter will be disabled. If there are no alternatives that support dynamic shape, the operation will be run in PyTorch.
98+
#
99+
# Finally there is the ``priority`` argument, which is an enum from the ``torch_tensorrt.dynamo.conversion.ConverterPriority`` class that defines the priority of the converter. The two options are ``HIGH`` and ``STANDARD``.
100+
# Converters registered with ``STANDARD`` will be appended to the converter list for a given operation, while converters registered with ``HIGH`` will be prepended to the list.
101+
# Candidate converters are evalated for their suitablity in this priority order and the first converter that passes the validator is used.
102+
103+
104+
# %%
105+
# Converter Implementation
106+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
107+
# The converter function itself takes the following arguments: the current conversion context, the target operator, the arguments to the target operator, the keyword arguments to the target operator, and the name of the target operator.
108+
# Arguments can either any of python primitives, ``torch.Tensor``, ``np.Arrays`` or ``ITensor`` objects.
109+
# The converter function should return the outputs of the target operator in terms of TensorRT ``ITensor`` primarily. These inputs and outputs should correspond to the schema
110+
# of the target PyTorch operator which can be found here `https://pytorch.org/docs/main/torch.compiler_ir.html <https://pytorch.org/docs/main/torch.compiler_ir.html>`_.
111+
#
112+
# Since Torch-TensorRT covers the core ATen opset, it has already abstracted many of the common low-level operations into helper functions that can be used to build up the TensorRT network.
113+
# This allows developers to avoid the boilerplate of creating the TensorRT layers directly and instead focus on the high-level logic of the conversion.
114+
# The helper functions are located in the ``torch_tensorrt.dynamo.conversion.impl`` module and are designed to be composable and interoperable with raw-TensorRT implementations.
115+
# In this case, we will use the Torch-TensorRT ``mul``, ``add`` and ``tanh`` functions from ``impl`` to implement our alternative GeLU layer.
116+
117+
118+
def aten_ops_gelu(
119+
ctx: ConversionContext,
120+
target: Target,
121+
args: Tuple[Argument, ...],
122+
kwargs: Dict[str, Argument],
123+
name: str,
124+
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
125+
# The schema for torch.ops.aten.gelu.default is gelu(Tensor self, *, str approximate=’none’) -> Tensor
126+
127+
from torch_tensorrt.dynamo import SourceIR
128+
from torch_tensorrt.dynamo.conversion import impl
129+
130+
# Cheap way to allow layer names to be unqiue
131+
op_count = 0
132+
133+
def get_op_count():
134+
nonlocal op_count
135+
op_count += 1
136+
return op_count
137+
138+
mul = lambda x, y: impl.elementwise.mul(
139+
ctx,
140+
target,
141+
name=f"mul_{get_op_count()}",
142+
source_ir=SourceIR.ATEN,
143+
lhs_val=x,
144+
rhs_val=y,
145+
)
146+
add = lambda x, y: impl.elementwise.add(
147+
ctx,
148+
target,
149+
name=f"add_{get_op_count()}",
150+
source_ir=SourceIR.ATEN,
151+
lhs_val=x,
152+
rhs_val=y,
153+
)
154+
tanh = lambda x: impl.activation.tanh(
155+
ctx, target, name=f"tanh_{get_op_count()}", source_ir=SourceIR.ATEN, input_val=x
156+
)
157+
158+
# So we know that our custom converter is being run instead of the standard one
159+
print("\n\n---------------------------")
160+
print("Using custom GeLU converter")
161+
print("---------------------------\n\n")
162+
163+
x_7 = mul(args[0], 0.5)
164+
x_8 = mul(args[0], 0.79788456080000003)
165+
x_9 = mul(args[0], 0.044714999999999998)
166+
x_10 = mul(x_9, args[0])
167+
x_11 = add(x_10, 1.0)
168+
x_12 = mul(x_8, x_11)
169+
x_13 = tanh(x_12)
170+
x_14 = add(x_13, 1.0)
171+
x_15 = mul(x_7, x_14)
172+
173+
return x_15
174+
175+
176+
# %%
177+
# Using our Custom Converter
178+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
179+
#
180+
# We can now recompile and see that our custom converter is being called to convert GeLU to TensorRT.
181+
my_custom_gelu = torch_tensorrt.compile(
182+
my_mod, arg_inputs=(ex_input,), min_block_size=1
183+
)
184+
185+
print(my_custom_gelu.graph)
186+
print(my_custom_gelu(ex_input))
187+
188+
# %%
189+
#
190+
# We can verify that our implementation matches the TensorRT implementation for the ``tanh`` approximation.
191+
print(
192+
f"tanh approximations are close: {torch.allclose(my_standard_gelu(ex_input), my_custom_gelu(ex_input))}"
193+
)
194+
195+
196+
# %%
197+
#
198+
# Finally, we want to verify that in the case that the ``approximate`` argument is not set to ``tanh``, our custom converter is not used.
199+
200+
my_mod_erf = GeLU(mode="none")
201+
my_gelu_erf = torch_tensorrt.compile(
202+
my_mod_erf, arg_inputs=(ex_input,), min_block_size=1
203+
)
204+
205+
# %%
206+
#
207+
# Notice that we don't see the print statement from our custom converter, indicating that it was not used. However, looking at the graph, we can still see that a TensorRT engine was created to run the GeLU operation.
208+
# In this case, the validator for our custom converter returned ``False``, so the conversion system moved on to the next converter in the list, the standard GeLU converter and used that one to convert the operation.
209+
210+
print(my_gelu_erf.graph)
211+
print(my_gelu_erf(ex_input))

0 commit comments

Comments
 (0)