Skip to content

Commit 26a3059

Browse files
committed
fix: Add support for truncate_long_and_double
- Add Dynamo compile support for `truncate_long_and_double` compilation argument by intercepting long/double type inputs and casting them to their 32-bit counterparts prior to usage in TRT-accelerated subgraphs, then casting back if necessary - Add robust logic to handle 64-bit inputs and outputs - Add test cases for long and double scenarios - Centralize truncation utility for later use in Dynamo export path
1 parent 7fc742f commit 26a3059

File tree

5 files changed

+322
-177
lines changed

5 files changed

+322
-177
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
get_submod_inputs,
1818
)
1919

20-
from torch_tensorrt.dynamo.backend.utils import repair_long_or_double_input
20+
from torch_tensorrt.dynamo.common import repair_long_or_double_inputs
2121
from torch_tensorrt.dynamo.backend.conversion import convert_module
2222

2323
from torch._dynamo.backends.common import fake_tensor_unsupported
@@ -168,33 +168,9 @@ def _compile_module(
168168

169169
# Handle long/double inputs if requested by the user
170170
if settings.truncate_long_and_double:
171-
num_submodule_inputs = len(submodule_inputs)
172-
173-
# For each input to the TRT subgraph, check if its type is long/double
174-
for position in range(num_submodule_inputs):
175-
param = submodule_inputs[position]
176-
177-
# If the data type of the input is long/double, insert necessary
178-
# casts to replace the operation
179-
if param.dtype in (torch.int64, torch.float64):
180-
submodule_outputs = submodule(*submodule_inputs)
181-
repair_long_or_double_input(
182-
partitioned_module,
183-
position,
184-
name,
185-
submodule_outputs,
186-
param.dtype,
187-
)
188-
189-
# Repair submodule inputs in accordance with inserted casts
190-
dtype_32bit = (
191-
torch.int32 if (param.dtype == torch.int64) else torch.float32
192-
)
193-
submodule_inputs = (
194-
submodule_inputs[:position]
195-
+ (param.to(dtype_32bit),)
196-
+ submodule_inputs[position + 1 :]
197-
)
171+
submodule_inputs = repair_long_or_double_inputs(
172+
partitioned_module, submodule, submodule_inputs
173+
)
198174

199175
# Create TRT Module from submodule
200176
trt_mod = convert_module(

py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from copy import deepcopy
55
from torch_tensorrt.dynamo import compile
66
from utils import lower_graph_testing
7-
from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT
7+
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT
88

99

1010
class TestTRTModuleNextCompilation(TestCase):
@@ -169,5 +169,116 @@ def forward(self, x, y):
169169
)
170170

171171

172+
class Test64BitInput(TestCase):
173+
def test_float64_input_full_support(self):
174+
class FullySupportedMultiOp(torch.nn.Module):
175+
def forward(self, x, y):
176+
return torch.ops.aten.mean.dim(
177+
torch.ops.aten.mul.Tensor(torch.ops.aten.add.Tensor(x, y), 2), [0]
178+
)
179+
180+
fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
181+
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3)
182+
183+
self.assertEquals(
184+
len(list(partitioned_graph.named_children())),
185+
1,
186+
"All operators are supported, there should be one segment",
187+
)
188+
189+
inputs = [
190+
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
191+
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
192+
]
193+
194+
torch._dynamo.reset()
195+
196+
# Validate that the results between Torch and Torch-TRT are similar
197+
optimized_model = compile(
198+
fx_graph,
199+
inputs,
200+
min_block_size=1,
201+
pass_through_build_failures=True,
202+
truncate_long_and_double=True,
203+
debug=True,
204+
)
205+
optimized_model_results = optimized_model(*inputs).detach().cpu()
206+
torch_model_results = fx_graph(*inputs).detach().cpu()
207+
208+
max_diff = float(
209+
torch.max(torch.abs(optimized_model_results - torch_model_results))
210+
)
211+
self.assertAlmostEqual(
212+
max_diff,
213+
0,
214+
DECIMALS_OF_AGREEMENT,
215+
f"TRT outputs don't match with the original model.",
216+
)
217+
218+
def test_int64_input_partial_support(self):
219+
class PartiallySupportedMultiOp(torch.nn.Module):
220+
def forward(self, x, y):
221+
return torch.ops.aten.div.Tensor_mode(
222+
x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor"
223+
)
224+
225+
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
226+
unexpected_ops = {torch.ops.aten.add.Tensor}
227+
228+
inputs = [
229+
torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(),
230+
torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(),
231+
]
232+
233+
(unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing(
234+
fx_graph,
235+
inputs,
236+
unexpected_ops=unexpected_ops,
237+
min_block_size=1,
238+
torch_executed_ops={"torch.ops.aten.add.Tensor"},
239+
testing_partitioning=True,
240+
)
241+
242+
self.assertEquals(
243+
len(unexpected_ops_seen),
244+
0,
245+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
246+
)
247+
self.assertEquals(
248+
len(partitioned_graphs),
249+
1,
250+
"Without control flow breaks, there should only be a single graph",
251+
)
252+
self.assertEquals(
253+
len(list(partitioned_graphs[0].named_children())),
254+
1,
255+
"Certain operators are set to run in Torch, expected 1 segment",
256+
)
257+
258+
torch._dynamo.reset()
259+
260+
# Validate that the results between Torch and Torch-TRT are similar
261+
optimized_model = compile(
262+
fx_graph,
263+
inputs,
264+
min_block_size=1,
265+
pass_through_build_failures=True,
266+
truncate_long_and_double=True,
267+
debug=True,
268+
)
269+
optimized_model_results = optimized_model(*inputs).detach().cpu()
270+
torch_model_results = fx_graph(*inputs).detach().cpu()
271+
272+
max_diff = float(
273+
torch.max(torch.abs(optimized_model_results - torch_model_results))
274+
)
275+
self.assertAlmostEqual(
276+
max_diff,
277+
0,
278+
DECIMALS_OF_AGREEMENT,
279+
f"TRT outputs don't match with the original model.",
280+
)
281+
282+
172283
if __name__ == "__main__":
173284
run_tests()
Lines changed: 0 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
from torch.fx.node import _get_qualified_name
32
from typing import Any, Union, Sequence, Dict
43
from torch_tensorrt import _Input, Device
54

@@ -66,149 +65,3 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
6665
)
6766

6867
return device
69-
70-
71-
def _extract_downstream_get_nodes(
72-
module_node: torch.fx.Node, output_indices: Sequence[int]
73-
) -> Sequence[torch.fx.Node]:
74-
"""Extracts downstream users of a node which get the item at a particular index
75-
76-
Certain module-type nodes have multiple outputs (tuple of outputs). This function
77-
returns downstream nodes which call the _operator.getitem function, which extracts
78-
the element at a particular index in the tuple
79-
80-
Args:
81-
module_node: FX module-type node to analyze
82-
output_index: Indices in the module node output to search for
83-
Returns:
84-
List of nodes which get the item at the specified index in the module node output
85-
"""
86-
get_nodes = []
87-
88-
# Iterate over all downstream users of the node object
89-
for user in module_node.users:
90-
# If the user is a "get" node accessing the specified index, store it
91-
if _get_qualified_name(user.target) == "_operator.getitem" and (
92-
user.args[1] in output_indices
93-
):
94-
get_nodes.append(user)
95-
96-
return get_nodes
97-
98-
99-
def repair_long_or_double_input(
100-
gm: torch.fx.GraphModule,
101-
position: int,
102-
submodule_name: str,
103-
submodule_outputs: Union[torch.Tensor, Sequence[torch.Tensor]],
104-
dtype: torch.dtype,
105-
):
106-
"""Fixes Long/Double type inputs to TRT-accelerated subgraphs
107-
108-
In-Place modifies the provided graph
109-
110-
Inserts a cast to the 32-bit equivalent type for TRT, then if necessary,
111-
inserts an upcast back to the 64-bit type for subsequent Torch operations
112-
113-
Args:
114-
gm: FX GraphModule enclosing the TRT subgraph
115-
position: Index in the submodule inputs at which the long or double input is found
116-
submodule_name: Name of TRT-accelerated subgraph module in FX graph
117-
submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure)
118-
dtype: Data type of tensor at position in submodule (double/long)
119-
"""
120-
assert dtype in (
121-
torch.int64,
122-
torch.float64,
123-
), f"dtype argument must be torch.int64 or torch.float64, got {dtype}"
124-
125-
# Determine target data type in 32 and 64 bit forms
126-
dtype_64bit = dtype
127-
dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32
128-
129-
# Find the node representing the submodule in the graph
130-
module_node = None
131-
132-
# Iterate over all nodes in the graph, seeking target module name match
133-
for n in gm.graph.nodes:
134-
if n.op == "call_module" and str(n.target) == submodule_name:
135-
module_node = n
136-
break
137-
138-
if module_node is None:
139-
raise AssertionError(
140-
f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}"
141-
)
142-
143-
# Extract the 64-bit node of the input
144-
node_64bit = module_node.all_input_nodes[position]
145-
146-
# Prior to the module, insert a cast to the 32-bit equivalent node
147-
with gm.graph.inserting_before(module_node):
148-
node_32bit = gm.graph.call_function(
149-
torch.ops.aten._to_copy.default,
150-
args=(node_64bit,),
151-
kwargs={"dtype": dtype_32bit},
152-
)
153-
154-
# Replace 64-bit input to TRT module with new 32-bit cast node
155-
module_node.replace_input_with(node_64bit, node_32bit)
156-
157-
output_positions_64bit = set()
158-
outputs_list = (
159-
[submodule_outputs]
160-
if isinstance(submodule_outputs, torch.Tensor)
161-
else submodule_outputs
162-
)
163-
164-
# Determine if any outputs of the model are 64-bit type and store their indices
165-
for output_position, output in enumerate(outputs_list):
166-
if output.dtype == dtype_64bit:
167-
output_positions_64bit.add(output_position)
168-
169-
# Only enter this code block if there exists a 64-bit output
170-
# This implies a cast is needed, since TRT cannot output 64-bit tensors
171-
if output_positions_64bit:
172-
# Determine whther the outputs of the module are tuple-type or not
173-
is_tuple_output = False
174-
if isinstance(submodule_outputs, tuple):
175-
is_tuple_output = True
176-
177-
if not is_tuple_output:
178-
# If the output is a single tensor, insert a cast back to int64
179-
with gm.graph.inserting_after(module_node):
180-
cast_node_64bit = gm.graph.call_function(
181-
torch.ops.aten._to_copy.default,
182-
args=(module_node,),
183-
kwargs={"dtype": dtype_64bit},
184-
)
185-
186-
# Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent
187-
module_node.replace_all_uses_with(
188-
cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit)
189-
)
190-
191-
else:
192-
# If the output is a tuple of tensors, extract downstream users for each 64-bit output
193-
get_nodes = _extract_downstream_get_nodes(
194-
module_node, output_positions_64bit
195-
)
196-
197-
# For each downstream user, append a cast node back to the 64-bit precision
198-
for get_node in get_nodes:
199-
with gm.graph.inserting_after(get_node):
200-
cast_node_64bit = gm.graph.call_function(
201-
torch.ops.aten._to_copy.default,
202-
args=(get_node,),
203-
kwargs={"dtype": torch.int64},
204-
)
205-
206-
get_node.replace_all_uses_with(
207-
cast_node_64bit,
208-
delete_user_cb=lambda user: (user != cast_node_64bit),
209-
)
210-
211-
# Clean up graph and ensure invariants are preserved
212-
gm.graph.eliminate_dead_code()
213-
gm.graph.lint()
214-
gm.recompile()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from ._settings import CompilationSettings
2-
32
from .fx2trt import TRTInterpreter, TRTInterpreterResult
43
from .input_tensor_spec import InputTensorSpec
4+
from .truncate_long_and_double import repair_long_or_double_inputs

0 commit comments

Comments
 (0)