Skip to content

Commit ec06d6f

Browse files
committed
chore: linter fixes
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 5923e6b commit ec06d6f

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

tools/perf/perf_run.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# Importing supported Backends
1818
import torch
1919
import torch_tensorrt as torchtrt
20+
2021
# from torch_tensorrt.fx.lower import compile
2122
# from torch_tensorrt.fx.utils import LowerPrecision
2223

@@ -144,7 +145,7 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size):
144145
model,
145146
ir="fx",
146147
inputs=input_tensors,
147-
enabled_precisions={torch.float16 if precision=="fp16" else torch.float32},
148+
enabled_precisions={torch.float16 if precision == "fp16" else torch.float32},
148149
)
149150
end_compile = time.time_ns()
150151
compile_time_ms = (end_compile - start_compile) / 1e6
@@ -169,12 +170,20 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size):
169170

170171
recordStats("FX-TensorRT", timings, precision, batch_size, compile_time_ms)
171172

173+
172174
def run_dynamo(model, input_tensors, params, precision, batch_size):
173175
dynamo_backend = params["dynamo_backend"]
174-
print("Running Dynamo with backend: ", dynamo_backend, " for precision: ", precision, " batch_size : ", batch_size)
176+
print(
177+
"Running Dynamo with backend: ",
178+
dynamo_backend,
179+
" for precision: ",
180+
precision,
181+
" batch_size : ",
182+
batch_size,
183+
)
175184

176185
if precision == "fp16":
177-
input_tensors = [tensor.half() for tensor in input_tensors]
186+
input_tensors = [tensor.half() for tensor in input_tensors]
178187

179188
fp16_mode = True if precision == "fp16" else False
180189
# dynamo_backend_params = {"fp16_mode" : fp16_mode}
@@ -187,6 +196,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
187196
# # **dynamo_backend_params
188197
# )
189198
import torch._dynamo as dynamo
199+
190200
model = dynamo.optimize(dynamo_backend, nopython=True)(model)
191201
# Compile and measure the time
192202
with torch.no_grad():
@@ -219,7 +229,10 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
219229
meas_time = end_time - start_time
220230
timings.append(meas_time)
221231

222-
recordStats("Dynamo-" + dynamo_backend, timings, precision, batch_size, compile_time_ms)
232+
recordStats(
233+
"Dynamo-" + dynamo_backend, timings, precision, batch_size, compile_time_ms
234+
)
235+
223236

224237
def torch_dtype_from_trt(dtype):
225238
if dtype == trt.int8:

0 commit comments

Comments
 (0)