@@ -274,6 +274,7 @@ def run(
274
274
truncate_long_and_double = False ,
275
275
batch_size = 1 ,
276
276
is_trt_engine = False ,
277
+ use_dynamo = False ,
277
278
model_torch = None ,
278
279
):
279
280
for backend in backends :
@@ -306,7 +307,7 @@ def run(
306
307
)
307
308
continue
308
309
309
- if backend == "all" :
310
+ if backend == "all" and not use_dynamo :
310
311
run_torch (model , input_tensors , params , precision , batch_size )
311
312
run_torch_tensorrt (
312
313
model ,
@@ -327,7 +328,7 @@ def run(
327
328
)
328
329
run_fx2trt (model_torch , input_tensors , params , precision , batch_size )
329
330
330
- elif backend == "torchscript" :
331
+ elif backend == "torchscript" and not use_dynamo :
331
332
run_torch (model , input_tensors , params , precision , batch_size )
332
333
run_torch_tensorrt (
333
334
model ,
@@ -347,10 +348,10 @@ def run(
347
348
batch_size ,
348
349
)
349
350
350
- elif backend == "torch" :
351
+ elif backend == "torch" and not use_dynamo :
351
352
run_torch (model , input_tensors , params , precision , batch_size )
352
353
353
- elif backend == "torch_tensorrt" :
354
+ elif backend == "torch_tensorrt" and not use_dynamo :
354
355
run_torch_tensorrt (
355
356
model ,
356
357
input_tensors ,
@@ -499,6 +500,17 @@ def load_torch_model(params):
499
500
action = "store_true" ,
500
501
help = "Boolean flag to determine if the user provided model is a TRT engine or not" ,
501
502
)
503
+ arg_parser .add_argument (
504
+ "--dynamo" ,
505
+ action = "store_true" ,
506
+ help = "Boolean flag to determine if the user provided model should be compiled with torch._dynamo" ,
507
+ )
508
+ arg_parser .add_argument (
509
+ "--dynamo_backend" ,
510
+ type = str ,
511
+ default = "inductor" ,
512
+ help = "List of backends to use in Torchdynamo. Select options: inductor|fx2trt" ,
513
+ )
502
514
arg_parser .add_argument (
503
515
"--report" ,
504
516
type = str ,
@@ -579,6 +591,8 @@ def load_torch_model(params):
579
591
580
592
model_name_torch = params ["model_torch" ]
581
593
model_torch = None
594
+ use_dynamo = params ["dynamo" ]
595
+ dynamo_backend = params ["dynamo_backend" ]
582
596
583
597
# Load TorchScript model, if provided
584
598
if os .path .exists (model_name ):
@@ -601,6 +615,12 @@ def load_torch_model(params):
601
615
+ "or provide a torch model file"
602
616
)
603
617
618
+ if use_dynamo and (model_torch is None ):
619
+ raise ValueError ("No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model" )
620
+
621
+ if use_dynamo and model_torch :
622
+ model_torch = torch .compile (model_torch , "default" , dynamic = False , fullgraph = False , backend = dynamo_backend )
623
+
604
624
backends = parse_backends (params ["backends" ])
605
625
truncate_long_and_double = params ["truncate" ]
606
626
batch_size = params ["batch_size" ]
@@ -623,6 +643,7 @@ def load_torch_model(params):
623
643
truncate_long_and_double ,
624
644
batch_size ,
625
645
is_trt_engine ,
646
+ use_dynamo ,
626
647
model_torch = model_torch ,
627
648
)
628
649
0 commit comments