Skip to content

Commit 2f3ada3

Browse files
committed
chore: refactor code
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 0363526 commit 2f3ada3

File tree

2 files changed

+39
-33
lines changed

2 files changed

+39
-33
lines changed

py/torch_tensorrt/fx/input_tensor_spec.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,43 @@ def from_tensors(cls, tensors: Sequence[torch.Tensor]) -> List["InputTensorSpec"
117117
assert isinstance(tensors, (list, tuple))
118118
return [cls.from_tensor(t) for t in tensors]
119119

120+
@classmethod
121+
def from_input(cls, input_obj: Input) -> "InputTensorSpec":
122+
"""
123+
Produce a list of InputTenosrSpec named tuples which contain
124+
the information of all the given PyTorch tensors.
125+
126+
Args:
127+
tensors (Iterable[torch.Tensor]): A list of PyTorch tensors.
128+
129+
Returns:
130+
A list of InputTensorSpec named tuples.
131+
"""
132+
assert isinstance(input_obj, Input)
133+
input_spec = None
134+
if isinstance(input_obj.shape, dict):
135+
min_shape = input_obj.shape["min_shape"]
136+
opt_shape = input_obj.shape["opt_shape"]
137+
max_shape = input_obj.shape["max_shape"]
138+
dyn_shape = []
139+
for min, opt, max in zip(min_shape, opt_shape, max_shape):
140+
if min == opt == max:
141+
dyn_shape.append(min)
142+
else:
143+
dyn_shape.append(-1)
144+
dtype = input_obj.torch_dtype
145+
input_spec = cls(
146+
shape=dyn_shape,
147+
dtype=dtype,
148+
shape_ranges=[(min_shape, opt_shape, max_shape)],
149+
)
150+
else:
151+
shape = input_obj.shape
152+
dtype = input_obj.torch_dtype
153+
input_spec = cls(shape=shape, dtype=dtype)
154+
155+
return input_spec
156+
120157
@classmethod
121158
def from_tensors_with_dynamic_batch_size(
122159
cls,

py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -61,37 +61,6 @@
6161
# ----------------------------------------------------------------------
6262

6363

64-
def convert_input_to_spec(input_obj: Any):
65-
input_spec = None
66-
if isinstance(input_obj, _Input.Input):
67-
if isinstance(input_obj.shape, dict):
68-
min_shape = input_obj.shape["min_shape"]
69-
opt_shape = input_obj.shape["opt_shape"]
70-
max_shape = input_obj.shape["max_shape"]
71-
dyn_shape = []
72-
for min, opt, max in zip(min_shape, opt_shape, max_shape):
73-
if min == opt == max:
74-
dyn_shape.append(min)
75-
else:
76-
dyn_shape.append(-1)
77-
dtype = input_obj.torch_dtype
78-
input_spec = InputTensorSpec(
79-
shape=dyn_shape,
80-
dtype=dtype,
81-
shape_ranges=[(min_shape, opt_shape, max_shape)],
82-
)
83-
else:
84-
shape = input_obj.shape
85-
dtype = input_obj.torch_dtype
86-
input_spec = InputTensorSpec(shape=shape, dtype=dtype)
87-
else:
88-
raise ValueError(
89-
"Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor"
90-
)
91-
92-
return input_spec
93-
94-
9564
def wrapper(fn: Callable, input) -> Callable:
9665
@wraps(fn)
9766
def wrapped_fn(gm):
@@ -303,7 +272,7 @@ def build_trt_lower_pipeline(
303272
self._trt_input = []
304273
for input_obj in input:
305274
if isinstance(input_obj, _Input.Input):
306-
self._trt_input.append(convert_input_to_spec(input_obj))
275+
self._trt_input.append(InputTensorSpec.from_input(input_obj))
307276

308277
self._additional_input = additional_input
309278
passes = []
@@ -325,7 +294,7 @@ def build_aten2trt_lower_pipeline(
325294
self._trt_input = []
326295
for input_obj in input:
327296
if isinstance(input_obj, _Input.Input):
328-
self._trt_input.append(convert_input_to_spec(input_obj))
297+
self._trt_input.append(InputTensorSpec.from_input(input_obj))
329298

330299
self._additional_input = additional_input
331300
passes = []

0 commit comments

Comments
 (0)