Skip to content

Commit 52aa28e

Browse files
committed
chore: Add TRT runner via onnx
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 4f8eb56 commit 52aa28e

File tree

2 files changed

+22
-27
lines changed

2 files changed

+22
-27
lines changed

tools/perf/perf_run.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
# Importing supported Backends
1717
import torch
18+
import torch_tensorrt as torchtrt
1819
from utils import (
1920
BENCHMARK_MODELS,
2021
parse_backends,
@@ -23,8 +24,6 @@
2324
precision_to_dtype,
2425
)
2526

26-
import torch_tensorrt as torchtrt
27-
2827
WARMUP_ITER = 10
2928
results = []
3029

@@ -294,29 +293,27 @@ def run_tensorrt(
294293
input_tensors,
295294
params,
296295
precision,
297-
is_trt_engine=False,
298296
batch_size=1,
299297
):
300-
engine = None
301-
302-
# If the model file is a TensorRT engine then directly deserialize and run inference
303-
# else convert the torch module to a TensorRT engine first and then run inference
304-
if not is_trt_engine:
305-
compile_settings = {
306-
"inputs": input_tensors,
307-
"enabled_precisions": {precision_to_dtype(precision)},
308-
"truncate_long_and_double": params.get("truncate", False),
309-
}
310-
311-
print("Converting method to TensorRT engine...")
312-
with torch.no_grad(), torchtrt.logging.errors():
313-
model = torchtrt.ts.convert_method_to_trt_engine(
314-
model, "forward", **compile_settings
315-
)
316-
298+
# Export an ONNX model and convert to TRT
299+
torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx")
300+
logger = trt.Logger(trt.Logger.WARNING)
301+
builder = trt.Builder(logger)
302+
network = builder.create_network(
303+
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
304+
)
305+
parser = trt.OnnxParser(network, logger)
306+
success = parser.parse_from_file("./tmp.onnx")
307+
if not success:
308+
raise ValueError("ONNX conversion failed")
309+
310+
config = builder.create_builder_config()
311+
if precision == "fp16":
312+
config.set_flag(trt.BuilderFlag.FP16)
313+
serialized_engine = builder.build_serialized_network(network, config)
317314
# Deserialize the TensorRT engine
318-
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
319-
engine = runtime.deserialize_cuda_engine(model)
315+
with trt.Runtime(logger) as runtime:
316+
engine = runtime.deserialize_cuda_engine(serialized_engine)
320317

321318
print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size)
322319
iters = params.get("iterations", 20)
@@ -427,11 +424,10 @@ def run(
427424
)
428425
elif backend == "tensorrt":
429426
run_tensorrt(
430-
model,
427+
model_torch,
431428
input_tensors,
432429
params,
433430
precision,
434-
is_trt_engine,
435431
batch_size,
436432
)
437433
elif backend == "dynamo":
@@ -440,9 +436,6 @@ def run(
440436
elif backend == "torch_compile":
441437
run_torch_compile(model_torch, input_tensors, params, precision, batch_size)
442438

443-
elif backend == "torch_compile":
444-
run_torch_compile(model_torch, input_tensors, params, precision, batch_size)
445-
446439
elif backend == "inductor":
447440
run_inductor(model_torch, input_tensors, params, precision, batch_size)
448441

tools/perf/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
numpy
22
argparse
33
pyyaml
4+
onnx
45
transformers==4.33.2
56
diffusers==0.21.4
67
pandas==2.0.1
78
timm==0.9.8
9+

0 commit comments

Comments
 (0)