@@ -108,6 +108,9 @@ def __init__(
108
108
self .engine = None
109
109
self .weight_name_map = weight_name_map
110
110
self .target_platform = Platform .current_platform ()
111
+ self .cudagraphs_enabled = False
112
+ self .pre_allocated_outputs : List [torch .Tensor ] = []
113
+ self .use_pre_allocated_outputs = False
111
114
112
115
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
113
116
self .setup_engine ()
@@ -172,7 +175,7 @@ def setup_engine(self) -> None:
172
175
self .engine .get_tensor_shape (input_name ) for input_name in self .input_names
173
176
]
174
177
self .output_dtypes = [
175
- dtype ._from (self .engine .get_tensor_dtype (output_name ))
178
+ dtype ._from (self .engine .get_tensor_dtype (output_name )). to ( torch . dtype )
176
179
for output_name in self .output_names
177
180
]
178
181
self .output_shapes = [
@@ -233,6 +236,19 @@ def __del__(self) -> None:
233
236
if self .cudagraph :
234
237
self .cudagraph .reset ()
235
238
239
+ def create_output_tensors (self ) -> List [torch .Tensor ]:
240
+ # create output tensors
241
+ outputs : List [torch .Tensor ] = []
242
+
243
+ for o , _ in enumerate (self .output_names ):
244
+ output = torch .empty (
245
+ size = self .output_shapes [o ],
246
+ dtype = self .output_dtypes [o ],
247
+ device = torch .cuda .current_device (),
248
+ )
249
+ outputs .append (output )
250
+ return outputs
251
+
236
252
def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
237
253
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
238
254
contiguous_inputs : List [torch .Tensor ] = [
@@ -248,19 +264,25 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
248
264
self ._check_initialized ()
249
265
250
266
cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
251
- need_cudagraphs_record = (
252
- cudagraphs_enabled and not self .cudagraphs_validate_shapes (inputs )
253
- )
267
+ shape_changed = self .validate_input_shapes (inputs )
268
+ # Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
269
+ if not self .cudagraphs_enabled and cudagraphs_enabled :
270
+ need_cudagraphs_record = True
271
+ else :
272
+ need_cudagraphs_record = cudagraphs_enabled and shape_changed
273
+ self .cudagraphs_enabled = cudagraphs_enabled
254
274
255
275
if need_cudagraphs_record :
276
+ if self .cudagraph :
277
+ self .cudagraph .reset ()
256
278
self ._input_buffers = [None ] * len (self .input_names )
257
279
self ._output_buffers = [None ] * len (self .output_names )
258
280
259
281
if not cudagraphs_enabled and self .cudagraph :
260
282
self .cudagraph .reset ()
261
283
self .cudagraph = None
262
284
263
- # If in safe mode, check at each iteration for for whether a switch is required
285
+ # If in safe mode, check at each iteration for whether a switch is required
264
286
if (
265
287
torch_tensorrt .runtime ._multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
266
288
):
@@ -351,14 +373,14 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
351
373
self .context .set_tensor_address (
352
374
input_name , contiguous_inputs [i ].data_ptr ()
353
375
)
354
-
355
- # Check if input shapes can be inferred.
356
- uninferred_input_names = self .context .infer_shapes ()
357
- if uninferred_input_names :
358
- logger .warning (
359
- f"The shapes of the inputs: { uninferred_input_names } cannot be inferred and could lead to undefined behavior. \
360
- This could happen if the input tensor addresses/shapes haven't been configured correctly"
361
- )
376
+ if shape_changed :
377
+ # Check if input shapes can be inferred.
378
+ uninferred_input_names = self .context .infer_shapes ()
379
+ if uninferred_input_names :
380
+ logger .warning (
381
+ f"The shapes of the inputs: { uninferred_input_names } cannot be inferred and could lead to undefined behavior. \
382
+ This could happen if the input tensor addresses/shapes haven't been configured correctly"
383
+ )
362
384
363
385
with (
364
386
torch .autograd .profiler .record_function (
@@ -367,24 +389,20 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
367
389
if self .profiling_enabled
368
390
else nullcontext ()
369
391
):
370
- # create output tensors
371
- outputs : List [torch .Tensor ] = []
372
-
373
- for o , output_name in enumerate (self .output_names ):
374
- shape = tuple (self .context .get_tensor_shape (output_name ))
375
-
376
- if DYNAMIC_DIM in shape :
392
+ if not self .use_pre_allocated_outputs or shape_changed :
393
+ self .output_shapes = [
394
+ tuple (self .context .get_tensor_shape (output_name ))
395
+ for output_name in self .output_names
396
+ ]
397
+ if DYNAMIC_DIM in self .output_shapes :
377
398
raise ValueError (
378
399
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
379
400
)
401
+ outputs = self .create_output_tensors ()
402
+ else :
403
+ outputs = self .pre_allocated_outputs
380
404
381
- output = torch .empty (
382
- size = shape ,
383
- dtype = self .output_dtypes [o ].to (torch .dtype ),
384
- device = torch .cuda .current_device (),
385
- )
386
-
387
- outputs .append (output )
405
+ for o , output_name in enumerate (self .output_names ):
388
406
389
407
if need_cudagraphs_record :
390
408
self ._output_buffers [o ] = outputs [o ].clone ()
@@ -445,6 +463,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
445
463
446
464
self ._caller_stream .wait_stream (self ._engine_stream )
447
465
466
+ if self .use_pre_allocated_outputs :
467
+ self .pre_allocated_outputs = self .create_output_tensors ()
468
+
448
469
if cudagraphs_enabled :
449
470
for idx , o in enumerate (outputs ):
450
471
o .copy_ (self ._output_buffers [idx ])
@@ -486,10 +507,9 @@ def get_layer_info(self) -> str:
486
507
)
487
508
return engine_json
488
509
489
- def cudagraphs_validate_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
510
+ def validate_input_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
490
511
"""
491
- Validates the input shapes of the forward function
492
- versus the version currently active for the
512
+ Validates the input shapes of the forward function has changed
493
513
"""
494
514
# Representation of input shapes to a given model
495
515
# Shapes are concatenated as so:
@@ -499,10 +519,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
499
519
# If the new shape key differs from the existing one,
500
520
# invalidate the old shape key and remove the CUDAGraph
501
521
if new_shape_key != self .shape_key :
502
- logger .debug (f"Resetting Cudagraph on new shape key { new_shape_key } " )
522
+ logger .debug (f"Input shape changed { self . shape_key } -> { new_shape_key } " )
503
523
self .shape_key = new_shape_key
504
- if self .cudagraph :
505
- self .cudagraph .reset ()
506
- return False
524
+ return True
507
525
508
- return True
526
+ return False
0 commit comments