@@ -26,7 +26,7 @@ def __init__(self, config_file):
26
26
self .parser = None
27
27
self .config = config_file
28
28
self .params = None
29
-
29
+
30
30
# Reads and loads the yaml file
31
31
def read_config (self ):
32
32
with open (self .config , "r" ) as stream :
@@ -45,7 +45,7 @@ def get(self, key, default_value=None):
45
45
return self .params [key ]
46
46
47
47
# Runs inference using Torch backend
48
- def run_torch (model , input_tensors , params , precision ):
48
+ def run_torch (model , input_tensors , params , precision , batch_size ):
49
49
print ("Running Torch for precision: " , precision )
50
50
iters = params .get ('iterations' , 20 )
51
51
@@ -66,25 +66,25 @@ def run_torch(model, input_tensors, params, precision):
66
66
meas_time = end_time - start_time
67
67
timings .append (meas_time )
68
68
print ("Iteration {}: {:.6f} s" .format (i , end_time - start_time ))
69
-
70
- printStats ("Torch" , timings , precision )
69
+
70
+ printStats ("Torch" , timings , precision , batch_size )
71
71
72
72
# Runs inference using Torch-TensorRT backend
73
- def run_torch_tensorrt (model , input_tensors , params , precision ):
73
+ def run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size ):
74
74
print ("Running Torch-TensorRT" )
75
-
76
75
# Compiling Torch-TensorRT model
77
76
compile_settings = {
78
77
"inputs" : input_tensors ,
79
- "enabled_precisions" : {precision_to_dtype (precision )}
78
+ "enabled_precisions" : {precision_to_dtype (precision )} ,
79
+ "truncate_long_and_double" : truncate_long_and_double ,
80
80
}
81
81
82
82
if precision == 'int8' :
83
83
compile_settings .update ({"calib" : params .get ('calibration_cache' )})
84
84
85
-
85
+
86
86
model = torchtrt .compile (model , ** compile_settings )
87
-
87
+
88
88
iters = params .get ('iterations' , 20 )
89
89
# Warm up
90
90
with torch .no_grad ():
@@ -103,8 +103,8 @@ def run_torch_tensorrt(model, input_tensors, params, precision):
103
103
meas_time = end_time - start_time
104
104
timings .append (meas_time )
105
105
print ("Iteration {}: {:.6f} s" .format (i , end_time - start_time ))
106
-
107
- printStats ("Torch-TensorRT" , timings , precision )
106
+
107
+ printStats ("Torch-TensorRT" , timings , precision , batch_size )
108
108
109
109
def torch_dtype_from_trt (dtype ):
110
110
if dtype == trt .int8 :
@@ -129,7 +129,7 @@ def torch_device_from_trt(device):
129
129
return TypeError ("%s is not supported by torch" % device )
130
130
131
131
132
- def run_tensorrt (model , input_tensors , params , precision , is_trt_engine = False ):
132
+ def run_tensorrt (model , input_tensors , params , precision , is_trt_engine = False , batch_size = 1 ):
133
133
engine = None
134
134
135
135
# If the model file is a TensorRT engine then directly deserialize and run inference
@@ -143,22 +143,21 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False):
143
143
print ("Converting method to TensorRT engine..." )
144
144
with torch .no_grad ():
145
145
model = torchtrt .ts .convert_method_to_trt_engine (model , "forward" , ** compile_settings )
146
-
146
+
147
147
# Deserialize the TensorRT engine
148
148
with trt .Logger () as logger , trt .Runtime (logger ) as runtime :
149
149
engine = runtime .deserialize_cuda_engine (model )
150
-
150
+
151
151
print ("Running TensorRT" )
152
152
iters = params .get ('iterations' , 20 )
153
- batch_size = params .get ('batch' , 1 )
154
153
155
154
# Compiling the bindings
156
155
bindings = engine .num_bindings * [None ]
157
-
156
+ # import pdb; pdb.set_trace()
158
157
k = 0
159
158
for idx ,_ in enumerate (bindings ):
160
159
dtype = torch_dtype_from_trt (engine .get_binding_dtype (idx ))
161
- shape = ( batch_size ,) + tuple (engine .get_binding_shape (idx ))
160
+ shape = tuple (engine .get_binding_shape (idx ))
162
161
device = torch_device_from_trt (engine .get_location (idx ))
163
162
if not engine .binding_is_input (idx ):
164
163
# Output bindings
@@ -168,26 +167,26 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False):
168
167
# Input bindings
169
168
bindings [idx ] = input_tensors [k ].data_ptr ()
170
169
k += 1
171
-
170
+
172
171
timings = []
173
172
with engine .create_execution_context () as context :
174
173
for i in range (WARMUP_ITER ):
175
- context .execute_async (batch_size , bindings , torch .cuda .current_stream ().cuda_stream )
174
+ context .execute_async (1 , bindings , torch .cuda .current_stream ().cuda_stream )
176
175
torch .cuda .synchronize ()
177
176
178
177
for i in range (iters ):
179
178
start_time = timeit .default_timer ()
180
- context .execute_async (batch_size , bindings , torch .cuda .current_stream ().cuda_stream )
179
+ context .execute_async (1 , bindings , torch .cuda .current_stream ().cuda_stream )
181
180
torch .cuda .synchronize ()
182
181
end_time = timeit .default_timer ()
183
182
meas_time = end_time - start_time
184
183
timings .append (meas_time )
185
184
print ("Iterations {}: {:.6f} s" .format (i , end_time - start_time ))
186
-
187
- printStats ("TensorRT" , timings , precision )
185
+
186
+ printStats ("TensorRT" , timings , precision , batch_size )
188
187
189
188
# Deploys inference run for different backend configurations
190
- def run (model , input_tensors , params , precision , is_trt_engine = False ):
189
+ def run (model , input_tensors , params , precision , truncate_long_and_double = False , batch_size = 1 , is_trt_engine = False ):
191
190
for backend in params .get ('backend' ):
192
191
193
192
if precision == 'int8' :
@@ -200,18 +199,19 @@ def run(model, input_tensors, params, precision, is_trt_engine = False):
200
199
return False
201
200
202
201
if backend == 'all' :
203
- run_torch (model , input_tensors , params , precision )
204
- run_torch_tensorrt (model , input_tensors , params , precision )
205
- run_tensorrt (model , input_tensors , params , precision , is_trt_engine )
206
-
202
+ run_torch (model , input_tensors , params , precision , batch_size )
203
+ run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size )
204
+ # import pdb; pdb.set_trace()
205
+ run_tensorrt (model , input_tensors , params , precision , is_trt_engine , batch_size )
206
+
207
207
elif backend == "torch" :
208
- run_torch (model , input_tensors , params , precision )
209
-
208
+ run_torch (model , input_tensors , params , precision , batch_size )
209
+
210
210
elif backend == "torch_tensorrt" :
211
- run_torch_tensorrt (model , input_tensors , params , precision )
212
-
211
+ run_torch_tensorrt (model , input_tensors , params , precision , truncate_long_and_double , batch_size )
212
+
213
213
elif backend == "tensorrt" :
214
- run_tensorrt (model , input_tensors , params , precision , is_trt_engine )
214
+ run_tensorrt (model , input_tensors , params , precision , is_trt_engine , batch_size )
215
215
216
216
# Generate report
217
217
def printStats (backend , timings , precision , batch_size = 1 ):
@@ -274,7 +274,7 @@ def load_model(params):
274
274
arg_parser = argparse .ArgumentParser (description = "Run inference on a model with random input values" )
275
275
arg_parser .add_argument ("--config" , help = "Load YAML based configuration file to run the inference. If this is used other params will be ignored" )
276
276
args = arg_parser .parse_args ()
277
-
277
+
278
278
parser = ConfigParser (args .config )
279
279
# Load YAML params
280
280
params = parser .read_config ()
@@ -293,6 +293,8 @@ def load_model(params):
293
293
torch .manual_seed (12345 )
294
294
295
295
num_input = params .get ('input' ).get ('num_inputs' )
296
+ truncate_long_and_double = params .get ('runtime' ).get ('truncate_long_and_double' , False )
297
+ batch_size = params .get ('input' ).get ('batch_size' , 1 )
296
298
for precision in params .get ('runtime' ).get ('precision' , 'fp32' ):
297
299
input_tensors = []
298
300
num_input = params .get ('input' ).get ('num_inputs' , 1 )
@@ -306,9 +308,9 @@ def load_model(params):
306
308
if not is_trt_engine and precision == "fp16" or precision == "half" :
307
309
# If model is TensorRT serialized engine then model.half will report failure
308
310
model = model .half ()
309
-
311
+
310
312
# Run inference
311
- status = run (model , input_tensors , params , precision , is_trt_engine )
313
+ status = run (model , input_tensors , params , precision , truncate_long_and_double , batch_size , is_trt_engine )
312
314
if status == False :
313
315
continue
314
316
0 commit comments