23
23
logger = logging .getLogger (__name__ )
24
24
25
25
26
+ class DynamicOutputAllocator (trt .IOutputAllocator ): # type: ignore[misc]
27
+ def __init__ (self , output_dtypes : Dict [str , torch .dtype ]) -> None :
28
+ trt .IOutputAllocator .__init__ (self )
29
+ self .buffers : Dict [str , torch .Tensor ] = {}
30
+ self .shapes : Dict [str , Tuple [int , ...]] = {}
31
+ self .dtypes : Dict [str , torch .dtype ] = output_dtypes
32
+
33
+ def reallocate_output_async (
34
+ self ,
35
+ tensor_name : str ,
36
+ memory : int ,
37
+ size : int ,
38
+ alignment : int ,
39
+ stream : torch .cuda .Stream ,
40
+ ) -> Any :
41
+ shape = (size ,)
42
+ if tensor_name not in self .buffers :
43
+ self .buffers [tensor_name ] = torch .empty (
44
+ shape ,
45
+ dtype = self .dtypes [tensor_name ],
46
+ device = torch .cuda .current_device (),
47
+ )
48
+ else :
49
+ if self .buffers [tensor_name ].shape != shape :
50
+ self .buffers [tensor_name ] = torch .empty (
51
+ shape ,
52
+ dtype = self .dtypes [tensor_name ],
53
+ device = torch .cuda .current_device (),
54
+ )
55
+ return self .buffers [tensor_name ].data_ptr ()
56
+
57
+ def notify_shape (self , tensor_name : str , shape : Tuple [int , ...]) -> None :
58
+ self .shapes [tensor_name ] = tuple (shape )
59
+
60
+
26
61
class TorchTRTRuntimeStates :
27
62
def __init__ (self , new_cudagraphs : bool ):
28
63
# Indicates whether CUDAGraphs were enabled in the previous execute_engine
@@ -164,8 +199,11 @@ def __init__(
164
199
self .runtime_states = TorchTRTRuntimeStates (
165
200
torch_tensorrt .runtime .get_cudagraphs_mode ()
166
201
)
202
+
203
+ self .contains_dds_layer = False
167
204
self .pre_allocated_outputs : List [torch .Tensor ] = []
168
205
self .use_pre_allocated_outputs = False
206
+ self .output_allocator : Optional [DynamicOutputAllocator ] = None
169
207
170
208
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
171
209
self .setup_engine ()
@@ -238,9 +276,19 @@ def setup_engine(self) -> None:
238
276
for output_name in self .output_names
239
277
]
240
278
279
+ self .contains_dds_layer = self ._check_dds_layer ()
280
+ if self .contains_dds_layer :
281
+ self .setup_output_allocator ()
282
+
241
283
if torch_tensorrt .runtime .get_cudagraphs_mode ():
242
284
self .cudagraph = torch .cuda .CUDAGraph ()
243
285
286
+ def _check_dds_layer (self ) -> bool :
287
+ layer_info = self .get_layer_info ()
288
+ if "trainStation" in layer_info : # contains dds layer
289
+ return True
290
+ return False
291
+
244
292
def _check_initialized (self ) -> None :
245
293
if not self .initialized :
246
294
raise RuntimeError ("PythonTorchTensorRTModule is not initialized." )
@@ -358,19 +406,22 @@ def create_output_tensors(self) -> List[torch.Tensor]:
358
406
def set_pre_allocated_outputs (self , enable : bool ) -> None :
359
407
self .use_pre_allocated_outputs = enable
360
408
409
+ def setup_output_allocator (self ) -> None :
410
+ if self .output_allocator is None :
411
+ output_dtypes_dict = {}
412
+ for o , output_name in enumerate (self .output_names ):
413
+ output_dtypes_dict [output_name ] = self .output_dtypes [o ]
414
+ self .output_allocator = DynamicOutputAllocator (output_dtypes_dict )
415
+
416
+ for output_name in self .output_names :
417
+ if not self .context .set_output_allocator (
418
+ output_name , self .output_allocator
419
+ ):
420
+ raise RuntimeError (f"Failed to set output allocator for { output_name } " )
421
+
361
422
def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
362
- # Ensure inputs are available in all scopes and cast symbolic integers to Tensors
363
- contiguous_inputs : List [torch .Tensor ] = [
364
- (i .contiguous () if isinstance (i , torch .Tensor ) else torch .tensor (i ).cuda ())
365
- for i in inputs
366
- ]
367
- with (
368
- torch .autograd .profiler .record_function ("PythonTorchTensorRTModule:Forward" )
369
- if self .profiling_enabled
370
- else nullcontext ()
371
- ):
372
- self ._check_initialized ()
373
423
424
+ def run_cuda_graph () -> torch .Tensor | Tuple [torch .Tensor , ...]:
374
425
cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
375
426
shape_changed = self .validate_input_shapes (inputs )
376
427
(
@@ -389,38 +440,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
389
440
self ._input_buffers = [None ] * len (self .input_names )
390
441
self ._output_buffers = [None ] * len (self .output_names )
391
442
392
- # If in safe mode, check at each iteration for whether a switch is required
393
- if (
394
- torch_tensorrt .runtime ._multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
395
- ):
396
- curr_device_id = torch .cuda .current_device ()
397
- curr_device_properties = torch .cuda .get_device_properties (
398
- curr_device_id
399
- )
400
- logger .debug (f"Current Device: cuda:{ curr_device_id } " )
401
-
402
- # If a switch is required, move all inputs to new device and set as active device
403
- if _is_switch_required (
404
- curr_device_id ,
405
- self .target_device_id ,
406
- curr_device_properties ,
407
- self .target_device_properties ,
408
- ):
409
- device_id , _ = _select_rt_device (
410
- curr_device_id ,
411
- self .target_device_id ,
412
- self .target_device_properties ,
413
- )
414
-
415
- # Update current device
416
- device = torch .device (device_id )
417
- torch .cuda .set_device (device_id )
418
-
419
- contiguous_inputs = [
420
- tensor .to (device ) for tensor in contiguous_inputs
421
- ]
422
- logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
423
-
424
443
with (
425
444
torch .autograd .profiler .record_function (
426
445
"PythonTorchTensorRTModule:ProcessInputs"
@@ -536,6 +555,118 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
536
555
537
556
return outputs
538
557
558
+ def run_output_allocator () -> torch .Tensor | Tuple [torch .Tensor , ...]:
559
+ with (
560
+ torch .autograd .profiler .record_function (
561
+ "PythonTorchTensorRTModule:ProcessInputs"
562
+ )
563
+ if self .profiling_enabled
564
+ else nullcontext ()
565
+ ):
566
+ assert len (contiguous_inputs ) == len (
567
+ self .input_names
568
+ ), f"Wrong number of inputs, expect { len (self .input_names )} get { len (contiguous_inputs )} ."
569
+
570
+ self .setup_input_tensors (contiguous_inputs , False , False )
571
+
572
+ with (
573
+ torch .autograd .profiler .record_function (
574
+ "PythonTorchTensorRTModule:TensorRTRuntime"
575
+ )
576
+ if self .profiling_enabled
577
+ else nullcontext ()
578
+ ):
579
+ self ._caller_stream = torch .cuda .current_stream ()
580
+ if (
581
+ self ._engine_stream == torch .cuda .default_stream ()
582
+ or self ._engine_stream is None
583
+ ):
584
+ self ._engine_stream = torch .cuda .Stream ()
585
+
586
+ self ._engine_stream .wait_stream (self ._caller_stream )
587
+
588
+ with torch .cuda .stream (self ._engine_stream ):
589
+ self .context .execute_async_v3 (
590
+ self ._engine_stream .cuda_stream
591
+ ) # The OutputAllocator is called by execute_async_v3()
592
+
593
+ self ._caller_stream .wait_stream (self ._engine_stream )
594
+
595
+ with (
596
+ torch .autograd .profiler .record_function (
597
+ "PythonTorchTensorRTModule:ProcessOutputs"
598
+ )
599
+ if self .profiling_enabled
600
+ else nullcontext ()
601
+ ):
602
+ outputs = []
603
+ assert self .output_allocator is not None
604
+ for o , output_name in enumerate (self .output_names ):
605
+ shape = self .output_allocator .shapes .get (output_name , None )
606
+ dtype = self .output_dtypes [o ]
607
+ output = (
608
+ self .output_allocator .buffers .get (output_name , None )
609
+ .clone ()
610
+ .detach ()
611
+ )
612
+ prod = int (torch .prod (torch .tensor (shape )))
613
+ output = output .reshape (- 1 ).view (dtype )[:prod ].reshape (shape )
614
+ outputs .append (output )
615
+
616
+ if len (outputs ) == 1 :
617
+ return outputs [0 ]
618
+
619
+ return outputs
620
+
621
+ # Run forward function
622
+ contiguous_inputs : List [torch .Tensor ] = [
623
+ (i .contiguous () if isinstance (i , torch .Tensor ) else torch .tensor (i ).cuda ())
624
+ for i in inputs
625
+ ]
626
+ with (
627
+ torch .autograd .profiler .record_function ("PythonTorchTensorRTModule:Forward" )
628
+ if self .profiling_enabled
629
+ else nullcontext ()
630
+ ):
631
+ self ._check_initialized ()
632
+
633
+ # If in safe mode, check at each iteration for whether a switch is required
634
+ if (
635
+ torch_tensorrt .runtime ._multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
636
+ ):
637
+ curr_device_id = torch .cuda .current_device ()
638
+ curr_device_properties = torch .cuda .get_device_properties (
639
+ curr_device_id
640
+ )
641
+ logger .debug (f"Current Device: cuda:{ curr_device_id } " )
642
+
643
+ # If a switch is required, move all inputs to new device and set as active device
644
+ if _is_switch_required (
645
+ curr_device_id ,
646
+ self .target_device_id ,
647
+ curr_device_properties ,
648
+ self .target_device_properties ,
649
+ ):
650
+ device_id , _ = _select_rt_device (
651
+ curr_device_id ,
652
+ self .target_device_id ,
653
+ self .target_device_properties ,
654
+ )
655
+
656
+ # Update current device
657
+ device = torch .device (device_id )
658
+ torch .cuda .set_device (device_id )
659
+
660
+ contiguous_inputs = [
661
+ tensor .to (device ) for tensor in contiguous_inputs
662
+ ]
663
+ logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
664
+
665
+ if self .contains_dds_layer :
666
+ return run_output_allocator ()
667
+ else :
668
+ return run_cuda_graph ()
669
+
539
670
def enable_profiling (self , profiler : "trt.IProfiler" = None ) -> None :
540
671
"""
541
672
Enable TensorRT profiling. After calling this function, TensorRT will report
0 commit comments