Skip to content

Commit 0d212af

Browse files
committed
fix: structured inputs for CudaGraphsTorchTensorRTModule
1 parent 7dbd4cb commit 0d212af

File tree

1 file changed

+55
-14
lines changed

1 file changed

+55
-14
lines changed

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,48 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import List, Optional, Sequence, Tuple
4+
from typing import Any, List, Optional, Sequence, Tuple
55

66
import torch
77
import torch_tensorrt
88
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
9+
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
910
from torch_tensorrt.dynamo import partitioning
1011

1112
logger = logging.getLogger(__name__)
1213

1314

15+
def _unflatten_inputs(
16+
flattened_inputs: Sequence[torch_tensorrt.Input],
17+
compiled_module: torch.fx.GraphModule,
18+
) -> Tuple[Any, Any]:
19+
"""
20+
Process inputs using tree_unflatten and tree_map to reconstructe inputs
21+
22+
Args:
23+
flattened_inputs: Flattened input tensors to process
24+
compiled_module: The compiled GraphModule containing input specifications
25+
26+
Returns:
27+
Tuple of (args, kwargs) containing reconstructed input tensors
28+
"""
29+
30+
def convert_input_to_cuda_tensor(input: Any) -> torch.Tensor:
31+
if isinstance(input, torch_tensorrt.Input):
32+
return input.torch_tensor.cuda()
33+
else:
34+
raise RuntimeError("Input is not a torch_tensorrt.Input")
35+
36+
# Reconstruct the (args, kwargs) structure that was flattened during export
37+
pytree_inputs = tree_unflatten(flattened_inputs, compiled_module._in_spec)
38+
# Apply the tensor creation to the reconstructed structure
39+
processed_inputs = tree_map(convert_input_to_cuda_tensor, pytree_inputs)
40+
41+
# Since inputs were originally flattened from (args, kwargs),
42+
# processed_inputs is now that same tuple structure
43+
return processed_inputs[0], processed_inputs[1]
44+
45+
1446
class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
1547
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
1648
@@ -43,14 +75,15 @@ def warm_up(self) -> None:
4375
Warm up is necessary to ensure that memory allocations and initializations
4476
are not recorded in cuda graphs
4577
"""
78+
4679
with torch_tensorrt.logging.errors():
4780
with unset_fake_temporarily():
48-
inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs]
81+
args, kwargs = _unflatten_inputs(self.inputs, self.compiled_module)
4982
s = torch.cuda.Stream()
5083
s.wait_stream(torch.cuda.current_stream())
5184
with torch.cuda.stream(s):
5285
for _ in range(3):
53-
self.compiled_module(*inputs_tensor)
86+
self.compiled_module(*args, **kwargs)
5487
torch.cuda.current_stream().wait_stream(s)
5588

5689
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
@@ -77,15 +110,18 @@ def __del__(self) -> None:
77110
def set_use_output_allocator(self, enable: bool) -> None:
78111
self.use_output_allocator_outputs = enable
79112

80-
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
113+
def forward(
114+
self, *args: Any, **kwargs: Any
115+
) -> torch.Tensor | Tuple[torch.Tensor, ...]:
116+
inputs, _ = tree_flatten((args, kwargs))
81117
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
82118
if cudagraphs_enabled:
83119
shape_changed = self.validate_input_shapes(inputs)
84120
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set
85121
if need_cudagraphs_record:
86122
if self.cudagraph:
87123
self.cudagraph.reset()
88-
self._input_buffers = [None] * len(self.inputs)
124+
self._input_buffers = [None] * len(inputs)
89125

90126
self.is_weight_streaming_set = False
91127
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
@@ -98,10 +134,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
98134
for i in inputs
99135
]
100136
assert len(contiguous_inputs) == len(
101-
self.inputs
102-
), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}."
137+
inputs
138+
), f"Wrong number of inputs, expect {len(inputs)} get {len(contiguous_inputs)}."
103139

104-
for i, _ in enumerate(self.inputs):
140+
for i, _ in enumerate(inputs):
105141
if not contiguous_inputs[i].is_cuda:
106142
logger.warning(
107143
f"Detected input[{i}] is not on a cuda device. "
@@ -116,8 +152,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
116152
)
117153

118154
assert (
119-
contiguous_inputs[i].dtype == self.inputs[i].dtype
120-
), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
155+
contiguous_inputs[i].dtype == inputs[i].dtype
156+
), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
121157

122158
if need_cudagraphs_record:
123159
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
@@ -126,6 +162,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
126162
else:
127163
self._input_buffers[i].copy_(contiguous_inputs[i])
128164

165+
if need_cudagraphs_record:
166+
# Reconstruct the original args and kwargs structure from static input buffers
167+
# using the input specification stored during module compilation
168+
args, kwargs = tree_unflatten(
169+
self._input_buffers, self.compiled_module._in_spec
170+
)
171+
129172
self._caller_stream = torch.cuda.current_stream()
130173
if (
131174
self._engine_stream == torch.cuda.default_stream()
@@ -139,9 +182,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
139182
if need_cudagraphs_record:
140183
self.cudagraph = torch.cuda.CUDAGraph()
141184
with torch.cuda.graph(self.cudagraph, stream=self._engine_stream):
142-
self._output_buffers = self.compiled_module(
143-
*self._input_buffers
144-
)
185+
self._output_buffers = self.compiled_module(*args, **kwargs)
145186

146187
self.cudagraph.replay() # type: ignore
147188
self._caller_stream.wait_stream(self._engine_stream)
@@ -158,4 +199,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
158199
if self.cudagraph:
159200
self.cudagraph.reset()
160201
self.cudagraph = None
161-
return self.compiled_module(*inputs)
202+
return self.compiled_module(*args, **kwargs)

0 commit comments

Comments
 (0)