Skip to content

Commit 7e22f61

Browse files
committed
chore: infer output shape from compiled module
1 parent 7585fc8 commit 7e22f61

File tree

5 files changed

+47
-25
lines changed

5 files changed

+47
-25
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
838838
if len(dryrun_tracker.to_run_in_torch) > 0:
839839
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
840840
partitioned_module = WrapperTorchTensorRTModule(
841-
gm,
842841
partitioned_module,
843842
dryrun_tracker.output_shapes,
844843
dryrun_tracker.output_dtypes,

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ def set_default_device_memory_budget(self) -> int:
152152
return self._set_device_memory_budget(budget_bytes)
153153

154154
def set_whole_cudagraphs(self, enable: bool) -> None:
155+
"""
156+
When the global CUDA graphs mode is enabled, the parent wrapper module handles all
157+
CUDA graph recording and replay. Therefore, any child modules must disable their
158+
own CUDA graph functionality to avoid conflicts.
159+
"""
155160
self.whole_cudagraphs = enable
156161

157162
def setup_engine(self) -> None:
@@ -245,7 +250,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
245250
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
246251
for i in inputs
247252
]
248-
253+
# TODO: calculate output shape under fakeTensorMode
254+
# fake_mode = detect_fake_mode(*inputs)
249255
with (
250256
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
251257
if self.profiling_enabled

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def __init__(
131131
self.weight_name_map = weight_name_map
132132
self.serialized_engine = serialized_engine
133133
self.engine = None
134-
self.cudagraphs_enabled_parent_module = False
135134

136135
if (
137136
serialized_engine
@@ -197,6 +196,11 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
197196
return budget_bytes
198197

199198
def set_whole_cudagraphs(self, enable: bool) -> None:
199+
"""
200+
When the global CUDA graphs mode is enabled, the parent wrapper module handles all
201+
CUDA graph recording and replay. Therefore, any child modules must disable their
202+
own CUDA graph functionality to avoid conflicts.
203+
"""
200204
self.engine.set_whole_cudagraphs(enable)
201205

202206
def setup_engine(self) -> None:

py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,11 @@ class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
3131

3232
def __init__(
3333
self,
34-
original_module: torch.nn.Module,
3534
compiled_module: torch.nn.Module,
3635
output_shapes: List[torch.Size],
3736
output_dtypes: List[torch.dtype],
3837
):
3938
super(WrapperTorchTensorRTModule, self).__init__()
40-
self.original_module = original_module
4139
self.compiled_module = compiled_module
4240
self.inputs = partitioning.construct_submodule_inputs(compiled_module)
4341
self.output_shapes = output_shapes
@@ -48,7 +46,7 @@ def __init__(
4846
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
4947
self.shape_key: Optional[str] = None
5048
self.profiling_enabled = False
51-
self.cudagraphs_enabled = False
49+
self.prev_cudagraphs_enabled = False
5250
self._caller_stream: Optional[torch.cuda.Stream] = None
5351
self._engine_stream: Optional[torch.cuda.Stream] = None
5452
self.input_is_dynamic = input_is_dynamic(self.inputs)
@@ -57,20 +55,27 @@ def __init__(
5755
for name, rt_mod in self.compiled_module.named_children():
5856
if "_run_on_acc" in name:
5957
rt_mod.set_whole_cudagraphs(True)
58+
self.warm_up()
6059

61-
# Warm up is necessary to ensure that memory allocations and initializations are not recorded in cuda graphs
62-
with unset_fake_temporarily():
63-
inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs]
64-
s = torch.cuda.Stream()
65-
s.wait_stream(torch.cuda.current_stream())
66-
with torch.cuda.stream(s):
67-
for _ in range(3):
68-
self.compiled_module(*inputs_tensor)
69-
torch.cuda.current_stream().wait_stream(s)
60+
def warm_up(self) -> None:
61+
"""
62+
Warm up is necessary to ensure that memory allocations and initializations
63+
are not recorded in cuda graphs
64+
"""
65+
with torch_tensorrt.logging.errors():
66+
with unset_fake_temporarily():
67+
inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs]
68+
s = torch.cuda.Stream()
69+
s.wait_stream(torch.cuda.current_stream())
70+
with torch.cuda.stream(s):
71+
for _ in range(3):
72+
self.compiled_module(*inputs_tensor)
73+
torch.cuda.current_stream().wait_stream(s)
7074

7175
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
7276
"""
7377
Validates the input shapes of the forward function has changed
78+
And infer output shapes if dynamic input shape has changed.
7479
"""
7580
# Representation of input shapes to a given model
7681
# Shapes are concatenated as so:
@@ -83,13 +88,12 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
8388
self.shape_key = new_shape_key
8489

8590
if self.input_is_dynamic:
86-
with FakeTensorMode() as mode:
87-
fake_inputs = [mode.from_tensor(input) for input in inputs]
88-
tmp_outputs = self.original_module(*fake_inputs)
91+
with FakeTensorMode(allow_non_fake_inputs=True):
92+
tmp_outputs = self.compiled_module(*inputs)
8993
if not isinstance(tmp_outputs, (list, tuple)):
9094
tmp_outputs = [tmp_outputs]
9195
self.output_shapes = [tuple(output.shape) for output in tmp_outputs]
92-
96+
print("self.output_shapes ", self.output_shapes)
9397
return True
9498

9599
return False
@@ -114,11 +118,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
114118
shape_changed = self.validate_input_shapes(inputs)
115119
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
116120
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
117-
if not self.cudagraphs_enabled and cudagraphs_enabled:
118-
need_cudagraphs_record = True
119-
else:
120-
need_cudagraphs_record = cudagraphs_enabled and shape_changed
121-
self.cudagraphs_enabled = cudagraphs_enabled
121+
need_cudagraphs_record = cudagraphs_enabled and (
122+
(not self.prev_cudagraphs_enabled) or shape_changed
123+
)
124+
self.prev_cudagraphs_enabled = cudagraphs_enabled
122125

123126
if need_cudagraphs_record:
124127
if self.cudagraph:
@@ -282,4 +285,5 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
282285

283286
return outputs
284287
else:
288+
285289
return outputs

py/torch_tensorrt/dynamo/runtime/register_fake_class.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, List
44

55
import torch
6+
from torch._library.fake_class_registry import FakeScriptObject
67
from torch_tensorrt.dynamo.utils import input_is_dynamic, unwrap_tensor_shape
78

89

@@ -26,7 +27,12 @@ def fake_tensorrt_execute_engine(
2627
modes = ["opt"]
2728

2829
# Get the TRTEngine class and infer output shapes based on input shapes
29-
trt_engine = fake_trt_engine.wrapped_obj.engine
30+
# If fake_trt_engine is not FakeScriptObject, assumes that it is the real object
31+
if isinstance(fake_trt_engine, FakeScriptObject):
32+
trt_engine = fake_trt_engine.wrapped_obj.engine
33+
else:
34+
trt_engine = fake_trt_engine
35+
3036
outputs_mode_dict = defaultdict(list)
3137
for mode in modes:
3238
input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs]
@@ -125,5 +131,8 @@ def automatic_device_memory_budget_getter(self) -> Any:
125131
def infer_outputs(self, input_shapes: List[Any]) -> Any:
126132
pass
127133

134+
def set_whole_cudagraphs(self) -> Any:
135+
pass
136+
128137
def __setstate__(self, serialized_state: List[str]) -> Any:
129138
pass

0 commit comments

Comments
 (0)