12
12
from torch .fx .passes import shape_prop
13
13
from torch .fx .passes .infra .pass_base import PassResult
14
14
from torch .testing ._internal .common_utils import TestCase
15
- from torch_tensorrt .fx import TRTInterpreter , TRTModule
16
- from torch_tensorrt ._Input import Input
15
+ from torch_tensorrt .fx import InputTensorSpec , TRTInterpreter , TRTModule
17
16
from torch_tensorrt .fx .passes .lower_basic_pass_aten import (
18
17
compose_bmm ,
19
18
compose_chunk ,
@@ -212,7 +211,7 @@ def run_test(self, mod, inputs, expected_ops, rtol=1e-03, atol=1e-03):
212
211
mod = torch .fx .symbolic_trace (mod )
213
212
shape_prop .ShapeProp (mod ).propagate (* inputs )
214
213
mod = NormalizeArgs (mod ).transform ()
215
- interp = TRTInterpreter (mod , Input .from_tensors (inputs ))
214
+ interp = TRTInterpreter (mod , InputTensorSpec .from_tensors (inputs ))
216
215
super ().run_test (mod , inputs , expected_ops , None , interp , rtol , atol )
217
216
218
217
def run_test_custom_compare_results (
@@ -230,7 +229,7 @@ def run_test_custom_compare_results(
230
229
mod = torch .fx .symbolic_trace (mod )
231
230
shape_prop .ShapeProp (mod ).propagate (* inputs )
232
231
mod = NormalizeArgs (mod ).transform ()
233
- interp = TRTInterpreter (mod , Input .from_tensors (inputs ))
232
+ interp = TRTInterpreter (mod , InputTensorSpec .from_tensors (inputs ))
234
233
super ().run_test_custom_compare_results (
235
234
mod , inputs , expected_ops , interp , comparators , fp16_mode = fp16_mode
236
235
)
@@ -259,14 +258,14 @@ def run_test(
259
258
mod = pass_tracer (mod , inputs )
260
259
261
260
if test_implicit_batch_dim :
262
- interp = TRTInterpreter (mod , Input .from_tensors (inputs ))
261
+ interp = TRTInterpreter (mod , InputTensorSpec .from_tensors (inputs ))
263
262
super ().run_test (
264
263
mod , inputs , expected_ops , unexpected_ops , interp , rtol , atol , precision
265
264
)
266
265
267
266
if test_explicit_batch_dim :
268
267
interp = TRTInterpreter (
269
- mod , Input .from_tensors (inputs ), explicit_batch_dimension = True
268
+ mod , InputTensorSpec .from_tensors (inputs ), explicit_batch_dimension = True
270
269
)
271
270
super ().run_test (
272
271
mod , inputs , expected_ops , unexpected_ops , interp , rtol , atol , precision
@@ -275,7 +274,7 @@ def run_test(
275
274
if test_explicit_precision :
276
275
interp = TRTInterpreter (
277
276
mod ,
278
- Input .from_tensors (inputs ),
277
+ InputTensorSpec .from_tensors (inputs ),
279
278
explicit_precision = test_explicit_precision ,
280
279
)
281
280
super ().run_test (
@@ -284,7 +283,7 @@ def run_test(
284
283
285
284
interp = TRTInterpreter (
286
285
mod ,
287
- Input .from_tensors (inputs ),
286
+ InputTensorSpec .from_tensors (inputs ),
288
287
explicit_batch_dimension = True ,
289
288
explicit_precision = test_explicit_precision ,
290
289
)
@@ -304,12 +303,12 @@ def run_test_with_assert_error(
304
303
mod = acc_tracer .trace (mod , inputs )
305
304
306
305
if test_implicit_batch_dim :
307
- interp = TRTInterpreter (mod , Input .from_tensors (inputs ))
306
+ interp = TRTInterpreter (mod , InputTensorSpec .from_tensors (inputs ))
308
307
super ().run_test_with_error (mod , inputs , interp , expect_error )
309
308
310
309
if test_explicit_batch_dim :
311
310
interp = TRTInterpreter (
312
- mod , Input .from_tensors (inputs ), explicit_batch_dimension = True
311
+ mod , InputTensorSpec .from_tensors (inputs ), explicit_batch_dimension = True
313
312
)
314
313
super ().run_test_with_error (mod , inputs , interp , expect_error )
315
314
@@ -323,7 +322,7 @@ def run_test_with_dynamic_shape(
323
322
atol = 1e-03 ,
324
323
):
325
324
mod .eval ()
326
- inputs = Input .create_inputs_from_specs (input_specs )
325
+ inputs = InputTensorSpec .create_inputs_from_specs (input_specs )
327
326
mod = acc_tracer .trace (mod , inputs )
328
327
interp = TRTInterpreter (mod , input_specs , explicit_batch_dimension = True )
329
328
super ().run_test (mod , inputs , expected_ops , unexpected_ops , interp , rtol , atol )
@@ -393,7 +392,7 @@ def run_test(
393
392
if test_explicit_batch_dim :
394
393
interp = TRTInterpreter (
395
394
mod ,
396
- Input .from_tensors (inputs ),
395
+ InputTensorSpec .from_tensors (inputs ),
397
396
explicit_batch_dimension = True ,
398
397
)
399
398
super ().run_test (
@@ -403,7 +402,7 @@ def run_test(
403
402
if test_explicit_precision :
404
403
interp = TRTInterpreter (
405
404
mod ,
406
- Input .from_tensors (inputs ),
405
+ InputTensorSpec .from_tensors (inputs ),
407
406
explicit_precision = test_explicit_precision ,
408
407
)
409
408
super ().run_test (
@@ -412,7 +411,7 @@ def run_test(
412
411
413
412
interp = TRTInterpreter (
414
413
mod ,
415
- Input .from_tensors (inputs ),
414
+ InputTensorSpec .from_tensors (inputs ),
416
415
explicit_batch_dimension = True ,
417
416
explicit_precision = test_explicit_precision ,
418
417
)
@@ -430,7 +429,7 @@ def run_test_with_dynamic_shape(
430
429
atol = 1e-03 ,
431
430
):
432
431
mod .eval ()
433
- inputs = Input .create_inputs_from_specs (input_specs )
432
+ inputs = InputTensorSpec .create_inputs_from_specs (input_specs )
434
433
mod = self .generate_graph (mod , inputs , expected_ops , unexpected_ops , None )
435
434
436
435
interp = TRTInterpreter (
@@ -440,7 +439,7 @@ def run_test_with_dynamic_shape(
440
439
)
441
440
# Since the lowering is based on optimal shape. We need to test with
442
441
# different shape(for ex. max shape) for testing dynamic shape
443
- inputs_max = Input .create_inputs_from_max_specs (input_specs )
442
+ inputs_max = InputTensorSpec .create_inputs_from_max_specs (input_specs )
444
443
super ().run_test (
445
444
mod , inputs_max , expected_ops , unexpected_ops , interp , rtol , atol
446
445
)
0 commit comments