Skip to content

Commit 7585fc8

Browse files
committed
chore: added missing test
1 parent ce063a2 commit 7585fc8

File tree

11 files changed

+251
-35
lines changed

11 files changed

+251
-35
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ TRTEngine::TRTEngine(
213213
LOG_DEBUG(*this);
214214
}
215215

216-
void TRTEngine::set_cudagraphs_enabled_parent_module(bool enable) {
217-
cudagraphs_enabled_parent_module = enable;
216+
void TRTEngine::set_whole_cudagraphs(bool enable) {
217+
whole_cudagraphs = enable;
218218
}
219219

220220
TRTEngine::~TRTEngine() {

core/runtime/TRTEngine.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ struct TRTEngine : torch::CustomClassHolder {
8787
bool set_device_memory_budget(int64_t budget);
8888
int64_t get_streamable_device_memory_budget();
8989
int64_t get_automatic_device_memory_budget();
90-
void set_cudagraphs_enabled_parent_module(bool enable);
90+
void set_whole_cudagraphs(bool enable);
9191
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
9292
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
9393
static const char BINDING_DELIM = '%';
@@ -103,8 +103,8 @@ struct TRTEngine : torch::CustomClassHolder {
103103
std::vector<at::Tensor> input_buffers = {};
104104
std::vector<at::Tensor> output_buffers = {};
105105
std::string shape_key;
106-
bool cudagraphs_enabled = false;
107-
bool cudagraphs_enabled_parent_module = false;
106+
bool prev_cudagraphs_enabled = false;
107+
bool whole_cudagraphs = false;
108108
// TODO: Implement a call method
109109
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
110110

core/runtime/execute_engine.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,14 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
113113
LOG_INFO("" << log_info);
114114
compiled_engine->cudagraph.enable_debug_mode();
115115
}
116-
bool cudagraphs_enabled = (!compiled_engine->cudagraphs_enabled_parent_module && CUDAGRAPHS_MODE);
116+
bool cudagraphs_enabled = (!compiled_engine->whole_cudagraphs && CUDAGRAPHS_MODE);
117117

118118
// Whether cudagraphs needs to record the graph on this pass
119119
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
120-
bool need_cudagraphs_record =
121-
(((!compiled_engine->cudagraphs_enabled) && cudagraphs_enabled) ||
122-
(cudagraphs_enabled && (!_cudagraphs_validate_shapes(inputs, compiled_engine))));
123-
compiled_engine->cudagraphs_enabled = cudagraphs_enabled;
120+
bool need_cudagraphs_record = cudagraphs_enabled &&
121+
((!compiled_engine->prev_cudagraphs_enabled) || (!_cudagraphs_validate_shapes(inputs, compiled_engine)));
122+
123+
compiled_engine->prev_cudagraphs_enabled = cudagraphs_enabled;
124124

125125
if (!cudagraphs_enabled) {
126126
compiled_engine->cudagraph.reset();

core/runtime/register_jit_hooks.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8787
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
8888
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
8989
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
90-
.def("set_cudagraphs_enabled_parent_module", &TRTEngine::set_cudagraphs_enabled_parent_module)
90+
.def("set_whole_cudagraphs", &TRTEngine::set_whole_cudagraphs)
9191
.def("infer_outputs", &TRTEngine::infer_outputs)
9292
.def_property(
9393
"device_memory_budget",

py/torch_tensorrt/_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def save(
598598
This flag is experimental for now.
599599
"""
600600
if isinstance(module, WrapperTorchTensorRTModule):
601-
module = module.original_module
601+
module = module.compiled_module
602602
module_type = _parse_module_type(module)
603603
accepted_formats = {"exported_program", "torchscript"}
604604
if arg_inputs is not None and not all(

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
838838
if len(dryrun_tracker.to_run_in_torch) > 0:
839839
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
840840
partitioned_module = WrapperTorchTensorRTModule(
841+
gm,
841842
partitioned_module,
842843
dryrun_tracker.output_shapes,
843844
dryrun_tracker.output_dtypes,

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ def __init__(
108108
self.weight_name_map = weight_name_map
109109
self.target_platform = Platform.current_platform()
110110
# Check if CUDA graph capture is enabled in the parent node
111-
self.cudagraphs_enabled_parent_module = False
112-
self.cudagraphs_enabled = False
111+
self.whole_cudagraphs = False
112+
# Previous cuda graphs state
113+
self.prev_cudagraphs_enabled = False
113114

114115
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
115116
self.setup_engine()
@@ -150,8 +151,8 @@ def set_default_device_memory_budget(self) -> int:
150151
logger.debug(f"Weight streaming budget set to {budget_bytes}B")
151152
return self._set_device_memory_budget(budget_bytes)
152153

153-
def set_cudagraphs_enabled_parent_module(self, enable: bool) -> None:
154-
self.cudagraphs_enabled_parent_module = enable
154+
def set_whole_cudagraphs(self, enable: bool) -> None:
155+
self.whole_cudagraphs = enable
155156

156157
def setup_engine(self) -> None:
157158
assert (
@@ -254,16 +255,14 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
254255

255256
cudagraphs_enabled = (
256257
torch_tensorrt.runtime.get_cudagraphs_mode()
257-
and not self.cudagraphs_enabled_parent_module
258+
and not self.whole_cudagraphs
258259
)
259260
# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
260-
if not self.cudagraphs_enabled and cudagraphs_enabled:
261-
need_cudagraphs_record = True
262-
else:
263-
need_cudagraphs_record = (
264-
cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs)
265-
)
266-
self.cudagraphs_enabled = cudagraphs_enabled
261+
need_cudagraphs_record = cudagraphs_enabled and (
262+
(not self.prev_cudagraphs_enabled)
263+
or (not self.cudagraphs_validate_shapes(inputs))
264+
)
265+
self.prev_cudagraphs_enabled = cudagraphs_enabled
267266

268267
if need_cudagraphs_record:
269268
self._input_buffers = [None] * len(self.input_names)

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
196196

197197
return budget_bytes
198198

199-
def set_cudagraphs_enabled_parent_module(self, enable: bool) -> None:
200-
self.engine.set_cudagraphs_enabled_parent_module(enable)
199+
def set_whole_cudagraphs(self, enable: bool) -> None:
200+
self.engine.set_whole_cudagraphs(enable)
201201

202202
def setup_engine(self) -> None:
203203
"""

py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
import torch_tensorrt
10+
from torch._subclasses.fake_tensor import FakeTensorMode
1011
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1112
from torch_tensorrt.dynamo import partitioning
1213
from torch_tensorrt.dynamo.conversion import DYNAMIC_DIM
@@ -17,17 +18,28 @@
1718

1819

1920
class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
20-
"""This Wrapper runtime module to record/replay cuda graph in sub modules"""
21+
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
22+
23+
Args:
24+
original_module: Unmodified FX GraphModule
25+
compiled_module: Complied fx graphModule that will be wrapped
26+
output_shapes: Shapes of output Tensors of the graph
27+
output_dtypes: Output data types of the graph
28+
Returns:
29+
Output tensor or tensor list
30+
"""
2131

2232
def __init__(
2333
self,
2434
original_module: torch.nn.Module,
35+
compiled_module: torch.nn.Module,
2536
output_shapes: List[torch.Size],
2637
output_dtypes: List[torch.dtype],
2738
):
2839
super(WrapperTorchTensorRTModule, self).__init__()
2940
self.original_module = original_module
30-
self.inputs = partitioning.construct_submodule_inputs(original_module)
41+
self.compiled_module = compiled_module
42+
self.inputs = partitioning.construct_submodule_inputs(compiled_module)
3143
self.output_shapes = output_shapes
3244
self.output_dtypes = output_dtypes
3345

@@ -42,9 +54,9 @@ def __init__(
4254
self.input_is_dynamic = input_is_dynamic(self.inputs)
4355

4456
# Disable cudagrphs in submodules as it will be enabled in wrapper
45-
for name, rt_mod in self.original_module.named_children():
57+
for name, rt_mod in self.compiled_module.named_children():
4658
if "_run_on_acc" in name:
47-
rt_mod.set_cudagraphs_enabled_parent_module(True)
59+
rt_mod.set_whole_cudagraphs(True)
4860

4961
# Warm up is necessary to ensure that memory allocations and initializations are not recorded in cuda graphs
5062
with unset_fake_temporarily():
@@ -53,7 +65,7 @@ def __init__(
5365
s.wait_stream(torch.cuda.current_stream())
5466
with torch.cuda.stream(s):
5567
for _ in range(3):
56-
self.original_module(*inputs_tensor)
68+
self.compiled_module(*inputs_tensor)
5769
torch.cuda.current_stream().wait_stream(s)
5870

5971
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
@@ -71,7 +83,9 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
7183
self.shape_key = new_shape_key
7284

7385
if self.input_is_dynamic:
74-
tmp_outputs = self.original_module(*inputs)
86+
with FakeTensorMode() as mode:
87+
fake_inputs = [mode.from_tensor(input) for input in inputs]
88+
tmp_outputs = self.original_module(*fake_inputs)
7589
if not isinstance(tmp_outputs, (list, tuple)):
7690
tmp_outputs = [tmp_outputs]
7791
self.output_shapes = [tuple(output.shape) for output in tmp_outputs]
@@ -237,7 +251,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
237251
with torch.cuda.graph(
238252
self.cudagraph, stream=self._engine_stream
239253
):
240-
self._output_buffers = self.original_module(
254+
self._output_buffers = self.compiled_module(
241255
*self._input_buffers
242256
)
243257

@@ -251,7 +265,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
251265
self.cudagraph.replay() # type: ignore
252266

253267
else:
254-
outputs = self.original_module(*inputs)
268+
outputs = self.compiled_module(*inputs)
255269

256270
self._caller_stream.wait_stream(self._engine_stream)
257271

py/torch_tensorrt/runtime/_weight_streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
self.current_device_budget = 0
2323

2424
if isinstance(module, WrapperTorchTensorRTModule):
25-
module = module.original_module
25+
module = module.compiled_module
2626
for name, rt_mod in module.named_children():
2727
if "_run_on_acc" in name and isinstance(
2828
rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)

0 commit comments

Comments
 (0)