Skip to content

Commit 1eeb319

Browse files
committed
chore(//py/torch_tensorrt/ts): Make torch_tensorrt.ts.TorchScriptInput
mypy compliant Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 6fd5d6a commit 1eeb319

File tree

5 files changed

+54
-53
lines changed

5 files changed

+54
-53
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import List, Dict, Any, Tuple, Optional
2+
from typing import List, Dict, Any, Tuple, Optional, Union
33

44
import torch
55

@@ -28,7 +28,7 @@ class _ShapeMode(Enum):
2828
DYNAMIC = 1
2929

3030
shape_mode = None #: (torch_tensorrt.Input._ShapeMode): Is input statically or dynamically shaped
31-
shape = None #: (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
31+
shape: Union[Tuple, Dict] = None #: (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
3232
dtype = (
3333
_enums.dtype.unknown
3434
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
@@ -42,7 +42,7 @@ class _ShapeMode(Enum):
4242
high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET
4343
torch_dtype = torch.float32
4444

45-
def __init__(self, *args, **kwargs):
45+
def __init__(self, *args: Any, **kwargs: Any) -> None:
4646
"""__init__ Method for torch_tensorrt.Input
4747
4848
Input accepts one of a few construction patterns

py/torch_tensorrt/ts/ts_input.py renamed to py/torch_tensorrt/ts/_Input.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch_tensorrt._Input import Input
1010

1111

12-
class TSInput(Input):
12+
class TorchScriptInput(Input):
1313
"""
1414
Defines an input to a module in terms of expected shape, data type and tensor format.
1515
@@ -26,7 +26,7 @@ class TSInput(Input):
2626
format (torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
2727
"""
2828

29-
def __init__(self, *args, **kwargs):
29+
def __init__(self, *args: Any, **kwargs: Any) -> None:
3030
"""__init__ Method for torch_tensorrt.Input
3131
3232
Input accepts one of a few construction patterns
@@ -52,38 +52,39 @@ def __init__(self, *args, **kwargs):
5252
- Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)
5353
- Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW
5454
"""
55-
super(TSInput, self).__init__(*args, **kwargs)
55+
super().__init__(*args, **kwargs)
5656

5757
def _to_internal(self) -> _C.Input:
5858
internal_in = _C.Input()
5959
if self.shape_mode == Input._ShapeMode.DYNAMIC:
60-
if not Input._supported_input_size_type(self.shape["min_shape"]):
61-
raise TypeError(
62-
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
63-
+ str(type(self.shape["min_shape"]))
64-
+ " for min_shape"
65-
)
66-
else:
67-
internal_in.min = self.shape["min_shape"]
68-
69-
if not Input._supported_input_size_type(self.shape["opt_shape"]):
70-
raise TypeError(
71-
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
72-
+ str(type(self.shape["opt_shape"]))
73-
+ " for opt_shape"
74-
)
75-
else:
76-
internal_in.opt = self.shape["opt_shape"]
77-
78-
if not Input._supported_input_size_type(self.shape["max_shape"]):
79-
raise TypeError(
80-
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
81-
+ str(type(self.shape["max_shape"]))
82-
+ " for max_shape"
83-
)
84-
else:
85-
internal_in.max = self.shape["max_shape"]
86-
internal_in.input_is_dynamic = True
60+
if isinstance(self.shape, dict):
61+
if not Input._supported_input_size_type(self.shape["min_shape"]):
62+
raise TypeError(
63+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
64+
+ str(type(self.shape["min_shape"]))
65+
+ " for min_shape"
66+
)
67+
else:
68+
internal_in.min = self.shape["min_shape"]
69+
70+
if not Input._supported_input_size_type(self.shape["opt_shape"]):
71+
raise TypeError(
72+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
73+
+ str(type(self.shape["opt_shape"]))
74+
+ " for opt_shape"
75+
)
76+
else:
77+
internal_in.opt = self.shape["opt_shape"]
78+
79+
if not Input._supported_input_size_type(self.shape["max_shape"]):
80+
raise TypeError(
81+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
82+
+ str(type(self.shape["max_shape"]))
83+
+ " for max_shape"
84+
)
85+
else:
86+
internal_in.max = self.shape["max_shape"]
87+
internal_in.input_is_dynamic = True
8788
else:
8889
if not Input._supported_input_size_type(self.shape):
8990
raise TypeError(

py/torch_tensorrt/ts/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from torch_tensorrt.ts._compiler import *
22
from torch_tensorrt.ts._compile_spec import TensorRTCompileSpec
3-
from torch_tensorrt.ts.ts_input import TSInput
3+
from torch_tensorrt.ts._Input import TorchScriptInput

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Tuple, List, Dict
1010
import warnings
1111
from copy import deepcopy
12-
from torch_tensorrt.ts.ts_input import TSInput
12+
from torch_tensorrt.ts._Input import TorchScriptInput
1313
import tensorrt as trt
1414

1515

@@ -195,9 +195,9 @@ def _parse_input_signature(input_signature: Any, depth: int = 0):
195195

196196
ts_i = i
197197
if i.shape_mode == Input._ShapeMode.STATIC:
198-
ts_i = TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
198+
ts_i = TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
199199
elif i.shape_mode == Input._ShapeMode.DYNAMIC:
200-
ts_i = TSInput(
200+
ts_i = TorchScriptInput(
201201
min_shape=i.shape["min_shape"],
202202
opt_shape=i.shape["opt_shape"],
203203
max_shape=i.shape["max_shape"],
@@ -245,13 +245,13 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
245245
for i in inputs:
246246
if i.shape_mode == Input._ShapeMode.STATIC:
247247
ts_inputs.append(
248-
TSInput(
248+
TorchScriptInput(
249249
shape=i.shape, dtype=i.dtype, format=i.format
250250
)._to_internal()
251251
)
252252
elif i.shape_mode == Input._ShapeMode.DYNAMIC:
253253
ts_inputs.append(
254-
TSInput(
254+
TorchScriptInput(
255255
min_shape=i.shape["min_shape"],
256256
opt_shape=i.shape["opt_shape"],
257257
max_shape=i.shape["max_shape"],

tests/py/ts/api/test_classes.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_infer_from_example_tensor(self):
104104

105105
example_tensor = torch.randn(shape).half()
106106
i = torchtrt.Input.from_tensor(example_tensor)
107-
ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
107+
ts_i = torchtrt.ts.TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
108108
self.assertTrue(self._verify_correctness(ts_i, target))
109109

110110
def test_static_shape(self):
@@ -120,27 +120,27 @@ def test_static_shape(self):
120120
}
121121

122122
i = torchtrt.Input(shape)
123-
ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
123+
ts_i = torchtrt.ts.TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
124124
self.assertTrue(self._verify_correctness(ts_i, target))
125125

126126
i = torchtrt.Input(tuple(shape))
127-
ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
127+
ts_i = torchtrt.ts.TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
128128
self.assertTrue(self._verify_correctness(ts_i, target))
129129

130130
i = torchtrt.Input(torch.randn(shape).shape)
131-
ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
131+
ts_i = torchtrt.ts.TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
132132
self.assertTrue(self._verify_correctness(ts_i, target))
133133

134134
i = torchtrt.Input(shape=shape)
135-
ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
135+
ts_i = torchtrt.ts.TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
136136
self.assertTrue(self._verify_correctness(ts_i, target))
137137

138138
i = torchtrt.Input(shape=tuple(shape))
139-
ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
139+
ts_i = torchtrt.ts.TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
140140
self.assertTrue(self._verify_correctness(ts_i, target))
141141

142142
i = torchtrt.Input(shape=torch.randn(shape).shape)
143-
ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
143+
ts_i = torchtrt.ts.TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
144144
self.assertTrue(self._verify_correctness(ts_i, target))
145145

146146
def test_data_type(self):
@@ -156,11 +156,11 @@ def test_data_type(self):
156156
}
157157

158158
i = torchtrt.Input(shape, dtype=torchtrt.dtype.half)
159-
ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
159+
ts_i = torchtrt.ts.TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
160160
self.assertTrue(self._verify_correctness(ts_i, target))
161161

162162
i = torchtrt.Input(shape, dtype=torch.half)
163-
ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
163+
ts_i = torchtrt.ts.TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
164164
self.assertTrue(self._verify_correctness(ts_i, target))
165165

166166
def test_tensor_format(self):
@@ -176,11 +176,11 @@ def test_tensor_format(self):
176176
}
177177

178178
i = torchtrt.Input(shape, format=torchtrt.TensorFormat.channels_last)
179-
ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
179+
ts_i = torchtrt.ts.TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
180180
self.assertTrue(self._verify_correctness(ts_i, target))
181181

182182
i = torchtrt.Input(shape, format=torch.channels_last)
183-
ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
183+
ts_i = torchtrt.ts.TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
184184
self.assertTrue(self._verify_correctness(ts_i, target))
185185

186186
def test_dynamic_shape(self):
@@ -200,7 +200,7 @@ def test_dynamic_shape(self):
200200
i = torchtrt.Input(
201201
min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape
202202
)
203-
ts_i = torchtrt.ts.TSInput(
203+
ts_i = torchtrt.ts.TorchScriptInput(
204204
min_shape=i.shape["min_shape"],
205205
opt_shape=i.shape["opt_shape"],
206206
max_shape=i.shape["max_shape"],
@@ -214,7 +214,7 @@ def test_dynamic_shape(self):
214214
opt_shape=tuple(opt_shape),
215215
max_shape=tuple(max_shape),
216216
)
217-
ts_i = torchtrt.ts.TSInput(
217+
ts_i = torchtrt.ts.TorchScriptInput(
218218
min_shape=i.shape["min_shape"],
219219
opt_shape=i.shape["opt_shape"],
220220
max_shape=i.shape["max_shape"],
@@ -229,7 +229,7 @@ def test_dynamic_shape(self):
229229
opt_shape=tensor_shape(opt_shape),
230230
max_shape=tensor_shape(max_shape),
231231
)
232-
ts_i = torchtrt.ts.TSInput(
232+
ts_i = torchtrt.ts.TorchScriptInput(
233233
min_shape=i.shape["min_shape"],
234234
opt_shape=i.shape["opt_shape"],
235235
max_shape=i.shape["max_shape"],

0 commit comments

Comments
 (0)