Skip to content

Commit 46961d9

Browse files
committed
chore: additional options for perf_run tool
Signed-off-by: dperi <[email protected]>
1 parent 160fe4f commit 46961d9

File tree

2 files changed

+46
-43
lines changed

2 files changed

+46
-43
lines changed

tools/perf/config/vgg16.yml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
---
2-
backend:
1+
---
2+
backend:
33
- all
4-
input:
5-
input0:
4+
input:
5+
input0:
66
- 1
77
- 3
88
- 224
99
- 224
1010
num_inputs: 1
11-
model:
12-
filename: models/vgg16_traced.jit.pt
11+
batch_size: 1
12+
model:
13+
filename: models/vgg16_scripted.jit.pt
1314
name: vgg16
14-
runtime:
15+
runtime:
1516
device: 0
16-
precision:
17+
precision:
1718
- fp32
1819
- fp16

tools/perf/perf_run.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, config_file):
2626
self.parser = None
2727
self.config = config_file
2828
self.params = None
29-
29+
3030
# Reads and loads the yaml file
3131
def read_config(self):
3232
with open(self.config, "r") as stream:
@@ -45,7 +45,7 @@ def get(self, key, default_value=None):
4545
return self.params[key]
4646

4747
# Runs inference using Torch backend
48-
def run_torch(model, input_tensors, params, precision):
48+
def run_torch(model, input_tensors, params, precision, batch_size):
4949
print("Running Torch for precision: ", precision)
5050
iters = params.get('iterations', 20)
5151

@@ -66,25 +66,25 @@ def run_torch(model, input_tensors, params, precision):
6666
meas_time = end_time - start_time
6767
timings.append(meas_time)
6868
print("Iteration {}: {:.6f} s".format(i, end_time - start_time))
69-
70-
printStats("Torch", timings, precision)
69+
70+
printStats("Torch", timings, precision, batch_size)
7171

7272
# Runs inference using Torch-TensorRT backend
73-
def run_torch_tensorrt(model, input_tensors, params, precision):
73+
def run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_and_double, batch_size):
7474
print("Running Torch-TensorRT")
75-
7675
# Compiling Torch-TensorRT model
7776
compile_settings = {
7877
"inputs": input_tensors,
79-
"enabled_precisions": {precision_to_dtype(precision)}
78+
"enabled_precisions": {precision_to_dtype(precision)} ,
79+
"truncate_long_and_double": truncate_long_and_double,
8080
}
8181

8282
if precision == 'int8':
8383
compile_settings.update({"calib": params.get('calibration_cache')})
8484

85-
85+
8686
model = torchtrt.compile(model, **compile_settings)
87-
87+
8888
iters = params.get('iterations', 20)
8989
# Warm up
9090
with torch.no_grad():
@@ -103,8 +103,8 @@ def run_torch_tensorrt(model, input_tensors, params, precision):
103103
meas_time = end_time - start_time
104104
timings.append(meas_time)
105105
print("Iteration {}: {:.6f} s".format(i, end_time - start_time))
106-
107-
printStats("Torch-TensorRT", timings, precision)
106+
107+
printStats("Torch-TensorRT", timings, precision, batch_size)
108108

109109
def torch_dtype_from_trt(dtype):
110110
if dtype == trt.int8:
@@ -129,7 +129,7 @@ def torch_device_from_trt(device):
129129
return TypeError("%s is not supported by torch" % device)
130130

131131

132-
def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False):
132+
def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False, batch_size=1):
133133
engine = None
134134

135135
# If the model file is a TensorRT engine then directly deserialize and run inference
@@ -143,22 +143,21 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False):
143143
print("Converting method to TensorRT engine...")
144144
with torch.no_grad():
145145
model = torchtrt.ts.convert_method_to_trt_engine(model, "forward", **compile_settings)
146-
146+
147147
# Deserialize the TensorRT engine
148148
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
149149
engine = runtime.deserialize_cuda_engine(model)
150-
150+
151151
print("Running TensorRT")
152152
iters = params.get('iterations', 20)
153-
batch_size = params.get('batch', 1)
154153

155154
# Compiling the bindings
156155
bindings = engine.num_bindings * [None]
157-
156+
# import pdb; pdb.set_trace()
158157
k = 0
159158
for idx,_ in enumerate(bindings):
160159
dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))
161-
shape = (batch_size,) + tuple(engine.get_binding_shape(idx))
160+
shape = tuple(engine.get_binding_shape(idx))
162161
device = torch_device_from_trt(engine.get_location(idx))
163162
if not engine.binding_is_input(idx):
164163
# Output bindings
@@ -168,26 +167,26 @@ def run_tensorrt(model, input_tensors, params, precision, is_trt_engine=False):
168167
# Input bindings
169168
bindings[idx] = input_tensors[k].data_ptr()
170169
k += 1
171-
170+
172171
timings = []
173172
with engine.create_execution_context() as context:
174173
for i in range(WARMUP_ITER):
175-
context.execute_async(batch_size, bindings, torch.cuda.current_stream().cuda_stream)
174+
context.execute_async(1, bindings, torch.cuda.current_stream().cuda_stream)
176175
torch.cuda.synchronize()
177176

178177
for i in range(iters):
179178
start_time = timeit.default_timer()
180-
context.execute_async(batch_size, bindings, torch.cuda.current_stream().cuda_stream)
179+
context.execute_async(1, bindings, torch.cuda.current_stream().cuda_stream)
181180
torch.cuda.synchronize()
182181
end_time = timeit.default_timer()
183182
meas_time = end_time - start_time
184183
timings.append(meas_time)
185184
print("Iterations {}: {:.6f} s".format(i, end_time - start_time))
186-
187-
printStats("TensorRT", timings, precision)
185+
186+
printStats("TensorRT", timings, precision, batch_size)
188187

189188
# Deploys inference run for different backend configurations
190-
def run(model, input_tensors, params, precision, is_trt_engine = False):
189+
def run(model, input_tensors, params, precision, truncate_long_and_double = False, batch_size = 1, is_trt_engine = False):
191190
for backend in params.get('backend'):
192191

193192
if precision == 'int8':
@@ -200,18 +199,19 @@ def run(model, input_tensors, params, precision, is_trt_engine = False):
200199
return False
201200

202201
if backend == 'all':
203-
run_torch(model, input_tensors, params, precision)
204-
run_torch_tensorrt(model, input_tensors, params, precision)
205-
run_tensorrt(model, input_tensors, params, precision, is_trt_engine)
206-
202+
run_torch(model, input_tensors, params, precision, batch_size)
203+
run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_and_double, batch_size)
204+
# import pdb; pdb.set_trace()
205+
run_tensorrt(model, input_tensors, params, precision, is_trt_engine, batch_size)
206+
207207
elif backend == "torch":
208-
run_torch(model, input_tensors, params, precision)
209-
208+
run_torch(model, input_tensors, params, precision, batch_size)
209+
210210
elif backend == "torch_tensorrt":
211-
run_torch_tensorrt(model, input_tensors, params, precision)
212-
211+
run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_and_double, batch_size)
212+
213213
elif backend == "tensorrt":
214-
run_tensorrt(model, input_tensors, params, precision, is_trt_engine)
214+
run_tensorrt(model, input_tensors, params, precision, is_trt_engine, batch_size)
215215

216216
# Generate report
217217
def printStats(backend, timings, precision, batch_size = 1):
@@ -274,7 +274,7 @@ def load_model(params):
274274
arg_parser = argparse.ArgumentParser(description="Run inference on a model with random input values")
275275
arg_parser.add_argument("--config", help="Load YAML based configuration file to run the inference. If this is used other params will be ignored")
276276
args = arg_parser.parse_args()
277-
277+
278278
parser = ConfigParser(args.config)
279279
# Load YAML params
280280
params = parser.read_config()
@@ -293,6 +293,8 @@ def load_model(params):
293293
torch.manual_seed(12345)
294294

295295
num_input = params.get('input').get('num_inputs')
296+
truncate_long_and_double = params.get('runtime').get('truncate_long_and_double', False)
297+
batch_size = params.get('input').get('batch_size', 1)
296298
for precision in params.get('runtime').get('precision', 'fp32'):
297299
input_tensors = []
298300
num_input = params.get('input').get('num_inputs', 1)
@@ -306,9 +308,9 @@ def load_model(params):
306308
if not is_trt_engine and precision == "fp16" or precision == "half":
307309
# If model is TensorRT serialized engine then model.half will report failure
308310
model = model.half()
309-
311+
310312
# Run inference
311-
status = run(model, input_tensors, params, precision, is_trt_engine)
313+
status = run(model, input_tensors, params, precision, truncate_long_and_double, batch_size, is_trt_engine)
312314
if status == False:
313315
continue
314316

0 commit comments

Comments
 (0)