Skip to content

Commit 7779b50

Browse files
committed
feat: Add fx2trt backend and revamp current perf utility to accept CLI arguments
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent a13607a commit 7779b50

File tree

1 file changed

+120
-54
lines changed

1 file changed

+120
-54
lines changed

tools/perf/perf_run.py

Lines changed: 120 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
# Importing supported Backends
1616
import torch
1717
import torch_tensorrt as torchtrt
18+
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
19+
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter
20+
from torch_tensorrt.fx import TRTModule
1821
import tensorrt as trt
22+
from utils import parse_inputs, parse_backends, precision_to_dtype, BENCHMARK_MODELS
1923

2024
WARMUP_ITER = 10
2125
results = []
@@ -71,7 +75,7 @@ def run_torch(model, input_tensors, params, precision, batch_size):
7175

7276
# Runs inference using Torch-TensorRT backend
7377
def run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_and_double, batch_size):
74-
print("Running Torch-TensorRT")
78+
print("Running Torch-TensorRT for precision: ", precision)
7579
# Compiling Torch-TensorRT model
7680
compile_settings = {
7781
"inputs": input_tensors,
@@ -82,8 +86,8 @@ def run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_an
8286
if precision == 'int8':
8387
compile_settings.update({"calib": params.get('calibration_cache')})
8488

85-
86-
model = torchtrt.compile(model, **compile_settings)
89+
with torchtrt.logging.errors():
90+
model = torchtrt.compile(model, **compile_settings)
8791

8892
iters = params.get('iterations', 20)
8993
# Warm up
@@ -106,6 +110,55 @@ def run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_an
106110

107111
printStats("Torch-TensorRT", timings, precision, batch_size)
108112

113+
# Runs inference using FX2TRT backend
114+
def run_fx2trt(model, input_tensors, params, precision, batch_size):
115+
print("Running FX2TRT for precision: ", precision)
116+
117+
# Trace the model with acc_tracer.
118+
acc_mod = acc_tracer.trace(model, input_tensors)
119+
# Generate input specs
120+
input_specs = InputTensorSpec.from_tensors(input_tensors)
121+
# Build a TRT interpreter. Set explicit_batch_dimension accordingly.
122+
interpreter = TRTInterpreter(
123+
acc_mod, input_specs, explicit_batch_dimension=True
124+
)
125+
trt_interpreter_result = interpreter.run(
126+
max_batch_size=batch_size,
127+
lower_precision=precision,
128+
max_workspace_size=1 << 25,
129+
sparse_weights=False,
130+
force_fp32_output=False,
131+
strict_type_constraints=False,
132+
algorithm_selector=None,
133+
timing_cache=None,
134+
profiling_verbosity=None)
135+
136+
model = TRTModule(
137+
trt_interpreter_result.engine,
138+
trt_interpreter_result.input_names,
139+
trt_interpreter_result.output_names)
140+
141+
iters = params.get('iterations', 20)
142+
# Warm up
143+
with torch.no_grad():
144+
for _ in range(WARMUP_ITER):
145+
features = model(*input_tensors)
146+
147+
torch.cuda.synchronize()
148+
149+
timings = []
150+
with torch.no_grad():
151+
for i in range(iters):
152+
start_time = timeit.default_timer()
153+
features = model(*input_tensors)
154+
torch.cuda.synchronize()
155+
end_time = timeit.default_timer()
156+
meas_time = end_time - start_time
157+
timings.append(meas_time)
158+
print("Iteration {}: {:.6f} s".format(i, end_time - start_time))
159+
160+
printStats("FX-TensorRT", timings, precision, batch_size)
161+
109162
def torch_dtype_from_trt(dtype):
110163
if dtype == trt.int8:
111164
return torch.int8
@@ -141,19 +194,18 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False, b
141194
}
142195

143196
print("Converting method to TensorRT engine...")
144-
with torch.no_grad():
197+
with torch.no_grad(), torchtrt.logging.errors():
145198
model = torchtrt.ts.convert_method_to_trt_engine(model, "forward", **compile_settings)
146199

147200
# Deserialize the TensorRT engine
148201
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
149202
engine = runtime.deserialize_cuda_engine(model)
150203

151-
print("Running TensorRT")
204+
print("Running TensorRT for precision: ", precision)
152205
iters = params.get('iterations', 20)
153206

154207
# Compiling the bindings
155208
bindings = engine.num_bindings * [None]
156-
# import pdb; pdb.set_trace()
157209
k = 0
158210
for idx,_ in enumerate(bindings):
159211
dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))
@@ -171,12 +223,12 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False, b
171223
timings = []
172224
with engine.create_execution_context() as context:
173225
for i in range(WARMUP_ITER):
174-
context.execute_async(1, bindings, torch.cuda.current_stream().cuda_stream)
226+
context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream)
175227
torch.cuda.synchronize()
176228

177229
for i in range(iters):
178230
start_time = timeit.default_timer()
179-
context.execute_async(1, bindings, torch.cuda.current_stream().cuda_stream)
231+
context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream)
180232
torch.cuda.synchronize()
181233
end_time = timeit.default_timer()
182234
meas_time = end_time - start_time
@@ -186,9 +238,8 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False, b
186238
printStats("TensorRT", timings, precision, batch_size)
187239

188240
# Deploys inference run for different backend configurations
189-
def run(model, input_tensors, params, precision, truncate_long_and_double = False, batch_size = 1, is_trt_engine = False):
190-
for backend in params.get('backend'):
191-
241+
def run(model, backends, input_tensors, params, precision, truncate_long_and_double = False, batch_size = 1, is_trt_engine = False):
242+
for backend in backends:
192243
if precision == 'int8':
193244
if backend == 'all' or backend == 'torch':
194245
print("int8 precision is not supported for torch runtime in this script yet")
@@ -201,7 +252,6 @@ def run(model, input_tensors, params, precision, truncate_long_and_double = Fals
201252
if backend == 'all':
202253
run_torch(model, input_tensors, params, precision, batch_size)
203254
run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_and_double, batch_size)
204-
# import pdb; pdb.set_trace()
205255
run_tensorrt(model, input_tensors, params, precision, is_trt_engine, batch_size)
206256

207257
elif backend == "torch":
@@ -210,6 +260,9 @@ def run(model, input_tensors, params, precision, truncate_long_and_double = Fals
210260
elif backend == "torch_tensorrt":
211261
run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_and_double, batch_size)
212262

263+
elif backend == "fx2trt":
264+
run_fx2trt(model, input_tensors, params, precision, batch_size)
265+
213266
elif backend == "tensorrt":
214267
run_tensorrt(model, input_tensors, params, precision, is_trt_engine, batch_size)
215268

@@ -246,14 +299,6 @@ def printStats(backend, timings, precision, batch_size = 1):
246299
}
247300
results.append(meas)
248301

249-
def precision_to_dtype(pr):
250-
if pr == 'fp32':
251-
return torch.float
252-
elif pr == 'fp16' or pr == 'half':
253-
return torch.half
254-
else:
255-
return torch.int8
256-
257302
def load_model(params):
258303
model = None
259304
is_trt_engine = False
@@ -272,47 +317,68 @@ def load_model(params):
272317

273318
if __name__ == '__main__':
274319
arg_parser = argparse.ArgumentParser(description="Run inference on a model with random input values")
275-
arg_parser.add_argument("--config", help="Load YAML based configuration file to run the inference. If this is used other params will be ignored")
320+
arg_parser.add_argument("--config", type=str, help="Load YAML based configuration file to run the inference. If this is used other params will be ignored")
321+
# The following options are manual user provided settings
322+
arg_parser.add_argument("--backends", type=str, help="Comma separated string of backends. Eg: torch,torch_tensorrt,tensorrt")
323+
arg_parser.add_argument("--model", type=str, help="Name of the model file")
324+
arg_parser.add_argument("--inputs", type=str, help="List of input shapes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT")
325+
arg_parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
326+
arg_parser.add_argument("--precision", default="fp32", type=str, help="Precision of TensorRT engine")
327+
arg_parser.add_argument("--device", type=int, help="device id")
328+
arg_parser.add_argument("--truncate", action='store_true', help="Truncate long and double weights in the network")
329+
arg_parser.add_argument("--is_trt_engine", action='store_true', help="Boolean flag to determine if the user provided model is a TRT engine or not")
276330
args = arg_parser.parse_args()
277331

278-
parser = ConfigParser(args.config)
279-
# Load YAML params
280-
params = parser.read_config()
281-
print("Loading model: ", params.get('model').get('filename'))
282-
283-
model = None
284-
285-
# Default device is set to 0. Configurable using yaml config file.
286-
torch.cuda.set_device(params.get('runtime').get('device', 0))
287-
288-
# Load the model file from disk. If the loaded file is TensorRT engine then is_trt_engine is returned as True
289-
model, is_trt_engine = load_model(params)
290332
cudnn.benchmark = True
291-
292333
# Create random input tensor of certain size
293334
torch.manual_seed(12345)
294335

295-
num_input = params.get('input').get('num_inputs')
296-
truncate_long_and_double = params.get('runtime').get('truncate_long_and_double', False)
297-
batch_size = params.get('input').get('batch_size', 1)
298-
for precision in params.get('runtime').get('precision', 'fp32'):
299-
input_tensors = []
300-
num_input = params.get('input').get('num_inputs', 1)
301-
for i in range(num_input):
302-
inp_tensor = params.get('input').get('input' + str(i))
303-
input_tensors.append(torch.randint(0, 2, tuple(d for d in inp_tensor), dtype=precision_to_dtype(precision)).cuda())
304-
305-
if is_trt_engine:
306-
print("Warning, TensorRT engine file is configured. Please make sure the precision matches with the TRT engine for reliable results")
307-
308-
if not is_trt_engine and precision == "fp16" or precision == "half":
309-
# If model is TensorRT serialized engine then model.half will report failure
310-
model = model.half()
311-
336+
if args.config:
337+
parser = ConfigParser(args.config)
338+
# Load YAML params
339+
params = parser.read_config()
340+
print("Loading model: ", params.get('model').get('filename'))
341+
model_file = params.get('model').get('filename')
342+
# Default device is set to 0. Configurable using yaml config file.
343+
torch.cuda.set_device(params.get('runtime').get('device', 0))
344+
345+
num_input = params.get('input').get('num_inputs')
346+
truncate_long_and_double = params.get('runtime').get('truncate_long_and_double', False)
347+
batch_size = params.get('input').get('batch_size', 1)
348+
for precision in params.get('runtime').get('precision', 'fp32'):
349+
input_tensors = []
350+
num_input = params.get('input').get('num_inputs', 1)
351+
for i in range(num_input):
352+
inp_tensor = params.get('input').get('input' + str(i))
353+
input_tensors.append(torch.randint(0, 2, tuple(d for d in inp_tensor), dtype=precision_to_dtype(precision)).cuda())
354+
355+
if is_trt_engine:
356+
print("Warning, TensorRT engine file is configured. Please make sure the precision matches with the TRT engine for reliable results")
357+
358+
if not is_trt_engine and precision == "fp16" or precision == "half":
359+
# If model is TensorRT serialized engine then model.half will report failure
360+
model = model.half()
361+
backends = params.get('backend')
362+
# Run inference
363+
status = run(model, backends, input_tensors, params, precision, truncate_long_and_double, batch_size, is_trt_engine)
364+
else:
365+
params = vars(args)
366+
model_name = params['model']
367+
if os.path.exists(model_name):
368+
print("Loading user provided model: ", model_name)
369+
model = torch.jit.load(model_name).cuda().eval()
370+
elif model_name in BENCHMARK_MODELS:
371+
model = BENCHMARK_MODELS[model_name]['model'].eval().cuda()
372+
else:
373+
raise ValueError("Invalid model name. Please provide a torchscript model file or model name (among the following options vgg16|resnet50|efficientnet_b0|vit)")
374+
precision = params['precision']
375+
input_tensors = parse_inputs(params['inputs'])
376+
backends = parse_backends(params['backends'])
377+
truncate_long_and_double = params.get('truncate', False)
378+
batch_size = params['batch_size']
379+
is_trt_engine = params['is_trt_engine']
312380
# Run inference
313-
status = run(model, input_tensors, params, precision, truncate_long_and_double, batch_size, is_trt_engine)
314-
if status == False:
315-
continue
381+
status = run(model, backends, input_tensors, params, precision, truncate_long_and_double, batch_size, is_trt_engine)
316382

317383
# Generate report
318384
print('Model Summary:')

0 commit comments

Comments
 (0)