17
17
# Importing supported Backends
18
18
import torch
19
19
import torch_tensorrt as torchtrt
20
+
20
21
# from torch_tensorrt.fx.lower import compile
21
22
# from torch_tensorrt.fx.utils import LowerPrecision
22
23
@@ -144,7 +145,7 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size):
144
145
model ,
145
146
ir = "fx" ,
146
147
inputs = input_tensors ,
147
- enabled_precisions = {torch .float16 if precision == "fp16" else torch .float32 },
148
+ enabled_precisions = {torch .float16 if precision == "fp16" else torch .float32 },
148
149
)
149
150
end_compile = time .time_ns ()
150
151
compile_time_ms = (end_compile - start_compile ) / 1e6
@@ -169,12 +170,20 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size):
169
170
170
171
recordStats ("FX-TensorRT" , timings , precision , batch_size , compile_time_ms )
171
172
173
+
172
174
def run_dynamo (model , input_tensors , params , precision , batch_size ):
173
175
dynamo_backend = params ["dynamo_backend" ]
174
- print ("Running Dynamo with backend: " , dynamo_backend , " for precision: " , precision , " batch_size : " , batch_size )
176
+ print (
177
+ "Running Dynamo with backend: " ,
178
+ dynamo_backend ,
179
+ " for precision: " ,
180
+ precision ,
181
+ " batch_size : " ,
182
+ batch_size ,
183
+ )
175
184
176
185
if precision == "fp16" :
177
- input_tensors = [tensor .half () for tensor in input_tensors ]
186
+ input_tensors = [tensor .half () for tensor in input_tensors ]
178
187
179
188
fp16_mode = True if precision == "fp16" else False
180
189
# dynamo_backend_params = {"fp16_mode" : fp16_mode}
@@ -187,6 +196,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
187
196
# # **dynamo_backend_params
188
197
# )
189
198
import torch ._dynamo as dynamo
199
+
190
200
model = dynamo .optimize (dynamo_backend , nopython = True )(model )
191
201
# Compile and measure the time
192
202
with torch .no_grad ():
@@ -219,7 +229,10 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
219
229
meas_time = end_time - start_time
220
230
timings .append (meas_time )
221
231
222
- recordStats ("Dynamo-" + dynamo_backend , timings , precision , batch_size , compile_time_ms )
232
+ recordStats (
233
+ "Dynamo-" + dynamo_backend , timings , precision , batch_size , compile_time_ms
234
+ )
235
+
223
236
224
237
def torch_dtype_from_trt (dtype ):
225
238
if dtype == trt .int8 :
0 commit comments