Skip to content

Commit e1e3deb

Browse files
committed
core: option to use persistent output buffer
1 parent 92ab985 commit e1e3deb

File tree

2 files changed

+71
-55
lines changed

2 files changed

+71
-55
lines changed

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 70 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import logging
4-
from contextlib import nullcontext
54
from tempfile import tempdir
65
from typing import Any, Dict, List, Optional, Sequence, Tuple
76

@@ -107,7 +106,10 @@ def __init__(
107106
self.engine = None
108107
self.weight_name_map = weight_name_map
109108
self.target_platform = Platform.current_platform()
110-
self.cudagraphs_disabled = False
109+
# Check if CUDA graph capture is enabled in the parent node
110+
self.cudagraphs_parent_module = False
111+
self.cudagraphs_enabled = False
112+
self.persistent_output_buffer = False
111113

112114
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
113115
self.setup_engine()
@@ -239,23 +241,31 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
239241
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
240242
for i in inputs
241243
]
242-
with nvtx.annotate(f"Forward", color="red"):
244+
with nvtx.annotate("Forward", color="red"):
243245
self._check_initialized()
244-
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() and not self.cudagraphs_disabled
245-
246-
need_cudagraphs_record = (
247-
cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs)
246+
cudagraphs_enabled = (
247+
torch_tensorrt.runtime.get_cudagraphs_mode()
248+
and not self.cudagraphs_parent_module
248249
)
250+
shape_changed = self.validate_input_shapes(inputs)
251+
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
252+
if not self.cudagraphs_enabled and cudagraphs_enabled:
253+
need_cudagraphs_record = True
254+
else:
255+
need_cudagraphs_record = cudagraphs_enabled and shape_changed
256+
self.cudagraphs_enabled = cudagraphs_enabled
249257

250258
if need_cudagraphs_record:
259+
if self.cudagraph:
260+
self.cudagraph.reset()
251261
self._input_buffers = [None] * len(self.input_names)
252262
self._output_buffers = [None] * len(self.output_names)
253263

254264
if not cudagraphs_enabled and self.cudagraph:
255265
self.cudagraph.reset()
256266
self.cudagraph = None
257267

258-
# If in safe mode, check at each iteration for for whether a switch is required
268+
# If in safe mode, check at each iteration for whether a switch is required
259269
if (
260270
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
261271
):
@@ -287,7 +297,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
287297
]
288298
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
289299

290-
with nvtx.annotate(f"ProcessInputs", color="red"):
300+
with nvtx.annotate("ProcessInputs", color="red"):
291301
assert len(contiguous_inputs) == len(
292302
self.input_names
293303
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."
@@ -349,63 +359,66 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
349359
This could happen if the input tensor addresses/shapes haven't been configured correctly"
350360
)
351361

352-
with nvtx.annotate(f"ProcessOutputs", color="red"):
362+
with nvtx.annotate("ProcessOutputs", color="red"):
353363
# create output tensors
354364
outputs: List[torch.Tensor] = []
365+
if not self.persistent_output_buffer or shape_changed:
366+
# Create and keep persistent output buffer as long as its shape does not change
367+
self._output_buffers = []
368+
for o, output_name in enumerate(self.output_names):
369+
shape = tuple(self.context.get_tensor_shape(output_name))
370+
if DYNAMIC_DIM in shape:
371+
raise ValueError(
372+
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
373+
)
355374

356-
for o, output_name in enumerate(self.output_names):
357-
shape = tuple(self.context.get_tensor_shape(output_name))
358-
359-
if DYNAMIC_DIM in shape:
360-
raise ValueError(
361-
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
375+
output = torch.empty(
376+
size=shape,
377+
dtype=self.output_dtypes[o].to(torch.dtype),
378+
device=torch.cuda.current_device(),
362379
)
363-
364-
output = torch.empty(
365-
size=shape,
366-
dtype=self.output_dtypes[o].to(torch.dtype),
367-
device=torch.cuda.current_device(),
368-
)
369-
370-
outputs.append(output)
371-
372-
if need_cudagraphs_record:
373-
self._output_buffers[o] = outputs[o].clone()
374-
375-
if cudagraphs_enabled:
376-
self.context.set_tensor_address(
377-
output_name, self._output_buffers[o].data_ptr()
378-
)
379-
else:
380-
self.context.set_tensor_address(
381-
output_name, outputs[o].data_ptr()
382-
)
383-
384-
with nvtx.annotate(f"TensorRTRuntime", color="red"):
380+
if self.persistent_output_buffer:
381+
self.context.set_tensor_address(
382+
output_name, output.data_ptr()
383+
)
384+
self._output_buffers.append(output)
385+
else:
386+
outputs.append(output)
387+
if need_cudagraphs_record:
388+
self._output_buffers[o] = outputs[o].clone()
389+
if cudagraphs_enabled:
390+
self.context.set_tensor_address(
391+
output_name, self._output_buffers[o].data_ptr()
392+
)
393+
else:
394+
self.context.set_tensor_address(
395+
output_name, outputs[o].data_ptr()
396+
)
397+
with nvtx.annotate("TensorRTRuntime", color="red"):
385398
self._caller_stream = torch.cuda.current_stream()
386399
if (
387400
self._engine_stream == torch.cuda.default_stream()
388401
or self._engine_stream is None
389402
):
390403
self._engine_stream = torch.cuda.Stream()
391404

392-
with nvtx.annotate(f"wait_stream", color="green"):
405+
with nvtx.annotate("wait_stream", color="green"):
393406
self._engine_stream.wait_stream(self._caller_stream)
394407

395408
with torch.cuda.stream(self._engine_stream):
396409
if cudagraphs_enabled:
397410
if need_cudagraphs_record:
398-
with nvtx.annotate(f"CUDAGraph", color="green"):
411+
with nvtx.annotate("CUDAGraph", color="green"):
399412
self.cudagraph = torch.cuda.CUDAGraph()
400413

401414
if self.profiling_enabled:
402415
self.cudagraph.enable_debug_mode()
403-
with nvtx.annotate(f"torch.cuda.graph", color="green"):
416+
with nvtx.annotate("torch.cuda.graph", color="green"):
404417
with torch.cuda.graph(
405418
self.cudagraph, stream=self._engine_stream
406419
):
407420
with nvtx.annotate(
408-
f"execute_async_v3", color="green"
421+
"execute_async_v3", color="green"
409422
):
410423
self.context.execute_async_v3(
411424
self._engine_stream.cuda_stream
@@ -418,17 +431,23 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
418431
self.cudagraph.debug_dump(
419432
f"{tempdir}/{self.name}_cudagraph.dot"
420433
)
421-
with nvtx.annotate(f"replay", color="green"):
434+
with nvtx.annotate("replay", color="green"):
422435
self.cudagraph.replay() # type: ignore
423436

424437
else:
425438
self.context.execute_async_v3(self._engine_stream.cuda_stream)
426439

427440
self._caller_stream.wait_stream(self._engine_stream)
428441

429-
if cudagraphs_enabled:
430-
for idx, o in enumerate(outputs):
431-
o.copy_(self._output_buffers[idx])
442+
if self.persistent_output_buffer:
443+
if len(self._output_buffers) == 1:
444+
return self._output_buffers[0]
445+
446+
return self._output_buffers
447+
else:
448+
if cudagraphs_enabled:
449+
for idx, o in enumerate(outputs):
450+
o.copy_(self._output_buffers[idx])
432451

433452
if len(outputs) == 1:
434453
return outputs[0]
@@ -467,10 +486,9 @@ def get_layer_info(self) -> str:
467486
)
468487
return engine_json
469488

470-
def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
489+
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
471490
"""
472-
Validates the input shapes of the forward function
473-
versus the version currently active for the
491+
Validates the input shapes of the forward function has changed
474492
"""
475493
# Representation of input shapes to a given model
476494
# Shapes are concatenated as so:
@@ -480,10 +498,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
480498
# If the new shape key differs from the existing one,
481499
# invalidate the old shape key and remove the CUDAGraph
482500
if new_shape_key != self.shape_key:
483-
logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}")
501+
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
484502
self.shape_key = new_shape_key
485-
if self.cudagraph:
486-
self.cudagraph.reset()
487-
return False
503+
return True
488504

489-
return True
505+
return False

py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
# Disable cudagrphs in submodules as it will be enabled in wrapper
4242
for name, rt_mod in self.original_module.named_children():
4343
if "_run_on_acc" in name:
44-
rt_mod.cudagraphs_disabled = True
44+
rt_mod.cudagraphs_parent_module = True
4545

4646
# TODO: check if only torch needs warm up.
4747
with unset_fake_temporarily():

0 commit comments

Comments
 (0)