Skip to content

Commit edf63bb

Browse files
cherry pick #3191 from main to release/2.5 (#3237)
1 parent e314ad6 commit edf63bb

File tree

4 files changed

+4
-11
lines changed

4 files changed

+4
-11
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ TRTEngine::TRTEngine(
7373
<< get_current_platform() << ")");
7474
this->target_platform = target_platform;
7575

76-
this->cudagraph_mempool_id = at::cuda::graph_pool_handle();
77-
7876
this->hardware_compatible = hardware_compatible;
7977
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
8078
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");

core/runtime/TRTEngine.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ struct TRTEngine : torch::CustomClassHolder {
8181
std::vector<at::Tensor> input_buffers = {};
8282
std::vector<at::Tensor> output_buffers = {};
8383
std::string shape_key;
84-
at::cuda::MempoolId_t cudagraph_mempool_id;
8584

8685
// TODO: Implement a call method
8786
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

core/runtime/execute_engine.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
305305
if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) {
306306
// Create a new stream if the engine stream is the default stream
307307
compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id);
308-
} else {
309-
compiled_engine->engine_stream = compiled_engine->caller_stream;
310308
}
311309

312310
// nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it.
@@ -333,7 +331,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
333331
if (need_cudagraphs_record) {
334332
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
335333
c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream;
336-
compiled_engine->cudagraph.capture_begin(compiled_engine->cudagraph_mempool_id);
334+
compiled_engine->cudagraph.capture_begin();
337335
compiled_engine->exec_ctx->enqueueV3(recording_stream);
338336
compiled_engine->cudagraph.capture_end();
339337

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tempfile import tempdir
66
from typing import Any, Dict, List, Optional, Sequence, Tuple
77

8+
import tensorrt as trt
89
import torch
910
import torch_tensorrt
1011
from torch.nn import Module
@@ -19,8 +20,6 @@
1920
multi_gpu_device_check,
2021
)
2122

22-
import tensorrt as trt
23-
2423
logger = logging.getLogger(__name__)
2524

2625

@@ -372,8 +371,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
372371
or self._engine_stream is None
373372
):
374373
self._engine_stream = torch.cuda.Stream()
375-
else:
376-
self._engine_stream = self._caller_stream
377374

378375
self._engine_stream.wait_stream(self._caller_stream)
379376

@@ -464,7 +461,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
464461
if new_shape_key != self.shape_key:
465462
logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}")
466463
self.shape_key = new_shape_key
467-
self.cudagraph.reset() # type: ignore
464+
if self.cudagraph:
465+
self.cudagraph.reset()
468466
return False
469467

470468
return True

0 commit comments

Comments
 (0)