Skip to content

Commit 7fc742f

Browse files
committed
fix: Add full support for handling long/double
- Add utilities for repairing long/double inputs to TRT engines, including support for autocasting back to long/double after the computation completes - Add multiple helper functions to enable easy testing and diagnosis of long/double IO to TRT engines - Add necessary compiler code to enable usage of the `truncate_long_and_double` argument as a switch for the feature
1 parent 9baca3e commit 7fc742f

File tree

2 files changed

+183
-1
lines changed

2 files changed

+183
-1
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
partition,
1717
get_submod_inputs,
1818
)
19+
20+
from torch_tensorrt.dynamo.backend.utils import repair_long_or_double_input
1921
from torch_tensorrt.dynamo.backend.conversion import convert_module
2022

2123
from torch._dynamo.backends.common import fake_tensor_unsupported
@@ -160,6 +162,40 @@ def _compile_module(
160162
partitioned_module, submodule, sample_inputs
161163
)
162164

165+
# Ensure all submodule inputs do not require a gradient
166+
for param in submodule_inputs:
167+
param.requires_grad = False
168+
169+
# Handle long/double inputs if requested by the user
170+
if settings.truncate_long_and_double:
171+
num_submodule_inputs = len(submodule_inputs)
172+
173+
# For each input to the TRT subgraph, check if its type is long/double
174+
for position in range(num_submodule_inputs):
175+
param = submodule_inputs[position]
176+
177+
# If the data type of the input is long/double, insert necessary
178+
# casts to replace the operation
179+
if param.dtype in (torch.int64, torch.float64):
180+
submodule_outputs = submodule(*submodule_inputs)
181+
repair_long_or_double_input(
182+
partitioned_module,
183+
position,
184+
name,
185+
submodule_outputs,
186+
param.dtype,
187+
)
188+
189+
# Repair submodule inputs in accordance with inserted casts
190+
dtype_32bit = (
191+
torch.int32 if (param.dtype == torch.int64) else torch.float32
192+
)
193+
submodule_inputs = (
194+
submodule_inputs[:position]
195+
+ (param.to(dtype_32bit),)
196+
+ submodule_inputs[position + 1 :]
197+
)
198+
163199
# Create TRT Module from submodule
164200
trt_mod = convert_module(
165201
submodule,

py/torch_tensorrt/dynamo/backend/utils.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
2+
from torch.fx.node import _get_qualified_name
33
from typing import Any, Union, Sequence, Dict
44
from torch_tensorrt import _Input, Device
55

@@ -66,3 +66,149 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
6666
)
6767

6868
return device
69+
70+
71+
def _extract_downstream_get_nodes(
72+
module_node: torch.fx.Node, output_indices: Sequence[int]
73+
) -> Sequence[torch.fx.Node]:
74+
"""Extracts downstream users of a node which get the item at a particular index
75+
76+
Certain module-type nodes have multiple outputs (tuple of outputs). This function
77+
returns downstream nodes which call the _operator.getitem function, which extracts
78+
the element at a particular index in the tuple
79+
80+
Args:
81+
module_node: FX module-type node to analyze
82+
output_index: Indices in the module node output to search for
83+
Returns:
84+
List of nodes which get the item at the specified index in the module node output
85+
"""
86+
get_nodes = []
87+
88+
# Iterate over all downstream users of the node object
89+
for user in module_node.users:
90+
# If the user is a "get" node accessing the specified index, store it
91+
if _get_qualified_name(user.target) == "_operator.getitem" and (
92+
user.args[1] in output_indices
93+
):
94+
get_nodes.append(user)
95+
96+
return get_nodes
97+
98+
99+
def repair_long_or_double_input(
100+
gm: torch.fx.GraphModule,
101+
position: int,
102+
submodule_name: str,
103+
submodule_outputs: Union[torch.Tensor, Sequence[torch.Tensor]],
104+
dtype: torch.dtype,
105+
):
106+
"""Fixes Long/Double type inputs to TRT-accelerated subgraphs
107+
108+
In-Place modifies the provided graph
109+
110+
Inserts a cast to the 32-bit equivalent type for TRT, then if necessary,
111+
inserts an upcast back to the 64-bit type for subsequent Torch operations
112+
113+
Args:
114+
gm: FX GraphModule enclosing the TRT subgraph
115+
position: Index in the submodule inputs at which the long or double input is found
116+
submodule_name: Name of TRT-accelerated subgraph module in FX graph
117+
submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure)
118+
dtype: Data type of tensor at position in submodule (double/long)
119+
"""
120+
assert dtype in (
121+
torch.int64,
122+
torch.float64,
123+
), f"dtype argument must be torch.int64 or torch.float64, got {dtype}"
124+
125+
# Determine target data type in 32 and 64 bit forms
126+
dtype_64bit = dtype
127+
dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32
128+
129+
# Find the node representing the submodule in the graph
130+
module_node = None
131+
132+
# Iterate over all nodes in the graph, seeking target module name match
133+
for n in gm.graph.nodes:
134+
if n.op == "call_module" and str(n.target) == submodule_name:
135+
module_node = n
136+
break
137+
138+
if module_node is None:
139+
raise AssertionError(
140+
f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}"
141+
)
142+
143+
# Extract the 64-bit node of the input
144+
node_64bit = module_node.all_input_nodes[position]
145+
146+
# Prior to the module, insert a cast to the 32-bit equivalent node
147+
with gm.graph.inserting_before(module_node):
148+
node_32bit = gm.graph.call_function(
149+
torch.ops.aten._to_copy.default,
150+
args=(node_64bit,),
151+
kwargs={"dtype": dtype_32bit},
152+
)
153+
154+
# Replace 64-bit input to TRT module with new 32-bit cast node
155+
module_node.replace_input_with(node_64bit, node_32bit)
156+
157+
output_positions_64bit = set()
158+
outputs_list = (
159+
[submodule_outputs]
160+
if isinstance(submodule_outputs, torch.Tensor)
161+
else submodule_outputs
162+
)
163+
164+
# Determine if any outputs of the model are 64-bit type and store their indices
165+
for output_position, output in enumerate(outputs_list):
166+
if output.dtype == dtype_64bit:
167+
output_positions_64bit.add(output_position)
168+
169+
# Only enter this code block if there exists a 64-bit output
170+
# This implies a cast is needed, since TRT cannot output 64-bit tensors
171+
if output_positions_64bit:
172+
# Determine whther the outputs of the module are tuple-type or not
173+
is_tuple_output = False
174+
if isinstance(submodule_outputs, tuple):
175+
is_tuple_output = True
176+
177+
if not is_tuple_output:
178+
# If the output is a single tensor, insert a cast back to int64
179+
with gm.graph.inserting_after(module_node):
180+
cast_node_64bit = gm.graph.call_function(
181+
torch.ops.aten._to_copy.default,
182+
args=(module_node,),
183+
kwargs={"dtype": dtype_64bit},
184+
)
185+
186+
# Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent
187+
module_node.replace_all_uses_with(
188+
cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit)
189+
)
190+
191+
else:
192+
# If the output is a tuple of tensors, extract downstream users for each 64-bit output
193+
get_nodes = _extract_downstream_get_nodes(
194+
module_node, output_positions_64bit
195+
)
196+
197+
# For each downstream user, append a cast node back to the 64-bit precision
198+
for get_node in get_nodes:
199+
with gm.graph.inserting_after(get_node):
200+
cast_node_64bit = gm.graph.call_function(
201+
torch.ops.aten._to_copy.default,
202+
args=(get_node,),
203+
kwargs={"dtype": torch.int64},
204+
)
205+
206+
get_node.replace_all_uses_with(
207+
cast_node_64bit,
208+
delete_user_cb=lambda user: (user != cast_node_64bit),
209+
)
210+
211+
# Clean up graph and ensure invariants are preserved
212+
gm.graph.eliminate_dead_code()
213+
gm.graph.lint()
214+
gm.recompile()

0 commit comments

Comments
 (0)