Skip to content

Commit 6af0886

Browse files
committed
chore: remove output buffer opt
1 parent 7f654e3 commit 6af0886

File tree

3 files changed

+81
-73
lines changed

3 files changed

+81
-73
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
519519

520520
dryrun_stats_display(dryrun_tracker, settings.dryrun)
521521

522-
# if len(trt_modules) > 1:
523-
if True:
522+
if len(trt_modules) > 1:
524523
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
525524
partitioned_module = WrapperTorchTensorRTModule(
526525
partitioned_module, dryrun_tracker.output_dtypes

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 79 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

33
import logging
4+
from contextlib import nullcontext
45
from tempfile import tempdir
56
from typing import Any, Dict, List, Optional, Sequence, Tuple
67

7-
import nvtx
88
import tensorrt as trt
99
import torch
1010
import torch_tensorrt
@@ -107,10 +107,8 @@ def __init__(
107107
self.weight_name_map = weight_name_map
108108
self.target_platform = Platform.current_platform()
109109
# Check if CUDA graph capture is enabled in the parent node
110-
self.cudagraphs_parent_module = False
110+
self.cudagraphs_enabled_parent_module = False
111111
self.cudagraphs_enabled = False
112-
self.pre_allocated_outputs: List[torch.Tensor] = []
113-
self.use_pre_allocated_outputs = False
114112

115113
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
116114
self.setup_engine()
@@ -236,42 +234,34 @@ def __del__(self) -> None:
236234
if self.cudagraph:
237235
self.cudagraph.reset()
238236

239-
def create_output_tensors(self) -> List[torch.Tensor]:
240-
# create output tensors
241-
outputs: List[torch.Tensor] = []
242-
243-
for o, _ in enumerate(self.output_names):
244-
output = torch.empty(
245-
size=self.output_shapes[o],
246-
dtype=self.output_dtypes[o],
247-
device=torch.cuda.current_device(),
248-
)
249-
outputs.append(output)
250-
return outputs
251-
252237
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
253238
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
254239
contiguous_inputs: List[torch.Tensor] = [
255240
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
256241
for i in inputs
257242
]
258-
with nvtx.annotate("Forward", color="red"):
243+
244+
with (
245+
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
246+
if self.profiling_enabled
247+
else nullcontext()
248+
):
259249
self._check_initialized()
250+
260251
cudagraphs_enabled = (
261252
torch_tensorrt.runtime.get_cudagraphs_mode()
262-
and not self.cudagraphs_parent_module
253+
and not self.cudagraphs_enabled_parent_module
263254
)
264-
shape_changed = self.validate_input_shapes(inputs)
265-
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
255+
# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
266256
if not self.cudagraphs_enabled and cudagraphs_enabled:
267257
need_cudagraphs_record = True
268258
else:
269-
need_cudagraphs_record = cudagraphs_enabled and shape_changed
259+
need_cudagraphs_record = (
260+
cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs)
261+
)
270262
self.cudagraphs_enabled = cudagraphs_enabled
271263

272264
if need_cudagraphs_record:
273-
if self.cudagraph:
274-
self.cudagraph.reset()
275265
self._input_buffers = [None] * len(self.input_names)
276266
self._output_buffers = [None] * len(self.output_names)
277267

@@ -311,7 +301,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
311301
]
312302
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
313303

314-
with nvtx.annotate("ProcessInputs", color="red"):
304+
with (
305+
torch.autograd.profiler.record_function(
306+
"PythonTorchTensorRTModule:ProcessInputs"
307+
)
308+
if self.profiling_enabled
309+
else nullcontext()
310+
):
315311
assert len(contiguous_inputs) == len(
316312
self.input_names
317313
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."
@@ -364,32 +360,44 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
364360
self.context.set_tensor_address(
365361
input_name, contiguous_inputs[i].data_ptr()
366362
)
367-
if shape_changed:
368-
# Check if input shapes can be inferred.
369-
uninferred_input_names = self.context.infer_shapes()
370-
if uninferred_input_names:
371-
logger.warning(
372-
f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \
373-
This could happen if the input tensor addresses/shapes haven't been configured correctly"
374-
)
375363

376-
with nvtx.annotate("ProcessOutputs:1", color="red"):
377-
if not self.use_pre_allocated_outputs or shape_changed:
378-
self.output_shapes = [
379-
tuple(self.context.get_tensor_shape(output_name))
380-
for output_name in self.output_names
381-
]
382-
if DYNAMIC_DIM in self.output_shapes:
364+
# Check if input shapes can be inferred.
365+
uninferred_input_names = self.context.infer_shapes()
366+
if uninferred_input_names:
367+
logger.warning(
368+
f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \
369+
This could happen if the input tensor addresses/shapes haven't been configured correctly"
370+
)
371+
372+
with (
373+
torch.autograd.profiler.record_function(
374+
"PythonTorchTensorRTModule:ProcessOutputs"
375+
)
376+
if self.profiling_enabled
377+
else nullcontext()
378+
):
379+
# create output tensors
380+
outputs: List[torch.Tensor] = []
381+
382+
for o, output_name in enumerate(self.output_names):
383+
shape = tuple(self.context.get_tensor_shape(output_name))
384+
385+
if DYNAMIC_DIM in shape:
383386
raise ValueError(
384387
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
385388
)
386-
outputs = self.create_output_tensors()
387-
else:
388-
outputs = self.pre_allocated_outputs
389389

390-
for o, output_name in enumerate(self.output_names):
390+
output = torch.empty(
391+
size=shape,
392+
dtype=self.output_dtypes[o],
393+
device=torch.cuda.current_device(),
394+
)
395+
396+
outputs.append(output)
397+
391398
if need_cudagraphs_record:
392399
self._output_buffers[o] = outputs[o].clone()
400+
393401
if cudagraphs_enabled:
394402
self.context.set_tensor_address(
395403
output_name, self._output_buffers[o].data_ptr()
@@ -399,35 +407,37 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
399407
output_name, outputs[o].data_ptr()
400408
)
401409

402-
with nvtx.annotate("TensorRTRuntime", color="red"):
410+
with (
411+
torch.autograd.profiler.record_function(
412+
"PythonTorchTensorRTModule:TensorRTRuntime"
413+
)
414+
if self.profiling_enabled
415+
else nullcontext()
416+
):
403417
self._caller_stream = torch.cuda.current_stream()
404418
if (
405419
self._engine_stream == torch.cuda.default_stream()
406420
or self._engine_stream is None
407421
):
408422
self._engine_stream = torch.cuda.Stream()
409423

410-
with nvtx.annotate("wait_stream", color="green"):
411-
self._engine_stream.wait_stream(self._caller_stream)
424+
self._engine_stream.wait_stream(self._caller_stream)
412425

413426
with torch.cuda.stream(self._engine_stream):
427+
414428
if cudagraphs_enabled:
415429
if need_cudagraphs_record:
416-
with nvtx.annotate("CUDAGraph", color="green"):
417-
self.cudagraph = torch.cuda.CUDAGraph()
430+
self.cudagraph = torch.cuda.CUDAGraph()
418431

419432
if self.profiling_enabled:
420433
self.cudagraph.enable_debug_mode()
421-
with nvtx.annotate("torch.cuda.graph", color="green"):
422-
with torch.cuda.graph(
423-
self.cudagraph, stream=self._engine_stream
424-
):
425-
with nvtx.annotate(
426-
"execute_async_v3", color="green"
427-
):
428-
self.context.execute_async_v3(
429-
self._engine_stream.cuda_stream
430-
)
434+
435+
with torch.cuda.graph(
436+
self.cudagraph, stream=self._engine_stream
437+
):
438+
self.context.execute_async_v3(
439+
self._engine_stream.cuda_stream
440+
)
431441

432442
if self.profiling_enabled:
433443
import tempfile
@@ -436,18 +446,14 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
436446
self.cudagraph.debug_dump(
437447
f"{tempdir}/{self.name}_cudagraph.dot"
438448
)
439-
with nvtx.annotate("replay", color="green"):
440-
self.cudagraph.replay() # type: ignore
449+
450+
self.cudagraph.replay() # type: ignore
441451

442452
else:
443453
self.context.execute_async_v3(self._engine_stream.cuda_stream)
444454

445455
self._caller_stream.wait_stream(self._engine_stream)
446456

447-
if self.use_pre_allocated_outputs:
448-
with nvtx.annotate("ProcessOutputs:2", color="red"):
449-
self.pre_allocated_outputs = self.create_output_tensors()
450-
451457
if cudagraphs_enabled:
452458
for idx, o in enumerate(outputs):
453459
o.copy_(self._output_buffers[idx])
@@ -489,9 +495,10 @@ def get_layer_info(self) -> str:
489495
)
490496
return engine_json
491497

492-
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
498+
def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
493499
"""
494-
Validates the input shapes of the forward function has changed
500+
Validates the input shapes of the forward function
501+
versus the version currently active for the
495502
"""
496503
# Representation of input shapes to a given model
497504
# Shapes are concatenated as so:
@@ -501,8 +508,10 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
501508
# If the new shape key differs from the existing one,
502509
# invalidate the old shape key and remove the CUDAGraph
503510
if new_shape_key != self.shape_key:
504-
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
511+
logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}")
505512
self.shape_key = new_shape_key
506-
return True
513+
if self.cudagraph:
514+
self.cudagraph.reset()
515+
return False
507516

508-
return False
517+
return True

py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
# Disable cudagrphs in submodules as it will be enabled in wrapper
4242
for name, rt_mod in self.original_module.named_children():
4343
if "_run_on_acc" in name:
44-
rt_mod.cudagraphs_parent_module = True
44+
rt_mod.cudagraphs_enabled_parent_module = True
4545

4646
# TODO: check if only torch needs warm up.
4747
with unset_fake_temporarily():

0 commit comments

Comments
 (0)