17
17
# Importing supported Backends
18
18
import torch
19
19
import torch_tensorrt as torchtrt
20
- from torch_tensorrt .fx .lower import compile
21
- from torch_tensorrt .fx .utils import LowerPrecision
20
+ # from torch_tensorrt.fx.lower import compile
21
+ # from torch_tensorrt.fx.utils import LowerPrecision
22
22
23
23
import tensorrt as trt
24
24
from utils import (
@@ -134,21 +134,17 @@ def run_torch_tensorrt(
134
134
# Runs inference using FX2TRT backend
135
135
def run_fx2trt (model , input_tensors , params , precision , batch_size ):
136
136
print ("Running FX2TRT for precision: " , precision , " batch_size : " , batch_size )
137
- if precision == "fp32" :
138
- precision = LowerPrecision .FP32
139
- elif precision == "fp16" :
140
- precision = LowerPrecision .FP16
137
+ if precision == "fp16" :
141
138
model .half ()
142
139
input_tensors = [tensor .half () for tensor in input_tensors ]
140
+
143
141
# Run lowering eager mode benchmark
144
142
start_compile = time .time_ns ()
145
- model = compile (
143
+ model = torchtrt . compile (
146
144
model ,
147
- input_tensors ,
148
- max_batch_size = batch_size ,
149
- lower_precision = precision ,
150
- verbose_log = False ,
151
- explicit_batch_dimension = True ,
145
+ ir = "fx" ,
146
+ inputs = input_tensors ,
147
+ enabled_precisions = {torch .float16 if precision == "fp16" else torch .float32 },
152
148
)
153
149
end_compile = time .time_ns ()
154
150
compile_time_ms = (end_compile - start_compile ) / 1e6
@@ -173,6 +169,57 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size):
173
169
174
170
recordStats ("FX-TensorRT" , timings , precision , batch_size , compile_time_ms )
175
171
172
+ def run_dynamo (model , input_tensors , params , precision , batch_size ):
173
+ dynamo_backend = params ["dynamo_backend" ]
174
+ print ("Running Dynamo with backend: " , dynamo_backend , " for precision: " , precision , " batch_size : " , batch_size )
175
+
176
+ if precision == "fp16" :
177
+ input_tensors = [tensor .half () for tensor in input_tensors ]
178
+
179
+ fp16_mode = True if precision == "fp16" else False
180
+ # dynamo_backend_params = {"fp16_mode" : fp16_mode}
181
+ # model = torch.compile(
182
+ # model,
183
+ # mode="default",
184
+ # dynamic=False,
185
+ # fullgraph=False,
186
+ # backend=dynamo_backend,
187
+ # # **dynamo_backend_params
188
+ # )
189
+ import torch ._dynamo as dynamo
190
+ model = dynamo .optimize (dynamo_backend , nopython = True )(model )
191
+ # Compile and measure the time
192
+ with torch .no_grad ():
193
+ start_compile = time .time_ns ()
194
+ features = model (* input_tensors )
195
+ end_compile = time .time_ns ()
196
+ compile_time_ms = (end_compile - start_compile ) / 1e6
197
+ iters = params .get ("iterations" , 20 )
198
+ # import pdb; pdb.set_trace()
199
+ print ("============= DONE 0 ==================" )
200
+
201
+ print ("============= DONE 1 ==================" )
202
+ # Warm up
203
+ model = torch ._dynamo .run (model )
204
+ # import pdb; pdb.set_trace()
205
+
206
+ exported_model , _ = torch ._dynamo .export (model , * input_tensors )
207
+ for i in range (WARMUP_ITER ):
208
+ print ("==== ITER: " , i )
209
+ features = exported_model (* input_tensors )
210
+
211
+ torch .cuda .synchronize ()
212
+ print ("============= DONE 2 ==================" )
213
+ timings = []
214
+ for i in range (iters ):
215
+ start_time = timeit .default_timer ()
216
+ features = exported_model (* input_tensors )
217
+ torch .cuda .synchronize ()
218
+ end_time = timeit .default_timer ()
219
+ meas_time = end_time - start_time
220
+ timings .append (meas_time )
221
+
222
+ recordStats ("Dynamo-" + dynamo_backend , timings , precision , batch_size , compile_time_ms )
176
223
177
224
def torch_dtype_from_trt (dtype ):
178
225
if dtype == trt .int8 :
@@ -274,7 +321,6 @@ def run(
274
321
truncate_long_and_double = False ,
275
322
batch_size = 1 ,
276
323
is_trt_engine = False ,
277
- use_dynamo = False ,
278
324
model_torch = None ,
279
325
):
280
326
for backend in backends :
@@ -307,7 +353,7 @@ def run(
307
353
)
308
354
continue
309
355
310
- if backend == "all" and not use_dynamo :
356
+ if backend == "all" :
311
357
run_torch (model , input_tensors , params , precision , batch_size )
312
358
run_torch_tensorrt (
313
359
model ,
@@ -327,8 +373,9 @@ def run(
327
373
batch_size ,
328
374
)
329
375
run_fx2trt (model_torch , input_tensors , params , precision , batch_size )
376
+ run_dynamo (model_torch , input_tensors , params , precision , batch_size )
330
377
331
- elif backend == "torchscript" and not use_dynamo :
378
+ elif backend == "torchscript" :
332
379
run_torch (model , input_tensors , params , precision , batch_size )
333
380
run_torch_tensorrt (
334
381
model ,
@@ -348,10 +395,10 @@ def run(
348
395
batch_size ,
349
396
)
350
397
351
- elif backend == "torch" and not use_dynamo :
398
+ elif backend == "torch" :
352
399
run_torch (model , input_tensors , params , precision , batch_size )
353
400
354
- elif backend == "torch_tensorrt" and not use_dynamo :
401
+ elif backend == "torch_tensorrt" :
355
402
run_torch_tensorrt (
356
403
model ,
357
404
input_tensors ,
@@ -374,6 +421,8 @@ def run(
374
421
is_trt_engine ,
375
422
batch_size ,
376
423
)
424
+ elif backend == "dynamo" :
425
+ run_dynamo (model_torch , input_tensors , params , precision , batch_size )
377
426
378
427
379
428
# Generate report
@@ -500,15 +549,10 @@ def load_torch_model(params):
500
549
action = "store_true" ,
501
550
help = "Boolean flag to determine if the user provided model is a TRT engine or not" ,
502
551
)
503
- arg_parser .add_argument (
504
- "--dynamo" ,
505
- action = "store_true" ,
506
- help = "Boolean flag to determine if the user provided model should be compiled with torch._dynamo" ,
507
- )
508
552
arg_parser .add_argument (
509
553
"--dynamo_backend" ,
510
554
type = str ,
511
- default = "inductor " ,
555
+ default = "fx2trt " ,
512
556
help = "List of backends to use in Torchdynamo. Select options: inductor|fx2trt" ,
513
557
)
514
558
arg_parser .add_argument (
@@ -591,8 +635,6 @@ def load_torch_model(params):
591
635
592
636
model_name_torch = params ["model_torch" ]
593
637
model_torch = None
594
- use_dynamo = params ["dynamo" ]
595
- dynamo_backend = params ["dynamo_backend" ]
596
638
597
639
# Load TorchScript model, if provided
598
640
if os .path .exists (model_name ):
@@ -615,21 +657,12 @@ def load_torch_model(params):
615
657
+ "or provide a torch model file"
616
658
)
617
659
618
- if use_dynamo and (model_torch is None ):
660
+ backends = parse_backends (params ["backends" ])
661
+ if "dynamo" in backends and (model_torch is None ):
619
662
raise ValueError (
620
- "No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model"
621
- )
622
-
623
- if use_dynamo and model_torch :
624
- model_torch = torch .compile (
625
- model_torch ,
626
- "default" ,
627
- dynamic = False ,
628
- fullgraph = False ,
629
- backend = dynamo_backend ,
663
+ "No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model using --model_torch argument"
630
664
)
631
665
632
- backends = parse_backends (params ["backends" ])
633
666
truncate_long_and_double = params ["truncate" ]
634
667
batch_size = params ["batch_size" ]
635
668
is_trt_engine = params ["is_trt_engine" ]
@@ -639,9 +672,11 @@ def load_torch_model(params):
639
672
input_tensors = parse_inputs (
640
673
params ["inputs" ], precision_to_dtype (precision )
641
674
)
675
+
642
676
if not is_trt_engine and (precision == "fp16" or precision == "half" ):
643
677
# If model is TensorRT serialized engine then model.half will report failure
644
678
model = model .half ()
679
+
645
680
status = run (
646
681
model ,
647
682
backends ,
@@ -651,7 +686,6 @@ def load_torch_model(params):
651
686
truncate_long_and_double ,
652
687
batch_size ,
653
688
is_trt_engine ,
654
- use_dynamo ,
655
689
model_torch = model_torch ,
656
690
)
657
691
0 commit comments