Skip to content

Commit 15df89f

Browse files
committed
fix: structured inputs for CudaGraphsTorchTensorRTModule
1 parent 0a46392 commit 15df89f

File tree

1 file changed

+60
-19
lines changed

1 file changed

+60
-19
lines changed

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,48 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import List, Optional, Sequence, Tuple
4+
from typing import Any, List, Optional, Sequence, Tuple
55

66
import torch
77
import torch_tensorrt
88
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
9+
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
910
from torch_tensorrt.dynamo import partitioning
1011

1112
logger = logging.getLogger(__name__)
1213

1314

15+
def _unflatten_inputs(
16+
flattened_inputs: Sequence[torch_tensorrt.Input],
17+
compiled_module: torch.fx.GraphModule,
18+
) -> Tuple[Any, Any]:
19+
"""
20+
Process inputs using tree_unflatten and tree_map to reconstructe inputs
21+
22+
Args:
23+
flattened_inputs: Flattened input tensors to process
24+
compiled_module: The compiled GraphModule containing input specifications
25+
26+
Returns:
27+
Tuple of (args, kwargs) containing reconstructed input tensors
28+
"""
29+
30+
def convert_input_to_cuda_tensor(input: Any) -> torch.Tensor:
31+
if isinstance(input, torch_tensorrt.Input):
32+
return input.torch_tensor.cuda()
33+
else:
34+
raise RuntimeError("Input is not a torch_tensorrt.Input")
35+
36+
# Reconstruct the (args, kwargs) structure that was flattened during export
37+
pytree_inputs = tree_unflatten(flattened_inputs, compiled_module._in_spec)
38+
# Apply the tensor creation to the reconstructed structure
39+
processed_inputs = tree_map(convert_input_to_cuda_tensor, pytree_inputs)
40+
41+
# Since inputs were originally flattened from (args, kwargs),
42+
# processed_inputs is now that same tuple structure
43+
return processed_inputs[0], processed_inputs[1]
44+
45+
1446
class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
1547
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
1648
@@ -42,14 +74,15 @@ def warm_up(self) -> None:
4274
Warm up is necessary to ensure that memory allocations and initializations
4375
are not recorded in cuda graphs
4476
"""
77+
4578
with torch_tensorrt.logging.errors():
4679
with unset_fake_temporarily():
47-
inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs]
80+
args, kwargs = _unflatten_inputs(self.inputs, self.compiled_module)
4881
s = torch.cuda.Stream()
4982
s.wait_stream(torch.cuda.current_stream())
5083
with torch.cuda.stream(s):
5184
for _ in range(3):
52-
self.compiled_module(*inputs_tensor)
85+
self.compiled_module(*args, **kwargs)
5386
torch.cuda.current_stream().wait_stream(s)
5487

5588
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
@@ -73,15 +106,18 @@ def __del__(self) -> None:
73106
if self.cudagraph:
74107
self.cudagraph.reset()
75108

76-
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
109+
def forward(
110+
self, *args: Any, **kwargs: Any
111+
) -> torch.Tensor | Tuple[torch.Tensor, ...]:
112+
inputs, _ = tree_flatten((args, kwargs))
77113
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
78114
if cudagraphs_enabled:
79115
shape_changed = self.validate_input_shapes(inputs)
80116
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set
81117
if need_cudagraphs_record:
82118
if self.cudagraph:
83119
self.cudagraph.reset()
84-
self._input_buffers = [None] * len(self.inputs)
120+
self._input_buffers = [None] * len(inputs)
85121

86122
self.is_weight_streaming_set = False
87123
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
@@ -94,10 +130,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
94130
for i in inputs
95131
]
96132
assert len(contiguous_inputs) == len(
97-
self.inputs
98-
), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}."
133+
inputs
134+
), f"Wrong number of inputs, expect {len(inputs)} get {len(contiguous_inputs)}."
99135

100-
for i, _ in enumerate(self.inputs):
136+
for i, _ in enumerate(inputs):
101137
if not contiguous_inputs[i].is_cuda:
102138
logger.warning(
103139
f"Detected input[{i}] is not on a cuda device. "
@@ -112,15 +148,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
112148
)
113149

114150
assert (
115-
contiguous_inputs[i].dtype == self.inputs[i].dtype
116-
), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
151+
contiguous_inputs[i].dtype == inputs[i].dtype
152+
), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
153+
154+
if need_cudagraphs_record:
155+
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
156+
# Clone is required to avoid re-using user-provided GPU memory
157+
self._input_buffers[i] = contiguous_inputs[i].clone()
158+
else:
159+
self._input_buffers[i].copy_(contiguous_inputs[i])
117160

118161
if need_cudagraphs_record:
119-
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
120-
# Clone is required to avoid re-using user-provided GPU memory
121-
self._input_buffers[i] = contiguous_inputs[i].clone()
122-
else:
123-
self._input_buffers[i].copy_(contiguous_inputs[i])
162+
# Reconstruct the original args and kwargs structure from static input buffers
163+
# using the input specification stored during module compilation
164+
args, kwargs = tree_unflatten(
165+
self._input_buffers, self.compiled_module._in_spec
166+
)
124167

125168
self._caller_stream = torch.cuda.current_stream()
126169
if (
@@ -135,9 +178,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
135178
if need_cudagraphs_record:
136179
self.cudagraph = torch.cuda.CUDAGraph()
137180
with torch.cuda.graph(self.cudagraph, stream=self._engine_stream):
138-
self._output_buffers = self.compiled_module(
139-
*self._input_buffers
140-
)
181+
self._output_buffers = self.compiled_module(*args, **kwargs)
141182

142183
self.cudagraph.replay() # type: ignore
143184
self._caller_stream.wait_stream(self._engine_stream)
@@ -154,4 +195,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
154195
if self.cudagraph:
155196
self.cudagraph.reset()
156197
self.cudagraph = None
157-
return self.compiled_module(*inputs)
198+
return self.compiled_module(*args, **kwargs)

0 commit comments

Comments
 (0)