Skip to content

Commit ecee5a6

Browse files
committed
chore: update for TorchTensorRTModule
1 parent 6af0886 commit ecee5a6

File tree

7 files changed

+28
-9
lines changed

7 files changed

+28
-9
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ TRTEngine::TRTEngine(
212212
LOG_DEBUG(*this);
213213
}
214214

215+
void TRTEngine::set_cudagraphs_enabled_parent_module(bool enable) {
216+
cudagraphs_enabled_parent_module = enable;
217+
}
218+
215219
TRTEngine::~TRTEngine() {
216220
trt_engine_profiler.reset();
217221
exec_ctx.reset();

core/runtime/TRTEngine.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ struct TRTEngine : torch::CustomClassHolder {
7575
bool set_device_memory_budget(int64_t budget);
7676
int64_t get_streamable_device_memory_budget();
7777
int64_t get_automatic_device_memory_budget();
78+
void set_cudagraphs_enabled_parent_module(bool enable);
7879
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
7980
static const char BINDING_DELIM = '%';
8081

@@ -85,7 +86,8 @@ struct TRTEngine : torch::CustomClassHolder {
8586
std::vector<at::Tensor> input_buffers = {};
8687
std::vector<at::Tensor> output_buffers = {};
8788
std::string shape_key;
88-
89+
bool cudagraphs_enabled = false;
90+
bool cudagraphs_enabled_parent_module = false;
8991
// TODO: Implement a call method
9092
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
9193

core/runtime/execute_engine.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,16 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
113113
LOG_INFO("" << log_info);
114114
compiled_engine->cudagraph.enable_debug_mode();
115115
}
116+
bool cudagraphs_enabled = (!compiled_engine->cudagraphs_enabled_parent_module && CUDAGRAPHS_MODE);
116117

117118
// Whether cudagraphs needs to record the graph on this pass
118-
bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine)));
119+
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
120+
bool need_cudagraphs_record =
121+
(((!compiled_engine->cudagraphs_enabled) && cudagraphs_enabled) ||
122+
(cudagraphs_enabled && (!_cudagraphs_validate_shapes(inputs, compiled_engine))));
123+
compiled_engine->cudagraphs_enabled = cudagraphs_enabled;
119124

120-
if (!CUDAGRAPHS_MODE) {
125+
if (!cudagraphs_enabled) {
121126
compiled_engine->cudagraph.reset();
122127
}
123128

@@ -211,7 +216,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
211216
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
212217
"Error while setting the tensor address for shape inputs");
213218

214-
if (CUDAGRAPHS_MODE) {
219+
if (cudagraphs_enabled) {
215220
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
216221
compiled_engine->input_buffers[i] = input_cpu;
217222
}
@@ -231,7 +236,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
231236
TORCHTRT_CHECK(
232237
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");
233238

234-
if (CUDAGRAPHS_MODE) {
239+
if (cudagraphs_enabled) {
235240
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
236241
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
237242
TORCHTRT_CHECK(
@@ -281,7 +286,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
281286
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
282287
}
283288

284-
if (CUDAGRAPHS_MODE) {
289+
if (cudagraphs_enabled) {
285290
TORCHTRT_CHECK(
286291
compiled_engine->exec_ctx->setTensorAddress(
287292
name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()),
@@ -324,7 +329,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
324329
caller_exec_complete.record(compiled_engine->caller_stream);
325330
caller_exec_complete.block(compiled_engine->engine_stream);
326331

327-
if (!CUDAGRAPHS_MODE) {
332+
if (!cudagraphs_enabled) {
328333
// Direct execution uses the caller buffers directly
329334
compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream);
330335
} else {
@@ -350,7 +355,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
350355
trt_exec_complete.record(compiled_engine->engine_stream);
351356
trt_exec_complete.block(compiled_engine->caller_stream);
352357

353-
if (CUDAGRAPHS_MODE) {
358+
if (cudagraphs_enabled) {
354359
// If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream)
355360
for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) {
356361
outputs[o].copy_(compiled_engine->output_buffers[o], false);

core/runtime/register_jit_hooks.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8686
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
8787
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
8888
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
89+
.def("set_cudagraphs_enabled_parent_module", &TRTEngine::set_cudagraphs_enabled_parent_module)
8990
.def_property(
9091
"device_memory_budget",
9192
&TRTEngine::get_device_memory_budget,

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def set_default_device_memory_budget(self) -> int:
149149
logger.debug(f"Weight streaming budget set to {budget_bytes}B")
150150
return self._set_device_memory_budget(budget_bytes)
151151

152+
def set_cudagraphs_enabled_parent_module(self, enable: bool) -> None:
153+
self.cudagraphs_enabled_parent_module = enable
154+
152155
def setup_engine(self) -> None:
153156
assert (
154157
self.target_platform == Platform.current_platform()

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(
131131
self.weight_name_map = weight_name_map
132132
self.serialized_engine = serialized_engine
133133
self.engine = None
134+
self.cudagraphs_enabled_parent_module = False
134135

135136
if serialized_engine and not self.settings.lazy_engine_init:
136137
self.setup_engine()
@@ -191,6 +192,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
191192

192193
return budget_bytes
193194

195+
def set_cudagraphs_enabled_parent_module(self, enable: bool) -> None:
196+
self.engine.set_cudagraphs_enabled_parent_module(enable)
197+
194198
def setup_engine(self) -> None:
195199
"""
196200
Setup engine for a module which has deferred engine setup.

py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
# Disable cudagrphs in submodules as it will be enabled in wrapper
4242
for name, rt_mod in self.original_module.named_children():
4343
if "_run_on_acc" in name:
44-
rt_mod.cudagraphs_enabled_parent_module = True
44+
rt_mod.set_cudagraphs_enabled_parent_module(True)
4545

4646
# TODO: check if only torch needs warm up.
4747
with unset_fake_temporarily():

0 commit comments

Comments
 (0)