1
1
from __future__ import annotations
2
2
3
3
import logging
4
- from typing import List , Optional , Sequence , Tuple
4
+ from typing import Any , List , Optional , Sequence , Tuple
5
5
6
6
import torch
7
7
import torch_tensorrt
8
8
from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
9
+ from torch .utils ._pytree import tree_flatten , tree_map , tree_unflatten
9
10
from torch_tensorrt .dynamo import partitioning
10
11
11
12
logger = logging .getLogger (__name__ )
12
13
13
14
15
+ def _unflatten_inputs (
16
+ flattened_inputs : Sequence [torch_tensorrt .Input ],
17
+ compiled_module : torch .fx .GraphModule ,
18
+ ) -> Tuple [Any , Any ]:
19
+ """
20
+ Process inputs using tree_unflatten and tree_map to reconstructe inputs
21
+
22
+ Args:
23
+ flattened_inputs: Flattened input tensors to process
24
+ compiled_module: The compiled GraphModule containing input specifications
25
+
26
+ Returns:
27
+ Tuple of (args, kwargs) containing reconstructed input tensors
28
+ """
29
+
30
+ def convert_input_to_cuda_tensor (input : Any ) -> torch .Tensor :
31
+ if isinstance (input , torch_tensorrt .Input ):
32
+ return input .torch_tensor .cuda ()
33
+ else :
34
+ raise RuntimeError ("Input is not a torch_tensorrt.Input" )
35
+
36
+ # Reconstruct the (args, kwargs) structure that was flattened during export
37
+ pytree_inputs = tree_unflatten (flattened_inputs , compiled_module ._in_spec )
38
+ # Apply the tensor creation to the reconstructed structure
39
+ processed_inputs = tree_map (convert_input_to_cuda_tensor , pytree_inputs )
40
+
41
+ # Since inputs were originally flattened from (args, kwargs),
42
+ # processed_inputs is now that same tuple structure
43
+ return processed_inputs [0 ], processed_inputs [1 ]
44
+
45
+
14
46
class CudaGraphsTorchTensorRTModule (torch .nn .Module ): # type: ignore[misc]
15
47
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
16
48
@@ -42,14 +74,15 @@ def warm_up(self) -> None:
42
74
Warm up is necessary to ensure that memory allocations and initializations
43
75
are not recorded in cuda graphs
44
76
"""
77
+
45
78
with torch_tensorrt .logging .errors ():
46
79
with unset_fake_temporarily ():
47
- inputs_tensor = [ spec . torch_tensor . cuda () for spec in self .inputs ]
80
+ args , kwargs = _unflatten_inputs ( self .inputs , self . compiled_module )
48
81
s = torch .cuda .Stream ()
49
82
s .wait_stream (torch .cuda .current_stream ())
50
83
with torch .cuda .stream (s ):
51
84
for _ in range (3 ):
52
- self .compiled_module (* inputs_tensor )
85
+ self .compiled_module (* args , ** kwargs )
53
86
torch .cuda .current_stream ().wait_stream (s )
54
87
55
88
def validate_input_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
@@ -73,15 +106,18 @@ def __del__(self) -> None:
73
106
if self .cudagraph :
74
107
self .cudagraph .reset ()
75
108
76
- def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
109
+ def forward (
110
+ self , * args : Any , ** kwargs : Any
111
+ ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
112
+ inputs , _ = tree_flatten ((args , kwargs ))
77
113
cudagraphs_enabled = torch_tensorrt .runtime .get_whole_cudagraphs_mode ()
78
114
if cudagraphs_enabled :
79
115
shape_changed = self .validate_input_shapes (inputs )
80
116
need_cudagraphs_record = shape_changed or self .is_weight_streaming_set
81
117
if need_cudagraphs_record :
82
118
if self .cudagraph :
83
119
self .cudagraph .reset ()
84
- self ._input_buffers = [None ] * len (self . inputs )
120
+ self ._input_buffers = [None ] * len (inputs )
85
121
86
122
self .is_weight_streaming_set = False
87
123
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
@@ -94,10 +130,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
94
130
for i in inputs
95
131
]
96
132
assert len (contiguous_inputs ) == len (
97
- self . inputs
98
- ), f"Wrong number of inputs, expect { len (self . inputs )} get { len (contiguous_inputs )} ."
133
+ inputs
134
+ ), f"Wrong number of inputs, expect { len (inputs )} get { len (contiguous_inputs )} ."
99
135
100
- for i , _ in enumerate (self . inputs ):
136
+ for i , _ in enumerate (inputs ):
101
137
if not contiguous_inputs [i ].is_cuda :
102
138
logger .warning (
103
139
f"Detected input[{ i } ] is not on a cuda device. "
@@ -112,8 +148,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
112
148
)
113
149
114
150
assert (
115
- contiguous_inputs [i ].dtype == self . inputs [i ].dtype
116
- ), f"Dtype mismatch for { i } th input. Expect { self . inputs [i ].dtype } , got { contiguous_inputs [i ].dtype } ."
151
+ contiguous_inputs [i ].dtype == inputs [i ].dtype
152
+ ), f"Dtype mismatch for { i } th input. Expect { inputs [i ].dtype } , got { contiguous_inputs [i ].dtype } ."
117
153
118
154
if need_cudagraphs_record :
119
155
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
@@ -122,6 +158,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
122
158
else :
123
159
self ._input_buffers [i ].copy_ (contiguous_inputs [i ])
124
160
161
+ if need_cudagraphs_record :
162
+ # Reconstruct the original args and kwargs structure from static input buffers
163
+ # using the input specification stored during module compilation
164
+ args , kwargs = tree_unflatten (
165
+ self ._input_buffers , self .compiled_module ._in_spec
166
+ )
167
+
125
168
self ._caller_stream = torch .cuda .current_stream ()
126
169
if (
127
170
self ._engine_stream == torch .cuda .default_stream ()
@@ -135,9 +178,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
135
178
if need_cudagraphs_record :
136
179
self .cudagraph = torch .cuda .CUDAGraph ()
137
180
with torch .cuda .graph (self .cudagraph , stream = self ._engine_stream ):
138
- self ._output_buffers = self .compiled_module (
139
- * self ._input_buffers
140
- )
181
+ self ._output_buffers = self .compiled_module (* args , ** kwargs )
141
182
142
183
self .cudagraph .replay () # type: ignore
143
184
self ._caller_stream .wait_stream (self ._engine_stream )
@@ -154,4 +195,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
154
195
if self .cudagraph :
155
196
self .cudagraph .reset ()
156
197
self .cudagraph = None
157
- return self .compiled_module (* inputs )
198
+ return self .compiled_module (* args , ** kwargs )
0 commit comments