1
+ import inspect
1
2
import logging
2
3
from copy import deepcopy
3
4
from enum import Enum , auto
@@ -41,6 +42,10 @@ def get_state(self) -> RefitFlag:
41
42
return self ._state
42
43
43
44
45
+ class DynamicShapeOutOfRangeException (Exception ):
46
+ pass
47
+
48
+
44
49
class MutableTorchTensorRTModule (object ):
45
50
"""
46
51
Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module.
@@ -65,7 +70,7 @@ def __init__(
65
70
Union [torch .dtype , dtype ]
66
71
] = _defaults .ENABLED_PRECISIONS ,
67
72
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
68
- immutable_weights : bool = _defaults . IMMUTABLE_WEIGHTS ,
73
+ immutable_weights : bool = False ,
69
74
debug : bool = _defaults .DEBUG ,
70
75
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
71
76
workspace_size : int = _defaults .WORKSPACE_SIZE ,
@@ -189,6 +194,9 @@ def __init__(
189
194
"hardware_compatible" : hardware_compatible ,
190
195
"timing_cache_path" : timing_cache_path ,
191
196
}
197
+ self .arg_dynamic_shapes : Optional [tuple [Any ]] = None
198
+ self .kwarg_dynamic_shapes : Optional [dict [Any , Any ]] = None
199
+ self .total_dynamic_shape : Optional [dict [Any , Any ]] = None
192
200
193
201
self .settings = CompilationSettings (** compilation_options )
194
202
self .run_info : Optional [tuple [Any , ...]] = None
@@ -203,6 +211,31 @@ def __init__(
203
211
)
204
212
self .init_finished = True
205
213
214
+ def set_dynamic_shape_hint (
215
+ self ,
216
+ args_dynamic_shape : tuple [dict [Any , Any ]],
217
+ kwargs_dynamic_shape : dict [str , Any ],
218
+ ) -> None :
219
+ assert isinstance (
220
+ args_dynamic_shape , tuple
221
+ ), "args dynamic shape has to be a tuple"
222
+ assert isinstance (
223
+ kwargs_dynamic_shape , dict
224
+ ), "args dynamic shape has to be a dictionary"
225
+ self .kwarg_dynamic_shapes = kwargs_dynamic_shape
226
+ self .arg_dynamic_shapes = args_dynamic_shape
227
+ self .total_dynamic_shape = self .kwarg_dynamic_shapes .copy ()
228
+ signature = list (
229
+ inspect .signature (self .original_model .forward ).parameters .keys ()
230
+ )
231
+ for i , arg in enumerate (self .arg_dynamic_shapes ):
232
+ self .total_dynamic_shape [signature [i ]] = arg
233
+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
234
+
235
+ # Clear cached inputs
236
+ self .arg_inputs = tuple ()
237
+ self .kwarg_inputs = {}
238
+
206
239
def store_state_dict_metadata (self ) -> None :
207
240
for k , v in self .original_model .state_dict ().items ():
208
241
self .state_dict_metadata [k ] = v .shape
@@ -295,6 +328,7 @@ def compile(self) -> None:
295
328
self .original_model ,
296
329
self .arg_inputs ,
297
330
kwargs = self .kwarg_inputs ,
331
+ dynamic_shapes = self .total_dynamic_shape ,
298
332
)
299
333
self .gm = dynamo_compile (
300
334
self .exp_program ,
@@ -306,14 +340,26 @@ def compile(self) -> None:
306
340
torch .cuda .empty_cache ()
307
341
308
342
def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
309
- if (
310
- not self .arg_inputs
311
- or not MutableTorchTensorRTModule .check_inputs_equal (self .arg_inputs , args )
312
- or not MutableTorchTensorRTModule .check_inputs_equal (
313
- self .kwarg_inputs , kwargs
314
- )
315
- ):
343
+ try :
344
+ if (
345
+ not self .arg_inputs
346
+ or not MutableTorchTensorRTModule .check_inputs_equal (
347
+ self .arg_inputs , args , dynamic_shapes = self .arg_dynamic_shapes
348
+ )
349
+ or not MutableTorchTensorRTModule .check_inputs_equal (
350
+ self .kwarg_inputs , kwargs , dynamic_shapes = self .kwarg_dynamic_shapes
351
+ )
352
+ ):
353
+ logger .info ("Input change detected." )
354
+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
355
+ self .store_inputs (args , kwargs )
356
+ except DynamicShapeOutOfRangeException as e :
316
357
logger .info ("Input change detected." )
358
+ logger .warning (e )
359
+ logger .warning ("Recompiling the engine with static shape" )
360
+ self .arg_dynamic_shapes = None
361
+ self .kwarg_dynamic_shapes = None
362
+ self .total_dynamic_shape = None
317
363
self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
318
364
self .store_inputs (args , kwargs )
319
365
@@ -436,35 +482,68 @@ def __setattr__(self, name: str, value: Any) -> None:
436
482
def check_inputs_equal (
437
483
input1 : Any ,
438
484
input2 : Any ,
485
+ dynamic_shapes : Any = None ,
439
486
) -> bool :
440
- # TODO: Add support for dynamic shape
487
+
441
488
if isinstance (input1 , (tuple , list )):
442
489
if len (input1 ) != len (input2 ):
443
490
return False
444
- for a , b in zip (input1 , input2 ):
491
+ for ( i , a ) , b in zip (enumerate ( input1 ) , input2 ):
445
492
if type (a ) != type (b ):
446
493
return False
447
- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
448
- return False
449
- elif isinstance (a , bool ) and a != b :
494
+ if isinstance (a , bool ) and a != b :
450
495
return False
496
+ elif isinstance (a , torch .Tensor ) and a .shape != b .shape :
497
+ if dynamic_shapes is None :
498
+ return False
499
+ else :
500
+ tensor_dynamic_shape = dynamic_shapes [i ]
501
+ if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
502
+ a , b , tensor_dynamic_shape
503
+ ):
504
+ return False
451
505
452
506
elif isinstance (input1 , dict ):
453
507
if input1 .keys () != input2 .keys ():
454
508
return False
455
- for a , b in zip (input1 .values (), input2 .values ()):
456
- if type (a ) != type (b ):
457
- return False
458
- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
509
+ for (ka , va ), vb in zip (input1 .items (), input2 .values ()):
510
+ if type (va ) != type (vb ):
459
511
return False
460
- elif isinstance (a , bool ) and a != b :
512
+ if isinstance (va , bool ) and va != vb :
461
513
return False
514
+ elif isinstance (va , torch .Tensor ) and va .shape != vb .shape :
515
+ if dynamic_shapes is None :
516
+ return False
517
+ else :
518
+ tensor_dynamic_shape = dynamic_shapes [ka ]
519
+ if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
520
+ va , vb , tensor_dynamic_shape
521
+ ):
522
+ return False
462
523
elif isinstance (
463
- a , (list , tuple , dict )
464
- ) and not MutableTorchTensorRTModule .check_inputs_equal (a , b ):
524
+ va , (list , tuple , dict )
525
+ ) and not MutableTorchTensorRTModule .check_inputs_equal (
526
+ va , vb , dynamic_shapes [ka ] if dynamic_shapes else None
527
+ ):
465
528
return False
466
529
return True
467
530
531
+ @staticmethod
532
+ def check_tensor_shapes_with_dynamic_shapes (
533
+ t1 : torch .tensor , t2 : torch .tensor , dynamic_shape : dict [int , Any ]
534
+ ) -> bool :
535
+ for (i , axis_0 ), axis_1 in zip (enumerate (t1 .shape ), t2 .shape ):
536
+ if axis_0 != axis_1 :
537
+ if i not in dynamic_shape :
538
+ return False
539
+ dyn = dynamic_shape [i ]
540
+ if axis_1 > dyn .max or axis_1 < dyn .min :
541
+ raise DynamicShapeOutOfRangeException (
542
+ f"The input size ({ axis_1 } ) of dimension ({ i } ) is not in dynamic shape range [{ dyn .max } , { dyn .max } ]!"
543
+ )
544
+
545
+ return True
546
+
468
547
@staticmethod
469
548
def save (module : Any , path : str ) -> None :
470
549
# Cast the object back to MutableTorchTensorRTModule to save
0 commit comments