@@ -109,7 +109,8 @@ def __init__(
109
109
# Check if CUDA graph capture is enabled in the parent node
110
110
self .cudagraphs_parent_module = False
111
111
self .cudagraphs_enabled = False
112
- self .persistent_output_buffer = False
112
+ self .pre_allocated_outputs : List [torch .Tensor ] = []
113
+ self .use_pre_allocated_outputs = False
113
114
114
115
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
115
116
self .setup_engine ()
@@ -174,7 +175,7 @@ def setup_engine(self) -> None:
174
175
self .engine .get_tensor_shape (input_name ) for input_name in self .input_names
175
176
]
176
177
self .output_dtypes = [
177
- dtype ._from (self .engine .get_tensor_dtype (output_name ))
178
+ dtype ._from (self .engine .get_tensor_dtype (output_name )). to ( torch . dtype )
178
179
for output_name in self .output_names
179
180
]
180
181
self .output_shapes = [
@@ -235,6 +236,19 @@ def __del__(self) -> None:
235
236
if self .cudagraph :
236
237
self .cudagraph .reset ()
237
238
239
+ def create_output_tensors (self ) -> List [torch .Tensor ]:
240
+ # create output tensors
241
+ outputs : List [torch .Tensor ] = []
242
+
243
+ for o , _ in enumerate (self .output_names ):
244
+ output = torch .empty (
245
+ size = self .output_shapes [o ],
246
+ dtype = self .output_dtypes [o ],
247
+ device = torch .cuda .current_device (),
248
+ )
249
+ outputs .append (output )
250
+ return outputs
251
+
238
252
def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
239
253
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
240
254
contiguous_inputs : List [torch .Tensor ] = [
@@ -350,50 +364,41 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
350
364
self .context .set_tensor_address (
351
365
input_name , contiguous_inputs [i ].data_ptr ()
352
366
)
367
+ if shape_changed :
368
+ # Check if input shapes can be inferred.
369
+ uninferred_input_names = self .context .infer_shapes ()
370
+ if uninferred_input_names :
371
+ logger .warning (
372
+ f"The shapes of the inputs: { uninferred_input_names } cannot be inferred and could lead to undefined behavior. \
373
+ This could happen if the input tensor addresses/shapes haven't been configured correctly"
374
+ )
353
375
354
- # Check if input shapes can be inferred.
355
- uninferred_input_names = self .context .infer_shapes ()
356
- if uninferred_input_names :
357
- logger .warning (
358
- f"The shapes of the inputs: { uninferred_input_names } cannot be inferred and could lead to undefined behavior. \
359
- This could happen if the input tensor addresses/shapes haven't been configured correctly"
360
- )
361
-
362
- with nvtx .annotate ("ProcessOutputs" , color = "red" ):
363
- # create output tensors
364
- outputs : List [torch .Tensor ] = []
365
- if not self .persistent_output_buffer or shape_changed :
366
- # Create and keep persistent output buffer as long as its shape does not change
367
- self ._output_buffers = []
368
- for o , output_name in enumerate (self .output_names ):
369
- shape = tuple (self .context .get_tensor_shape (output_name ))
370
- if DYNAMIC_DIM in shape :
371
- raise ValueError (
372
- "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
373
- )
376
+ with nvtx .annotate ("ProcessOutputs:1" , color = "red" ):
377
+ if not self .use_pre_allocated_outputs or shape_changed :
378
+ self .output_shapes = [
379
+ tuple (self .context .get_tensor_shape (output_name ))
380
+ for output_name in self .output_names
381
+ ]
382
+ if DYNAMIC_DIM in self .output_shapes :
383
+ raise ValueError (
384
+ "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
385
+ )
386
+ outputs = self .create_output_tensors ()
387
+ else :
388
+ outputs = self .pre_allocated_outputs
374
389
375
- output = torch .empty (
376
- size = shape ,
377
- dtype = self .output_dtypes [o ].to (torch .dtype ),
378
- device = torch .cuda .current_device (),
390
+ for o , output_name in enumerate (self .output_names ):
391
+ if need_cudagraphs_record :
392
+ self ._output_buffers [o ] = outputs [o ].clone ()
393
+ if cudagraphs_enabled :
394
+ self .context .set_tensor_address (
395
+ output_name , self ._output_buffers [o ].data_ptr ()
379
396
)
380
- if self .persistent_output_buffer :
381
- self .context .set_tensor_address (
382
- output_name , output .data_ptr ()
383
- )
384
- self ._output_buffers .append (output )
385
- else :
386
- outputs .append (output )
387
- if need_cudagraphs_record :
388
- self ._output_buffers [o ] = outputs [o ].clone ()
389
- if cudagraphs_enabled :
390
- self .context .set_tensor_address (
391
- output_name , self ._output_buffers [o ].data_ptr ()
392
- )
393
- else :
394
- self .context .set_tensor_address (
395
- output_name , outputs [o ].data_ptr ()
396
- )
397
+ else :
398
+ self .context .set_tensor_address (
399
+ output_name , outputs [o ].data_ptr ()
400
+ )
401
+
397
402
with nvtx .annotate ("TensorRTRuntime" , color = "red" ):
398
403
self ._caller_stream = torch .cuda .current_stream ()
399
404
if (
@@ -439,15 +444,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
439
444
440
445
self ._caller_stream .wait_stream (self ._engine_stream )
441
446
442
- if self .persistent_output_buffer :
443
- if len ( self . _output_buffers ) == 1 :
444
- return self ._output_buffers [ 0 ]
447
+ if self .use_pre_allocated_outputs :
448
+ with nvtx . annotate ( "ProcessOutputs:2" , color = "red" ) :
449
+ self . pre_allocated_outputs = self .create_output_tensors ()
445
450
446
- return self ._output_buffers
447
- else :
448
- if cudagraphs_enabled :
449
- for idx , o in enumerate (outputs ):
450
- o .copy_ (self ._output_buffers [idx ])
451
+ if cudagraphs_enabled :
452
+ for idx , o in enumerate (outputs ):
453
+ o .copy_ (self ._output_buffers [idx ])
451
454
452
455
if len (outputs ) == 1 :
453
456
return outputs [0 ]
0 commit comments