15
15
16
16
# Importing supported Backends
17
17
import torch
18
+ import torch_tensorrt as torchtrt
18
19
from utils import (
19
20
BENCHMARK_MODELS ,
20
21
parse_backends ,
23
24
precision_to_dtype ,
24
25
)
25
26
26
- import torch_tensorrt as torchtrt
27
-
28
27
WARMUP_ITER = 10
29
28
results = []
30
29
@@ -294,29 +293,30 @@ def run_tensorrt(
294
293
input_tensors ,
295
294
params ,
296
295
precision ,
297
- is_trt_engine = False ,
298
296
batch_size = 1 ,
299
297
):
300
- engine = None
301
-
302
- # If the model file is a TensorRT engine then directly deserialize and run inference
303
- # else convert the torch module to a TensorRT engine first and then run inference
304
- if not is_trt_engine :
305
- compile_settings = {
306
- "inputs" : input_tensors ,
307
- "enabled_precisions" : {precision_to_dtype (precision )},
308
- "truncate_long_and_double" : params .get ("truncate" , False ),
309
- }
310
-
311
- print ("Converting method to TensorRT engine..." )
312
- with torch .no_grad (), torchtrt .logging .errors ():
313
- model = torchtrt .ts .convert_method_to_trt_engine (
314
- model , "forward" , ** compile_settings
315
- )
316
-
298
+ # Export an ONNX model and convert to TRT
299
+ torch .onnx .export (model .eval ().cuda (), tuple (input_tensors ), "./tmp.onnx" )
300
+ logger = trt .Logger (trt .Logger .WARNING )
301
+ builder = trt .Builder (logger )
302
+ network = builder .create_network (
303
+ 1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )
304
+ )
305
+ parser = trt .OnnxParser (network , logger )
306
+ success = parser .parse_from_file ("./tmp.onnx" )
307
+ if not success :
308
+ raise ValueError ("ONNX conversion failed" )
309
+
310
+ config = builder .create_builder_config ()
311
+ if precision == "fp16" :
312
+ config .set_flag (trt .BuilderFlag .FP16 )
313
+ start_compile = time .time_ns ()
314
+ serialized_engine = builder .build_serialized_network (network , config )
315
+ end_compile = time .time_ns ()
316
+ compile_time_s = (end_compile - start_compile ) / 1e9
317
317
# Deserialize the TensorRT engine
318
- with trt .Logger () as logger , trt . Runtime (logger ) as runtime :
319
- engine = runtime .deserialize_cuda_engine (model )
318
+ with trt .Runtime (logger ) as runtime :
319
+ engine = runtime .deserialize_cuda_engine (serialized_engine )
320
320
321
321
print ("Running TensorRT for precision: " , precision , " batch_size : " , batch_size )
322
322
iters = params .get ("iterations" , 20 )
@@ -351,7 +351,7 @@ def run_tensorrt(
351
351
meas_time = end_time - start_time
352
352
timings .append (meas_time )
353
353
354
- recordStats ("TensorRT" , timings , precision , batch_size )
354
+ recordStats ("TensorRT" , timings , precision , batch_size , compile_time_s )
355
355
356
356
357
357
# Deploys inference run for different backend configurations
@@ -427,11 +427,10 @@ def run(
427
427
)
428
428
elif backend == "tensorrt" :
429
429
run_tensorrt (
430
- model ,
430
+ model_torch ,
431
431
input_tensors ,
432
432
params ,
433
433
precision ,
434
- is_trt_engine ,
435
434
batch_size ,
436
435
)
437
436
elif backend == "dynamo" :
@@ -440,9 +439,6 @@ def run(
440
439
elif backend == "torch_compile" :
441
440
run_torch_compile (model_torch , input_tensors , params , precision , batch_size )
442
441
443
- elif backend == "torch_compile" :
444
- run_torch_compile (model_torch , input_tensors , params , precision , batch_size )
445
-
446
442
elif backend == "inductor" :
447
443
run_inductor (model_torch , input_tensors , params , precision , batch_size )
448
444
0 commit comments