Skip to content

Commit 0363526

Browse files
committed
chore: Linter fixes
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 5960ac2 commit 0363526

File tree

2 files changed

+46
-7
lines changed

2 files changed

+46
-7
lines changed

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
3838
+ str(type(input_size))
3939
)
4040

41+
4142
def _parse_op_precision(precision: Any) -> _enums.dtype:
4243
if isinstance(precision, torch.dtype):
4344
if precision == torch.int8:
@@ -192,9 +193,17 @@ def _parse_input_signature(input_signature: Any, depth: int = 0):
192193
if i.shape_mode == Input._ShapeMode.STATIC:
193194
ts_i = TSInput(shape=i.shape, dtype=i.dtype, format=i.format)
194195
elif i.shape_mode == Input._ShapeMode.DYNAMIC:
195-
ts_i = TSInput(min_shape=i.shape["min_shape"], opt_shape=i.shape["opt_shape"], max_shape=i.shape["max_shape"], dtype=i.dtype, format=i.format)
196+
ts_i = TSInput(
197+
min_shape=i.shape["min_shape"],
198+
opt_shape=i.shape["opt_shape"],
199+
max_shape=i.shape["max_shape"],
200+
dtype=i.dtype,
201+
format=i.format,
202+
)
196203
else:
197-
raise ValueError("Invalid shape mode detected for input while parsing the input_signature")
204+
raise ValueError(
205+
"Invalid shape mode detected for input while parsing the input_signature"
206+
)
198207

199208
clone = _internal_input_to_torch_class_input(ts_i._to_internal())
200209
return clone
@@ -231,9 +240,21 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
231240
ts_inputs = []
232241
for i in inputs:
233242
if i.shape_mode == Input._ShapeMode.STATIC:
234-
ts_inputs.append(TSInput(shape=i.shape, dtype=i.dtype, format=i.format)._to_internal())
243+
ts_inputs.append(
244+
TSInput(
245+
shape=i.shape, dtype=i.dtype, format=i.format
246+
)._to_internal()
247+
)
235248
elif i.shape_mode == Input._ShapeMode.DYNAMIC:
236-
ts_inputs.append(TSInput(min_shape=i.shape["min_shape"], opt_shape=i.shape["opt_shape"], max_shape=i.shape["max_shape"], dtype=i.dtype, format=i.format)._to_internal())
249+
ts_inputs.append(
250+
TSInput(
251+
min_shape=i.shape["min_shape"],
252+
opt_shape=i.shape["opt_shape"],
253+
max_shape=i.shape["max_shape"],
254+
dtype=i.dtype,
255+
format=i.format,
256+
)._to_internal()
257+
)
237258
info.inputs = ts_inputs
238259

239260
elif compile_spec["input_signature"] is not None:

tests/py/api/test_classes.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,27 @@ def test_dynamic_shape(self):
199199
i = torchtrt.Input(
200200
min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape
201201
)
202-
ts_i = torchtrt.ts.TSInput(min_shape=i.shape["min_shape"], opt_shape=i.shape["opt_shape"], max_shape=i.shape["max_shape"], dtype=i.dtype, format=i.format)
202+
ts_i = torchtrt.ts.TSInput(
203+
min_shape=i.shape["min_shape"],
204+
opt_shape=i.shape["opt_shape"],
205+
max_shape=i.shape["max_shape"],
206+
dtype=i.dtype,
207+
format=i.format,
208+
)
203209
self.assertTrue(self._verify_correctness(ts_i, target))
204210

205211
i = torchtrt.Input(
206212
min_shape=tuple(min_shape),
207213
opt_shape=tuple(opt_shape),
208214
max_shape=tuple(max_shape),
209215
)
210-
ts_i = torchtrt.ts.TSInput(min_shape=i.shape["min_shape"], opt_shape=i.shape["opt_shape"], max_shape=i.shape["max_shape"], dtype=i.dtype, format=i.format)
216+
ts_i = torchtrt.ts.TSInput(
217+
min_shape=i.shape["min_shape"],
218+
opt_shape=i.shape["opt_shape"],
219+
max_shape=i.shape["max_shape"],
220+
dtype=i.dtype,
221+
format=i.format,
222+
)
211223
self.assertTrue(self._verify_correctness(ts_i, target))
212224

213225
tensor_shape = lambda shape: torch.randn(shape).shape
@@ -216,7 +228,13 @@ def test_dynamic_shape(self):
216228
opt_shape=tensor_shape(opt_shape),
217229
max_shape=tensor_shape(max_shape),
218230
)
219-
ts_i = torchtrt.ts.TSInput(min_shape=i.shape["min_shape"], opt_shape=i.shape["opt_shape"], max_shape=i.shape["max_shape"], dtype=i.dtype, format=i.format)
231+
ts_i = torchtrt.ts.TSInput(
232+
min_shape=i.shape["min_shape"],
233+
opt_shape=i.shape["opt_shape"],
234+
max_shape=i.shape["max_shape"],
235+
dtype=i.dtype,
236+
format=i.format,
237+
)
220238
self.assertTrue(self._verify_correctness(ts_i, target))
221239

222240

0 commit comments

Comments
 (0)