Skip to content

Commit 092be2a

Browse files
authored
Wrapper module around TRT + pytorch subgraphs (#3270)
1 parent bd11733 commit 092be2a

File tree

14 files changed

+376
-58
lines changed

14 files changed

+376
-58
lines changed

core/runtime/execute_engine.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ bool _validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngi
9494
void setup_input_tensors(
9595
std::vector<at::Tensor> inputs,
9696
c10::intrusive_ptr<TRTEngine> compiled_engine,
97+
bool cudagraphs_enabled,
9798
bool need_cudagraphs_record) {
9899
// this is a buffer to store shape tensor input addresses throughout the runtime scope
99100
std::list<std::vector<int64_t>> inputShapeTensorValues;
@@ -127,7 +128,7 @@ void setup_input_tensors(
127128
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
128129
"Error while setting the tensor address for shape inputs");
129130

130-
if (CUDAGRAPHS_MODE) {
131+
if (cudagraphs_enabled) {
131132
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
132133
compiled_engine->input_buffers[i] = input_cpu;
133134
}
@@ -147,7 +148,7 @@ void setup_input_tensors(
147148
TORCHTRT_CHECK(
148149
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");
149150

150-
if (CUDAGRAPHS_MODE) {
151+
if (cudagraphs_enabled) {
151152
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
152153
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
153154
TORCHTRT_CHECK(
@@ -201,17 +202,17 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
201202
LOG_INFO("" << log_info);
202203
compiled_engine->cudagraph.enable_debug_mode();
203204
}
204-
205+
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
205206
bool shape_changed = _validate_shapes(inputs, compiled_engine);
206207

207208
// Whether cudagraphs needs to record the graph on this pass
208209
auto result = compiled_engine->runtime_states.set_runtime_states(
209-
CUDAGRAPHS_MODE, compiled_engine->use_pre_allocated_outputs, shape_changed);
210+
cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed);
210211

211212
bool need_cudagraphs_record = std::get<0>(result);
212213
bool can_use_pre_allocated_outputs = std::get<1>(result);
213214

214-
if (!CUDAGRAPHS_MODE || shape_changed) {
215+
if (!cudagraphs_enabled || shape_changed) {
215216
compiled_engine->cudagraph.reset();
216217
}
217218

@@ -273,8 +274,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
273274
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
274275
}
275276

276-
setup_input_tensors(inputs, compiled_engine, need_cudagraphs_record);
277-
277+
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
278278
// Check if input shapes can be inferred.
279279
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
280280
std::vector<char const*> names(io_size);
@@ -306,7 +306,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
306306
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
307307
}
308308

309-
if (CUDAGRAPHS_MODE) {
309+
if (cudagraphs_enabled) {
310310
TORCHTRT_CHECK(
311311
compiled_engine->exec_ctx->setTensorAddress(
312312
name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()),
@@ -346,7 +346,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
346346
caller_exec_complete.record(compiled_engine->caller_stream);
347347
caller_exec_complete.block(compiled_engine->engine_stream);
348348

349-
if (!CUDAGRAPHS_MODE) {
349+
if (!cudagraphs_enabled) {
350350
// Direct execution uses the caller buffers directly
351351
compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream);
352352
} else {
@@ -377,7 +377,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
377377
trt_exec_complete.record(compiled_engine->engine_stream);
378378
trt_exec_complete.block(compiled_engine->caller_stream);
379379

380-
if (CUDAGRAPHS_MODE) {
380+
if (cudagraphs_enabled) {
381381
// If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream)
382382
for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) {
383383
outputs[o].copy_(compiled_engine->output_buffers[o], false);

core/runtime/register_jit_hooks.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ TORCH_LIBRARY(tensorrt, m) {
112112
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
113113
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
114114
});
115-
m.def("get_cudagraphs_mode", []() -> bool { return CUDAGRAPHS_MODE; });
116-
m.def("set_cudagraphs_mode", [](bool cudagraphs_mode) -> void { CUDAGRAPHS_MODE = cudagraphs_mode; });
115+
m.def("get_cudagraphs_mode", []() -> int64_t { return CUDAGRAPHS_MODE; });
116+
m.def("set_cudagraphs_mode", [](int64_t cudagraphs_mode) -> void {
117+
CUDAGRAPHS_MODE = CudaGraphsMode(cudagraphs_mode);
118+
});
117119
m.def("set_logging_level", [](int64_t level) -> void {
118120
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
119121
});

core/runtime/runtime.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace core {
88
namespace runtime {
99

1010
bool MULTI_DEVICE_SAFE_MODE = false;
11-
bool CUDAGRAPHS_MODE = false;
11+
CudaGraphsMode CUDAGRAPHS_MODE = STANDARD;
1212

1313
c10::optional<RTDevice> get_most_compatible_device(
1414
const RTDevice& target_device,
@@ -130,11 +130,11 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode) {
130130
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
131131
}
132132

133-
bool get_cudagraphs_mode() {
133+
CudaGraphsMode get_cudagraphs_mode() {
134134
return CUDAGRAPHS_MODE;
135135
}
136136

137-
void set_cudagraphs_mode(bool cudagraphs_mode) {
137+
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode) {
138138
CUDAGRAPHS_MODE = cudagraphs_mode;
139139
}
140140

core/runtime/runtime.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@ namespace runtime {
1818
using EngineID = int64_t;
1919
const std::string ABI_VERSION = "6";
2020
extern bool MULTI_DEVICE_SAFE_MODE;
21-
extern bool CUDAGRAPHS_MODE;
21+
22+
typedef enum {
23+
STANDARD = 0,
24+
SUBGRAPH_CUDAGRAPHS,
25+
WHOLE_GRAPH_CUDAGRAPHS,
26+
} CudaGraphsMode;
27+
28+
extern CudaGraphsMode CUDAGRAPHS_MODE;
2229

2330
typedef enum {
2431
ABI_TARGET_IDX = 0,
@@ -51,9 +58,9 @@ bool get_multi_device_safe_mode();
5158

5259
void set_multi_device_safe_mode(bool multi_device_safe_mode);
5360

54-
bool get_cudagraphs_mode();
61+
CudaGraphsMode get_cudagraphs_mode();
5562

56-
void set_cudagraphs_mode(bool cudagraphs_mode);
63+
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode);
5764

5865
class DeviceList {
5966
using DeviceMap = std::unordered_map<int, RTDevice>;

docsrc/user_guide/runtime.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Cudagraphs can accelerate certain models by reducing kernel overheads, as docume
8686
torch_tensorrt.runtime.set_cudagraphs_mode(False)
8787
8888
# Enables Cudagraphs Mode, then resets the mode to its prior setting
89-
with torch_tensorrt.runtime.enable_cudagraphs():
89+
with torch_tensorrt.runtime.enable_cudagraphs(trt_module):
9090
...
9191
9292
In the current implementation, use of a new input shape (for instance in dynamic shape

examples/dynamo/torch_export_cudagraphs.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1212

1313
import torch
14-
import torchvision.models as models
15-
1614
import torch_tensorrt
15+
import torchvision.models as models
1716

1817
# %%
1918
# Compilation with `torch_tensorrt.compile` Using Default Settings
@@ -47,8 +46,8 @@
4746
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4847

4948
# We can enable the cudagraphs API with a context manager
50-
with torch_tensorrt.runtime.enable_cudagraphs():
51-
out_trt = opt(inputs)
49+
with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module:
50+
out_trt = cudagraphs_module(inputs)
5251

5352
# Alternatively, we can set the cudagraphs mode for the session
5453
torch_tensorrt.runtime.set_cudagraphs_mode(True)
@@ -64,6 +63,49 @@
6463
inputs_2 = torch.randn((8, 3, 224, 224)).cuda()
6564
inputs_3 = torch.randn((4, 3, 224, 224)).cuda()
6665

67-
with torch_tensorrt.runtime.enable_cudagraphs():
68-
out_trt_2 = opt(inputs_2)
69-
out_trt_3 = opt(inputs_3)
66+
with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module:
67+
out_trt_2 = cudagraphs_module(inputs_2)
68+
out_trt_3 = cudagraphs_module(inputs_3)
69+
70+
# %%
71+
# Cuda graphs with module that contains graph breaks
72+
# ----------------------------------
73+
#
74+
# When CUDA Graphs are applied to a TensorRT model that contains graph breaks, each break introduces additional
75+
# overhead. This occurs because graph breaks prevent the entire model from being executed as a single, continuous
76+
# optimized unit. As a result, some of the performance benefits typically provided by CUDA Graphs, such as reduced
77+
# kernel launch overhead and improved execution efficiency, may be diminished.
78+
# Using a wrapped runtime module with CUDA Graphs allows you to encapsulate sequences of operations into graphs
79+
# that can be executed efficiently, even in the presence of graph breaks.
80+
# If TensorRT module has graph breaks, CUDA Graph context manager returns a wrapped_module. This module captures entire
81+
# execution graph, enabling efficient replay during subsequent inferences by reducing kernel launch overheads
82+
# and improving performance. Note that initializing with the wrapper module involves a warm-up phase where the
83+
# module is executed several times. This warm-up ensures that memory allocations and initializations are not
84+
# recorded in CUDA Graphs, which helps maintain consistent execution paths and optimize performance.
85+
86+
87+
class SampleModel(torch.nn.Module):
88+
def forward(self, x):
89+
return torch.relu((x + 2) * 0.5)
90+
91+
92+
model = SampleModel().eval().cuda()
93+
input = torch.randn((1, 3, 224, 224)).to("cuda")
94+
95+
# The 'torch_executed_ops' compiler option is used in this example to intentionally introduce graph breaks within the module.
96+
# Note: The Dynamo backend is required for the CUDA Graph context manager to handle modules in an Ahead-Of-Time (AOT) manner.
97+
opt_with_graph_break = torch_tensorrt.compile(
98+
model,
99+
ir="dynamo",
100+
inputs=[input],
101+
min_block_size=1,
102+
pass_through_build_failures=True,
103+
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
104+
)
105+
106+
# %%
107+
# If module has graph breaks, whole submodules are recorded and replayed by cuda graphs
108+
with torch_tensorrt.runtime.enable_cudagraphs(
109+
opt_with_graph_break
110+
) as cudagraphs_module:
111+
cudagraphs_module(input)

py/torch_tensorrt/_compile.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from torch_tensorrt._features import ENABLED_FEATURES
1313
from torch_tensorrt._Input import Input
1414
from torch_tensorrt.dynamo import _defaults
15+
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
16+
CudaGraphsTorchTensorRTModule,
17+
)
1518
from torch_tensorrt.fx import InputTensorSpec
1619
from torch_tensorrt.fx.lower import compile as fx_compile
1720
from torch_tensorrt.fx.utils import LowerPrecision
@@ -586,14 +589,16 @@ def save(
586589
Save the model to disk in the specified output format.
587590
588591
Arguments:
589-
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module
592+
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | CudaGraphsTorchTensorRTModule)): Compiled Torch-TensorRT module
590593
inputs (torch.Tensor): Torch input tensors
591594
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
592595
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
593596
output_format (str): Format to save the model. Options include exported_program | torchscript.
594597
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
595598
This flag is experimental for now.
596599
"""
600+
if isinstance(module, CudaGraphsTorchTensorRTModule):
601+
module = module.compiled_module
597602
module_type = _parse_module_type(module)
598603
accepted_formats = {"exported_program", "torchscript"}
599604
if arg_inputs is not None and not all(

0 commit comments

Comments
 (0)