15
15
# Importing supported Backends
16
16
import torch
17
17
import torch_tensorrt as torchtrt
18
+ import torch_tensorrt .fx .tracer .acc_tracer .acc_tracer as acc_tracer
19
+ from torch_tensorrt .fx import InputTensorSpec , TRTInterpreter
20
+ from torch_tensorrt .fx import TRTModule
18
21
import tensorrt as trt
22
+ from utils import parse_inputs , parse_backends , precision_to_dtype , BENCHMARK_MODELS
19
23
20
24
WARMUP_ITER = 10
21
25
results = []
@@ -71,7 +75,7 @@ def run_torch(model, input_tensors, params, precision, batch_size):
71
75
72
76
# Runs inference using Torch-TensorRT backend
73
77
def run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size ):
74
- print ("Running Torch-TensorRT" )
78
+ print ("Running Torch-TensorRT for precision: " , precision )
75
79
# Compiling Torch-TensorRT model
76
80
compile_settings = {
77
81
"inputs" : input_tensors ,
@@ -82,8 +86,8 @@ def run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_an
82
86
if precision == 'int8' :
83
87
compile_settings .update ({"calib" : params .get ('calibration_cache' )})
84
88
85
-
86
- model = torchtrt .compile (model , ** compile_settings )
89
+ with torchtrt . logging . errors ():
90
+ model = torchtrt .compile (model , ** compile_settings )
87
91
88
92
iters = params .get ('iterations' , 20 )
89
93
# Warm up
@@ -106,6 +110,55 @@ def run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_an
106
110
107
111
printStats ("Torch-TensorRT" , timings , precision , batch_size )
108
112
113
+ # Runs inference using FX2TRT backend
114
+ def run_fx2trt (model , input_tensors , params , precision , batch_size ):
115
+ print ("Running FX2TRT for precision: " , precision )
116
+
117
+ # Trace the model with acc_tracer.
118
+ acc_mod = acc_tracer .trace (model , input_tensors )
119
+ # Generate input specs
120
+ input_specs = InputTensorSpec .from_tensors (input_tensors )
121
+ # Build a TRT interpreter. Set explicit_batch_dimension accordingly.
122
+ interpreter = TRTInterpreter (
123
+ acc_mod , input_specs , explicit_batch_dimension = True
124
+ )
125
+ trt_interpreter_result = interpreter .run (
126
+ max_batch_size = batch_size ,
127
+ lower_precision = precision ,
128
+ max_workspace_size = 1 << 25 ,
129
+ sparse_weights = False ,
130
+ force_fp32_output = False ,
131
+ strict_type_constraints = False ,
132
+ algorithm_selector = None ,
133
+ timing_cache = None ,
134
+ profiling_verbosity = None )
135
+
136
+ model = TRTModule (
137
+ trt_interpreter_result .engine ,
138
+ trt_interpreter_result .input_names ,
139
+ trt_interpreter_result .output_names )
140
+
141
+ iters = params .get ('iterations' , 20 )
142
+ # Warm up
143
+ with torch .no_grad ():
144
+ for _ in range (WARMUP_ITER ):
145
+ features = model (* input_tensors )
146
+
147
+ torch .cuda .synchronize ()
148
+
149
+ timings = []
150
+ with torch .no_grad ():
151
+ for i in range (iters ):
152
+ start_time = timeit .default_timer ()
153
+ features = model (* input_tensors )
154
+ torch .cuda .synchronize ()
155
+ end_time = timeit .default_timer ()
156
+ meas_time = end_time - start_time
157
+ timings .append (meas_time )
158
+ print ("Iteration {}: {:.6f} s" .format (i , end_time - start_time ))
159
+
160
+ printStats ("FX-TensorRT" , timings , precision , batch_size )
161
+
109
162
def torch_dtype_from_trt (dtype ):
110
163
if dtype == trt .int8 :
111
164
return torch .int8
@@ -141,19 +194,18 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False, b
141
194
}
142
195
143
196
print ("Converting method to TensorRT engine..." )
144
- with torch .no_grad ():
197
+ with torch .no_grad (), torchtrt . logging . errors () :
145
198
model = torchtrt .ts .convert_method_to_trt_engine (model , "forward" , ** compile_settings )
146
199
147
200
# Deserialize the TensorRT engine
148
201
with trt .Logger () as logger , trt .Runtime (logger ) as runtime :
149
202
engine = runtime .deserialize_cuda_engine (model )
150
203
151
- print ("Running TensorRT" )
204
+ print ("Running TensorRT for precision: " , precision )
152
205
iters = params .get ('iterations' , 20 )
153
206
154
207
# Compiling the bindings
155
208
bindings = engine .num_bindings * [None ]
156
- # import pdb; pdb.set_trace()
157
209
k = 0
158
210
for idx ,_ in enumerate (bindings ):
159
211
dtype = torch_dtype_from_trt (engine .get_binding_dtype (idx ))
@@ -171,12 +223,12 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False, b
171
223
timings = []
172
224
with engine .create_execution_context () as context :
173
225
for i in range (WARMUP_ITER ):
174
- context .execute_async ( 1 , bindings , torch .cuda .current_stream ().cuda_stream )
226
+ context .execute_async_v2 ( bindings , torch .cuda .current_stream ().cuda_stream )
175
227
torch .cuda .synchronize ()
176
228
177
229
for i in range (iters ):
178
230
start_time = timeit .default_timer ()
179
- context .execute_async ( 1 , bindings , torch .cuda .current_stream ().cuda_stream )
231
+ context .execute_async_v2 ( bindings , torch .cuda .current_stream ().cuda_stream )
180
232
torch .cuda .synchronize ()
181
233
end_time = timeit .default_timer ()
182
234
meas_time = end_time - start_time
@@ -186,9 +238,8 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False, b
186
238
printStats ("TensorRT" , timings , precision , batch_size )
187
239
188
240
# Deploys inference run for different backend configurations
189
- def run (model , input_tensors , params , precision , truncate_long_and_double = False , batch_size = 1 , is_trt_engine = False ):
190
- for backend in params .get ('backend' ):
191
-
241
+ def run (model , backends , input_tensors , params , precision , truncate_long_and_double = False , batch_size = 1 , is_trt_engine = False ):
242
+ for backend in backends :
192
243
if precision == 'int8' :
193
244
if backend == 'all' or backend == 'torch' :
194
245
print ("int8 precision is not supported for torch runtime in this script yet" )
@@ -201,7 +252,6 @@ def run(model, input_tensors, params, precision, truncate_long_and_double = Fals
201
252
if backend == 'all' :
202
253
run_torch (model , input_tensors , params , precision , batch_size )
203
254
run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size )
204
- # import pdb; pdb.set_trace()
205
255
run_tensorrt (model , input_tensors , params , precision , is_trt_engine , batch_size )
206
256
207
257
elif backend == "torch" :
@@ -210,6 +260,9 @@ def run(model, input_tensors, params, precision, truncate_long_and_double = Fals
210
260
elif backend == "torch_tensorrt" :
211
261
run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size )
212
262
263
+ elif backend == "fx2trt" :
264
+ run_fx2trt (model , input_tensors , params , precision , batch_size )
265
+
213
266
elif backend == "tensorrt" :
214
267
run_tensorrt (model , input_tensors , params , precision , is_trt_engine , batch_size )
215
268
@@ -246,14 +299,6 @@ def printStats(backend, timings, precision, batch_size = 1):
246
299
}
247
300
results .append (meas )
248
301
249
- def precision_to_dtype (pr ):
250
- if pr == 'fp32' :
251
- return torch .float
252
- elif pr == 'fp16' or pr == 'half' :
253
- return torch .half
254
- else :
255
- return torch .int8
256
-
257
302
def load_model (params ):
258
303
model = None
259
304
is_trt_engine = False
@@ -272,47 +317,68 @@ def load_model(params):
272
317
273
318
if __name__ == '__main__' :
274
319
arg_parser = argparse .ArgumentParser (description = "Run inference on a model with random input values" )
275
- arg_parser .add_argument ("--config" , help = "Load YAML based configuration file to run the inference. If this is used other params will be ignored" )
320
+ arg_parser .add_argument ("--config" , type = str , help = "Load YAML based configuration file to run the inference. If this is used other params will be ignored" )
321
+ # The following options are manual user provided settings
322
+ arg_parser .add_argument ("--backends" , type = str , help = "Comma separated string of backends. Eg: torch,torch_tensorrt,tensorrt" )
323
+ arg_parser .add_argument ("--model" , type = str , help = "Name of the model file" )
324
+ arg_parser .add_argument ("--inputs" , type = str , help = "List of input shapes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT" )
325
+ arg_parser .add_argument ("--batch_size" , type = int , default = 1 , help = "Batch size" )
326
+ arg_parser .add_argument ("--precision" , default = "fp32" , type = str , help = "Precision of TensorRT engine" )
327
+ arg_parser .add_argument ("--device" , type = int , help = "device id" )
328
+ arg_parser .add_argument ("--truncate" , action = 'store_true' , help = "Truncate long and double weights in the network" )
329
+ arg_parser .add_argument ("--is_trt_engine" , action = 'store_true' , help = "Boolean flag to determine if the user provided model is a TRT engine or not" )
276
330
args = arg_parser .parse_args ()
277
331
278
- parser = ConfigParser (args .config )
279
- # Load YAML params
280
- params = parser .read_config ()
281
- print ("Loading model: " , params .get ('model' ).get ('filename' ))
282
-
283
- model = None
284
-
285
- # Default device is set to 0. Configurable using yaml config file.
286
- torch .cuda .set_device (params .get ('runtime' ).get ('device' , 0 ))
287
-
288
- # Load the model file from disk. If the loaded file is TensorRT engine then is_trt_engine is returned as True
289
- model , is_trt_engine = load_model (params )
290
332
cudnn .benchmark = True
291
-
292
333
# Create random input tensor of certain size
293
334
torch .manual_seed (12345 )
294
335
295
- num_input = params .get ('input' ).get ('num_inputs' )
296
- truncate_long_and_double = params .get ('runtime' ).get ('truncate_long_and_double' , False )
297
- batch_size = params .get ('input' ).get ('batch_size' , 1 )
298
- for precision in params .get ('runtime' ).get ('precision' , 'fp32' ):
299
- input_tensors = []
300
- num_input = params .get ('input' ).get ('num_inputs' , 1 )
301
- for i in range (num_input ):
302
- inp_tensor = params .get ('input' ).get ('input' + str (i ))
303
- input_tensors .append (torch .randint (0 , 2 , tuple (d for d in inp_tensor ), dtype = precision_to_dtype (precision )).cuda ())
304
-
305
- if is_trt_engine :
306
- print ("Warning, TensorRT engine file is configured. Please make sure the precision matches with the TRT engine for reliable results" )
307
-
308
- if not is_trt_engine and precision == "fp16" or precision == "half" :
309
- # If model is TensorRT serialized engine then model.half will report failure
310
- model = model .half ()
311
-
336
+ if args .config :
337
+ parser = ConfigParser (args .config )
338
+ # Load YAML params
339
+ params = parser .read_config ()
340
+ print ("Loading model: " , params .get ('model' ).get ('filename' ))
341
+ model_file = params .get ('model' ).get ('filename' )
342
+ # Default device is set to 0. Configurable using yaml config file.
343
+ torch .cuda .set_device (params .get ('runtime' ).get ('device' , 0 ))
344
+
345
+ num_input = params .get ('input' ).get ('num_inputs' )
346
+ truncate_long_and_double = params .get ('runtime' ).get ('truncate_long_and_double' , False )
347
+ batch_size = params .get ('input' ).get ('batch_size' , 1 )
348
+ for precision in params .get ('runtime' ).get ('precision' , 'fp32' ):
349
+ input_tensors = []
350
+ num_input = params .get ('input' ).get ('num_inputs' , 1 )
351
+ for i in range (num_input ):
352
+ inp_tensor = params .get ('input' ).get ('input' + str (i ))
353
+ input_tensors .append (torch .randint (0 , 2 , tuple (d for d in inp_tensor ), dtype = precision_to_dtype (precision )).cuda ())
354
+
355
+ if is_trt_engine :
356
+ print ("Warning, TensorRT engine file is configured. Please make sure the precision matches with the TRT engine for reliable results" )
357
+
358
+ if not is_trt_engine and precision == "fp16" or precision == "half" :
359
+ # If model is TensorRT serialized engine then model.half will report failure
360
+ model = model .half ()
361
+ backends = params .get ('backend' )
362
+ # Run inference
363
+ status = run (model , backends , input_tensors , params , precision , truncate_long_and_double , batch_size , is_trt_engine )
364
+ else :
365
+ params = vars (args )
366
+ model_name = params ['model' ]
367
+ if os .path .exists (model_name ):
368
+ print ("Loading user provided model: " , model_name )
369
+ model = torch .jit .load (model_name ).cuda ().eval ()
370
+ elif model_name in BENCHMARK_MODELS :
371
+ model = BENCHMARK_MODELS [model_name ]['model' ].eval ().cuda ()
372
+ else :
373
+ raise ValueError ("Invalid model name. Please provide a torchscript model file or model name (among the following options vgg16|resnet50|efficientnet_b0|vit)" )
374
+ precision = params ['precision' ]
375
+ input_tensors = parse_inputs (params ['inputs' ])
376
+ backends = parse_backends (params ['backends' ])
377
+ truncate_long_and_double = params .get ('truncate' , False )
378
+ batch_size = params ['batch_size' ]
379
+ is_trt_engine = params ['is_trt_engine' ]
312
380
# Run inference
313
- status = run (model , input_tensors , params , precision , truncate_long_and_double , batch_size , is_trt_engine )
314
- if status == False :
315
- continue
381
+ status = run (model , backends , input_tensors , params , precision , truncate_long_and_double , batch_size , is_trt_engine )
316
382
317
383
# Generate report
318
384
print ('Model Summary:' )
0 commit comments