Skip to content

Commit 987b2b3

Browse files
committed
chore: revert FX changes
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 1b60341 commit 987b2b3

File tree

3 files changed

+18
-20
lines changed

3 files changed

+18
-20
lines changed

py/torch_tensorrt/fx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
tensorrt_converter,
99
)
1010
from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa
11-
from .input_tensor_spec import generate_input_specs # noqa
11+
from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa
1212
from .lower_setting import LowerSetting # noqa
1313
from .trt_module import TRTModule # noqa
1414
from .lower import compile # usort: skip #noqa

py/torch_tensorrt/fx/fx2trt.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch.fx.passes.shape_prop import TensorMetadata
1515

1616
from .converter_registry import CONVERTERS
17-
from .._Input import Input
17+
from .input_tensor_spec import InputTensorSpec
1818
from .observer import Observer
1919
from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt
2020

@@ -36,7 +36,7 @@ class TRTInterpreter(torch.fx.Interpreter):
3636
def __init__(
3737
self,
3838
module: torch.fx.GraphModule,
39-
input_specs: List[Input],
39+
input_specs: List[InputTensorSpec],
4040
explicit_batch_dimension: bool = False,
4141
explicit_precision: bool = False,
4242
logger_level=None,
@@ -79,7 +79,6 @@ def __init__(
7979
] = dict()
8080

8181
def validate_input_specs(self):
82-
# import pdb; pdb.set_trace()
8382
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
8483
if not self.network.has_implicit_batch_dimension:
8584
assert (

py/torch_tensorrt/fx/tools/common_fx2trt.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from torch.fx.passes import shape_prop
1313
from torch.fx.passes.infra.pass_base import PassResult
1414
from torch.testing._internal.common_utils import TestCase
15-
from torch_tensorrt.fx import TRTInterpreter, TRTModule
16-
from torch_tensorrt._Input import Input
15+
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
1716
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
1817
compose_bmm,
1918
compose_chunk,
@@ -212,7 +211,7 @@ def run_test(self, mod, inputs, expected_ops, rtol=1e-03, atol=1e-03):
212211
mod = torch.fx.symbolic_trace(mod)
213212
shape_prop.ShapeProp(mod).propagate(*inputs)
214213
mod = NormalizeArgs(mod).transform()
215-
interp = TRTInterpreter(mod, Input.from_tensors(inputs))
214+
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
216215
super().run_test(mod, inputs, expected_ops, None, interp, rtol, atol)
217216

218217
def run_test_custom_compare_results(
@@ -230,7 +229,7 @@ def run_test_custom_compare_results(
230229
mod = torch.fx.symbolic_trace(mod)
231230
shape_prop.ShapeProp(mod).propagate(*inputs)
232231
mod = NormalizeArgs(mod).transform()
233-
interp = TRTInterpreter(mod, Input.from_tensors(inputs))
232+
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
234233
super().run_test_custom_compare_results(
235234
mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode
236235
)
@@ -259,14 +258,14 @@ def run_test(
259258
mod = pass_tracer(mod, inputs)
260259

261260
if test_implicit_batch_dim:
262-
interp = TRTInterpreter(mod, Input.from_tensors(inputs))
261+
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
263262
super().run_test(
264263
mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
265264
)
266265

267266
if test_explicit_batch_dim:
268267
interp = TRTInterpreter(
269-
mod, Input.from_tensors(inputs), explicit_batch_dimension=True
268+
mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
270269
)
271270
super().run_test(
272271
mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
@@ -275,7 +274,7 @@ def run_test(
275274
if test_explicit_precision:
276275
interp = TRTInterpreter(
277276
mod,
278-
Input.from_tensors(inputs),
277+
InputTensorSpec.from_tensors(inputs),
279278
explicit_precision=test_explicit_precision,
280279
)
281280
super().run_test(
@@ -284,7 +283,7 @@ def run_test(
284283

285284
interp = TRTInterpreter(
286285
mod,
287-
Input.from_tensors(inputs),
286+
InputTensorSpec.from_tensors(inputs),
288287
explicit_batch_dimension=True,
289288
explicit_precision=test_explicit_precision,
290289
)
@@ -304,12 +303,12 @@ def run_test_with_assert_error(
304303
mod = acc_tracer.trace(mod, inputs)
305304

306305
if test_implicit_batch_dim:
307-
interp = TRTInterpreter(mod, Input.from_tensors(inputs))
306+
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
308307
super().run_test_with_error(mod, inputs, interp, expect_error)
309308

310309
if test_explicit_batch_dim:
311310
interp = TRTInterpreter(
312-
mod, Input.from_tensors(inputs), explicit_batch_dimension=True
311+
mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
313312
)
314313
super().run_test_with_error(mod, inputs, interp, expect_error)
315314

@@ -323,7 +322,7 @@ def run_test_with_dynamic_shape(
323322
atol=1e-03,
324323
):
325324
mod.eval()
326-
inputs = Input.create_inputs_from_specs(input_specs)
325+
inputs = InputTensorSpec.create_inputs_from_specs(input_specs)
327326
mod = acc_tracer.trace(mod, inputs)
328327
interp = TRTInterpreter(mod, input_specs, explicit_batch_dimension=True)
329328
super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)
@@ -393,7 +392,7 @@ def run_test(
393392
if test_explicit_batch_dim:
394393
interp = TRTInterpreter(
395394
mod,
396-
Input.from_tensors(inputs),
395+
InputTensorSpec.from_tensors(inputs),
397396
explicit_batch_dimension=True,
398397
)
399398
super().run_test(
@@ -403,7 +402,7 @@ def run_test(
403402
if test_explicit_precision:
404403
interp = TRTInterpreter(
405404
mod,
406-
Input.from_tensors(inputs),
405+
InputTensorSpec.from_tensors(inputs),
407406
explicit_precision=test_explicit_precision,
408407
)
409408
super().run_test(
@@ -412,7 +411,7 @@ def run_test(
412411

413412
interp = TRTInterpreter(
414413
mod,
415-
Input.from_tensors(inputs),
414+
InputTensorSpec.from_tensors(inputs),
416415
explicit_batch_dimension=True,
417416
explicit_precision=test_explicit_precision,
418417
)
@@ -430,7 +429,7 @@ def run_test_with_dynamic_shape(
430429
atol=1e-03,
431430
):
432431
mod.eval()
433-
inputs = Input.create_inputs_from_specs(input_specs)
432+
inputs = InputTensorSpec.create_inputs_from_specs(input_specs)
434433
mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None)
435434

436435
interp = TRTInterpreter(
@@ -440,7 +439,7 @@ def run_test_with_dynamic_shape(
440439
)
441440
# Since the lowering is based on optimal shape. We need to test with
442441
# different shape(for ex. max shape) for testing dynamic shape
443-
inputs_max = Input.create_inputs_from_max_specs(input_specs)
442+
inputs_max = InputTensorSpec.create_inputs_from_max_specs(input_specs)
444443
super().run_test(
445444
mod, inputs_max, expected_ops, unexpected_ops, interp, rtol, atol
446445
)

0 commit comments

Comments
 (0)