Skip to content

Commit f74ad40

Browse files
committed
chore: tri state of cuda graphs mode
1 parent 970e9bf commit f74ad40

File tree

12 files changed

+158
-247
lines changed

12 files changed

+158
-247
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,6 @@ TRTEngine::TRTEngine(
213213
LOG_DEBUG(*this);
214214
}
215215

216-
void TRTEngine::set_whole_cudagraphs(bool enable) {
217-
whole_cudagraphs = enable;
218-
}
219-
220216
TRTEngine::~TRTEngine() {
221217
trt_engine_profiler.reset();
222218
exec_ctx.reset();

core/runtime/TRTEngine.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ struct TRTEngine : torch::CustomClassHolder {
8787
bool set_device_memory_budget(int64_t budget);
8888
int64_t get_streamable_device_memory_budget();
8989
int64_t get_automatic_device_memory_budget();
90-
void set_whole_cudagraphs(bool enable);
9190
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
9291
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
9392
static const char BINDING_DELIM = '%';
@@ -104,13 +103,12 @@ struct TRTEngine : torch::CustomClassHolder {
104103
std::vector<at::Tensor> output_buffers = {};
105104
std::string shape_key;
106105
bool prev_cudagraphs_enabled = false;
107-
bool whole_cudagraphs = false;
108106
// TODO: Implement a call method
109107
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
110108

111109
void set_profiling_paths();
112110
#ifndef NDEBUG
113-
bool profile_execution = true;
111+
bool profile_execution = false;
114112
#else
115113
bool profile_execution = false;
116114
#endif

core/runtime/execute_engine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
113113
LOG_INFO("" << log_info);
114114
compiled_engine->cudagraph.enable_debug_mode();
115115
}
116-
bool cudagraphs_enabled = (!compiled_engine->whole_cudagraphs && CUDAGRAPHS_MODE);
116+
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
117117

118118
// Whether cudagraphs needs to record the graph on this pass
119119
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change

core/runtime/register_jit_hooks.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8787
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
8888
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
8989
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
90-
.def("set_whole_cudagraphs", &TRTEngine::set_whole_cudagraphs)
9190
.def("infer_outputs", &TRTEngine::infer_outputs)
9291
.def_property(
9392
"device_memory_budget",
@@ -112,8 +111,10 @@ TORCH_LIBRARY(tensorrt, m) {
112111
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
113112
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
114113
});
115-
m.def("get_cudagraphs_mode", []() -> bool { return CUDAGRAPHS_MODE; });
116-
m.def("set_cudagraphs_mode", [](bool cudagraphs_mode) -> void { CUDAGRAPHS_MODE = cudagraphs_mode; });
114+
m.def("get_cudagraphs_mode", []() -> int64_t { return CUDAGRAPHS_MODE; });
115+
m.def("set_cudagraphs_mode", [](int64_t cudagraphs_mode) -> void {
116+
CUDAGRAPHS_MODE = CudaGraphsMode(cudagraphs_mode);
117+
});
117118
m.def("set_logging_level", [](int64_t level) -> void {
118119
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
119120
});

core/runtime/runtime.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace core {
88
namespace runtime {
99

1010
bool MULTI_DEVICE_SAFE_MODE = false;
11-
bool CUDAGRAPHS_MODE = false;
11+
CudaGraphsMode CUDAGRAPHS_MODE = STANDARD;
1212

1313
c10::optional<RTDevice> get_most_compatible_device(
1414
const RTDevice& target_device,
@@ -130,11 +130,11 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode) {
130130
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
131131
}
132132

133-
bool get_cudagraphs_mode() {
133+
CudaGraphsMode get_cudagraphs_mode() {
134134
return CUDAGRAPHS_MODE;
135135
}
136136

137-
void set_cudagraphs_mode(bool cudagraphs_mode) {
137+
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode) {
138138
CUDAGRAPHS_MODE = cudagraphs_mode;
139139
}
140140

core/runtime/runtime.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@ namespace runtime {
1818
using EngineID = int64_t;
1919
const std::string ABI_VERSION = "6";
2020
extern bool MULTI_DEVICE_SAFE_MODE;
21-
extern bool CUDAGRAPHS_MODE;
21+
22+
typedef enum {
23+
STANDARD = 0,
24+
SUBGRAPH_CUDAGRAPHS,
25+
WHOLE_GRAPH_CUDAGRAPHS,
26+
} CudaGraphsMode;
27+
28+
extern CudaGraphsMode CUDAGRAPHS_MODE;
2229

2330
typedef enum {
2431
ABI_TARGET_IDX = 0,
@@ -51,9 +58,9 @@ bool get_multi_device_safe_mode();
5158

5259
void set_multi_device_safe_mode(bool multi_device_safe_mode);
5360

54-
bool get_cudagraphs_mode();
61+
CudaGraphsMode get_cudagraphs_mode();
5562

56-
void set_cudagraphs_mode(bool cudagraphs_mode);
63+
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode);
5764

5865
class DeviceList {
5966
using DeviceMap = std::unordered_map<int, RTDevice>;

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ def __init__(
107107
self.engine = None
108108
self.weight_name_map = weight_name_map
109109
self.target_platform = Platform.current_platform()
110-
# Check if CUDA graph capture is enabled in the parent node
111-
self.whole_cudagraphs = False
112110
# Previous cuda graphs state
113111
self.prev_cudagraphs_enabled = False
114112

@@ -151,14 +149,6 @@ def set_default_device_memory_budget(self) -> int:
151149
logger.debug(f"Weight streaming budget set to {budget_bytes}B")
152150
return self._set_device_memory_budget(budget_bytes)
153151

154-
def set_whole_cudagraphs(self, enable: bool) -> None:
155-
"""
156-
When the global CUDA graphs mode is enabled, the parent wrapper module handles all
157-
CUDA graph recording and replay. Therefore, any child modules must disable their
158-
own CUDA graph functionality to avoid conflicts.
159-
"""
160-
self.whole_cudagraphs = enable
161-
162152
def setup_engine(self) -> None:
163153
assert (
164154
self.target_platform == Platform.current_platform()
@@ -257,10 +247,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
257247
):
258248
self._check_initialized()
259249

260-
cudagraphs_enabled = (
261-
torch_tensorrt.runtime.get_cudagraphs_mode()
262-
and not self.whole_cudagraphs
263-
)
250+
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
251+
264252
# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
265253
need_cudagraphs_record = cudagraphs_enabled and (
266254
(not self.prev_cudagraphs_enabled)

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,6 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
195195

196196
return budget_bytes
197197

198-
def set_whole_cudagraphs(self, enable: bool) -> None:
199-
"""
200-
When the global CUDA graphs mode is enabled, the parent wrapper module handles all
201-
CUDA graph recording and replay. Therefore, any child modules must disable their
202-
own CUDA graph functionality to avoid conflicts.
203-
"""
204-
self.engine.set_whole_cudagraphs(enable)
205-
206198
def setup_engine(self) -> None:
207199
"""
208200
Setup engine for a module which has deferred engine setup.

0 commit comments

Comments
 (0)