Skip to content

Commit d16585f

Browse files
committed
chore: update streams
1 parent 05627cd commit d16585f

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

core/runtime/execute_engine.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
178178
enqueue_profiler_guard =
179179
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path);
180180
}
181-
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index());
182-
181+
c10::cuda::CUDAStream stream = c10::cuda::getStreamFromPool(/*isHighPriority=*/true, inputs[0].device().index());
183182
// nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it.
184183
std::unique_lock<std::mutex> lock(compiled_engine->mu);
185184
compiled_engine->exec_ctx->enqueueV3(stream);

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _initialize(self) -> None:
5959
runtime = trt.Runtime(logger)
6060
self.engine = runtime.deserialize_cuda_engine(self.engine)
6161
self.context = self.engine.create_execution_context()
62-
62+
self.stream = torch.cuda.Stream(torch.cuda.current_device())
6363
# Indices of inputs/outputs in the trt engine bindings, in the order
6464
# as they are in the original PyTorch model.
6565

@@ -286,7 +286,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
286286
if self.profiling_enabled
287287
else nullcontext()
288288
):
289-
self.context.execute_async_v3(torch.cuda.current_stream().cuda_stream)
289+
self.context.execute_async_v3(self.stream.cuda_stream)
290290

291291
if len(outputs) == 1:
292292
return outputs[0]

0 commit comments

Comments
 (0)