@@ -107,6 +107,9 @@ def __init__(
107
107
self .engine = None
108
108
self .weight_name_map = weight_name_map
109
109
self .target_platform = Platform .current_platform ()
110
+ self .cudagraphs_enabled = False
111
+ self .pre_allocated_outputs : List [torch .Tensor ] = []
112
+ self .use_pre_allocated_outputs = False
110
113
111
114
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
112
115
self .setup_engine ()
@@ -171,7 +174,7 @@ def setup_engine(self) -> None:
171
174
self .engine .get_tensor_shape (input_name ) for input_name in self .input_names
172
175
]
173
176
self .output_dtypes = [
174
- dtype ._from (self .engine .get_tensor_dtype (output_name ))
177
+ dtype ._from (self .engine .get_tensor_dtype (output_name )). to ( torch . dtype )
175
178
for output_name in self .output_names
176
179
]
177
180
self .output_shapes = [
@@ -232,6 +235,19 @@ def __del__(self) -> None:
232
235
if self .cudagraph :
233
236
self .cudagraph .reset ()
234
237
238
+ def create_output_tensors (self ) -> List [torch .Tensor ]:
239
+ # create output tensors
240
+ outputs : List [torch .Tensor ] = []
241
+
242
+ for o , _ in enumerate (self .output_names ):
243
+ output = torch .empty (
244
+ size = self .output_shapes [o ],
245
+ dtype = self .output_dtypes [o ],
246
+ device = torch .cuda .current_device (),
247
+ )
248
+ outputs .append (output )
249
+ return outputs
250
+
235
251
def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
236
252
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
237
253
contiguous_inputs : List [torch .Tensor ] = [
@@ -247,19 +263,25 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
247
263
self ._check_initialized ()
248
264
249
265
cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
250
- need_cudagraphs_record = (
251
- cudagraphs_enabled and not self .cudagraphs_validate_shapes (inputs )
252
- )
266
+ shape_changed = self .validate_input_shapes (inputs )
267
+ # Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
268
+ if not self .cudagraphs_enabled and cudagraphs_enabled :
269
+ need_cudagraphs_record = True
270
+ else :
271
+ need_cudagraphs_record = cudagraphs_enabled and shape_changed
272
+ self .cudagraphs_enabled = cudagraphs_enabled
253
273
254
274
if need_cudagraphs_record :
275
+ if self .cudagraph :
276
+ self .cudagraph .reset ()
255
277
self ._input_buffers = [None ] * len (self .input_names )
256
278
self ._output_buffers = [None ] * len (self .output_names )
257
279
258
280
if not cudagraphs_enabled and self .cudagraph :
259
281
self .cudagraph .reset ()
260
282
self .cudagraph = None
261
283
262
- # If in safe mode, check at each iteration for for whether a switch is required
284
+ # If in safe mode, check at each iteration for whether a switch is required
263
285
if (
264
286
torch_tensorrt .runtime ._multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
265
287
):
@@ -350,14 +372,14 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
350
372
self .context .set_tensor_address (
351
373
input_name , contiguous_inputs [i ].data_ptr ()
352
374
)
353
-
354
- # Check if input shapes can be inferred.
355
- uninferred_input_names = self .context .infer_shapes ()
356
- if uninferred_input_names :
357
- logger .warning (
358
- f"The shapes of the inputs: { uninferred_input_names } cannot be inferred and could lead to undefined behavior. \
359
- This could happen if the input tensor addresses/shapes haven't been configured correctly"
360
- )
375
+ if shape_changed :
376
+ # Check if input shapes can be inferred.
377
+ uninferred_input_names = self .context .infer_shapes ()
378
+ if uninferred_input_names :
379
+ logger .warning (
380
+ f"The shapes of the inputs: { uninferred_input_names } cannot be inferred and could lead to undefined behavior. \
381
+ This could happen if the input tensor addresses/shapes haven't been configured correctly"
382
+ )
361
383
362
384
with (
363
385
torch .autograd .profiler .record_function (
@@ -366,24 +388,20 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
366
388
if self .profiling_enabled
367
389
else nullcontext ()
368
390
):
369
- # create output tensors
370
- outputs : List [torch .Tensor ] = []
371
-
372
- for o , output_name in enumerate (self .output_names ):
373
- shape = tuple (self .context .get_tensor_shape (output_name ))
374
-
375
- if DYNAMIC_DIM in shape :
391
+ if not self .use_pre_allocated_outputs or shape_changed :
392
+ self .output_shapes = [
393
+ tuple (self .context .get_tensor_shape (output_name ))
394
+ for output_name in self .output_names
395
+ ]
396
+ if DYNAMIC_DIM in self .output_shapes :
376
397
raise ValueError (
377
398
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
378
399
)
400
+ outputs = self .create_output_tensors ()
401
+ else :
402
+ outputs = self .pre_allocated_outputs
379
403
380
- output = torch .empty (
381
- size = shape ,
382
- dtype = self .output_dtypes [o ].to (torch .dtype ),
383
- device = torch .cuda .current_device (),
384
- )
385
-
386
- outputs .append (output )
404
+ for o , output_name in enumerate (self .output_names ):
387
405
388
406
if need_cudagraphs_record :
389
407
self ._output_buffers [o ] = outputs [o ].clone ()
@@ -444,6 +462,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
444
462
445
463
self ._caller_stream .wait_stream (self ._engine_stream )
446
464
465
+ if self .use_pre_allocated_outputs :
466
+ self .pre_allocated_outputs = self .create_output_tensors ()
467
+
447
468
if cudagraphs_enabled :
448
469
for idx , o in enumerate (outputs ):
449
470
o .copy_ (self ._output_buffers [idx ])
@@ -485,10 +506,9 @@ def get_layer_info(self) -> str:
485
506
)
486
507
return engine_json
487
508
488
- def cudagraphs_validate_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
509
+ def validate_input_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
489
510
"""
490
- Validates the input shapes of the forward function
491
- versus the version currently active for the
511
+ Validates the input shapes of the forward function has changed
492
512
"""
493
513
# Representation of input shapes to a given model
494
514
# Shapes are concatenated as so:
@@ -498,10 +518,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
498
518
# If the new shape key differs from the existing one,
499
519
# invalidate the old shape key and remove the CUDAGraph
500
520
if new_shape_key != self .shape_key :
501
- logger .debug (f"Resetting Cudagraph on new shape key { new_shape_key } " )
521
+ logger .debug (f"Input shape changed { self . shape_key } -> { new_shape_key } " )
502
522
self .shape_key = new_shape_key
503
- if self .cudagraph :
504
- self .cudagraph .reset ()
505
- return False
523
+ return True
506
524
507
- return True
525
+ return False
0 commit comments