Skip to content

Commit ba36046

Browse files
committed
chore: pre-allocated output buffer for latency hiding
1 parent 291ae89 commit ba36046

File tree

2 files changed

+56
-52
lines changed

2 files changed

+56
-52
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
835835

836836
dryrun_stats_display(dryrun_tracker, settings.dryrun)
837837

838-
if len(trt_modules) > 1:
838+
# if len(trt_modules) > 1:
839+
if True:
839840
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
840841
partitioned_module = WrapperTorchTensorRTModule(
841842
partitioned_module, dryrun_tracker.output_dtypes

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def __init__(
109109
# Check if CUDA graph capture is enabled in the parent node
110110
self.cudagraphs_parent_module = False
111111
self.cudagraphs_enabled = False
112-
self.persistent_output_buffer = False
112+
self.pre_allocated_outputs: List[torch.Tensor] = []
113+
self.use_pre_allocated_outputs = False
113114

114115
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
115116
self.setup_engine()
@@ -174,7 +175,7 @@ def setup_engine(self) -> None:
174175
self.engine.get_tensor_shape(input_name) for input_name in self.input_names
175176
]
176177
self.output_dtypes = [
177-
dtype._from(self.engine.get_tensor_dtype(output_name))
178+
dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype)
178179
for output_name in self.output_names
179180
]
180181
self.output_shapes = [
@@ -235,6 +236,19 @@ def __del__(self) -> None:
235236
if self.cudagraph:
236237
self.cudagraph.reset()
237238

239+
def create_output_tensors(self) -> List[torch.Tensor]:
240+
# create output tensors
241+
outputs: List[torch.Tensor] = []
242+
243+
for o, _ in enumerate(self.output_names):
244+
output = torch.empty(
245+
size=self.output_shapes[o],
246+
dtype=self.output_dtypes[o],
247+
device=torch.cuda.current_device(),
248+
)
249+
outputs.append(output)
250+
return outputs
251+
238252
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
239253
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
240254
contiguous_inputs: List[torch.Tensor] = [
@@ -350,50 +364,41 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
350364
self.context.set_tensor_address(
351365
input_name, contiguous_inputs[i].data_ptr()
352366
)
367+
if shape_changed:
368+
# Check if input shapes can be inferred.
369+
uninferred_input_names = self.context.infer_shapes()
370+
if uninferred_input_names:
371+
logger.warning(
372+
f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \
373+
This could happen if the input tensor addresses/shapes haven't been configured correctly"
374+
)
353375

354-
# Check if input shapes can be inferred.
355-
uninferred_input_names = self.context.infer_shapes()
356-
if uninferred_input_names:
357-
logger.warning(
358-
f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \
359-
This could happen if the input tensor addresses/shapes haven't been configured correctly"
360-
)
361-
362-
with nvtx.annotate("ProcessOutputs", color="red"):
363-
# create output tensors
364-
outputs: List[torch.Tensor] = []
365-
if not self.persistent_output_buffer or shape_changed:
366-
# Create and keep persistent output buffer as long as its shape does not change
367-
self._output_buffers = []
368-
for o, output_name in enumerate(self.output_names):
369-
shape = tuple(self.context.get_tensor_shape(output_name))
370-
if DYNAMIC_DIM in shape:
371-
raise ValueError(
372-
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
373-
)
376+
with nvtx.annotate("ProcessOutputs:1", color="red"):
377+
if not self.use_pre_allocated_outputs or shape_changed:
378+
self.output_shapes = [
379+
tuple(self.context.get_tensor_shape(output_name))
380+
for output_name in self.output_names
381+
]
382+
if DYNAMIC_DIM in self.output_shapes:
383+
raise ValueError(
384+
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
385+
)
386+
outputs = self.create_output_tensors()
387+
else:
388+
outputs = self.pre_allocated_outputs
374389

375-
output = torch.empty(
376-
size=shape,
377-
dtype=self.output_dtypes[o].to(torch.dtype),
378-
device=torch.cuda.current_device(),
390+
for o, output_name in enumerate(self.output_names):
391+
if need_cudagraphs_record:
392+
self._output_buffers[o] = outputs[o].clone()
393+
if cudagraphs_enabled:
394+
self.context.set_tensor_address(
395+
output_name, self._output_buffers[o].data_ptr()
379396
)
380-
if self.persistent_output_buffer:
381-
self.context.set_tensor_address(
382-
output_name, output.data_ptr()
383-
)
384-
self._output_buffers.append(output)
385-
else:
386-
outputs.append(output)
387-
if need_cudagraphs_record:
388-
self._output_buffers[o] = outputs[o].clone()
389-
if cudagraphs_enabled:
390-
self.context.set_tensor_address(
391-
output_name, self._output_buffers[o].data_ptr()
392-
)
393-
else:
394-
self.context.set_tensor_address(
395-
output_name, outputs[o].data_ptr()
396-
)
397+
else:
398+
self.context.set_tensor_address(
399+
output_name, outputs[o].data_ptr()
400+
)
401+
397402
with nvtx.annotate("TensorRTRuntime", color="red"):
398403
self._caller_stream = torch.cuda.current_stream()
399404
if (
@@ -439,15 +444,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
439444

440445
self._caller_stream.wait_stream(self._engine_stream)
441446

442-
if self.persistent_output_buffer:
443-
if len(self._output_buffers) == 1:
444-
return self._output_buffers[0]
447+
if self.use_pre_allocated_outputs:
448+
with nvtx.annotate("ProcessOutputs:2", color="red"):
449+
self.pre_allocated_outputs = self.create_output_tensors()
445450

446-
return self._output_buffers
447-
else:
448-
if cudagraphs_enabled:
449-
for idx, o in enumerate(outputs):
450-
o.copy_(self._output_buffers[idx])
451+
if cudagraphs_enabled:
452+
for idx, o in enumerate(outputs):
453+
o.copy_(self._output_buffers[idx])
451454

452455
if len(outputs) == 1:
453456
return outputs[0]

0 commit comments

Comments
 (0)