Skip to content

Commit 8b1d0d5

Browse files
committed
Removing cast
1 parent 71a452b commit 8b1d0d5

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22
import logging
33
import re
4-
from typing import Any, List, Optional, Tuple, Union
4+
from typing import Any, Callable, List, Optional, Tuple, Union
55

66
import numpy as np
77
import tensorrt as trt
@@ -55,7 +55,9 @@ def dynamic_unsupported_with_args(
5555
arg_positions_to_check: Optional[List[int]] = None,
5656
) -> Callable[[torch.fx.Node], bool]:
5757
"""Returns a validator that a node has no dynamic args at specific positions"""
58-
return functools.partial(_dynamic_unsupported, arg_positions_to_check=arg_positions_to_check)
58+
return functools.partial(
59+
_dynamic_unsupported, arg_positions_to_check=arg_positions_to_check
60+
)
5961

6062

6163
def _dynamic_unsupported(

py/torch_tensorrt/dynamo/conversion/impl/split.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,17 @@ def split(
2828
f"split received input {input} that is not part " "of the TensorRT region!"
2929
)
3030

31-
dim = cast(int, dim)
3231
dynamic_shape = has_dynamic_shape(input.shape)
3332
if dynamic_shape > 0:
3433
# Check whether slice target dim is dynamic shape dim
3534
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
3635

3736
split_sizes = []
3837
if isinstance(split_size_or_sections, int):
39-
split_sizes.append(cast(int, split_size_or_sections))
38+
split_sizes.append(split_size_or_sections)
4039
else:
4140
for split_size_or_section in split_size_or_sections:
42-
split_sizes.append(cast(int, split_size_or_section))
41+
split_sizes.append(split_size_or_section)
4342

4443
start = [0] * len(input.shape)
4544
stride = [1] * len(start)
@@ -65,7 +64,7 @@ def split(
6564
output = []
6665
for i in range(num_splits):
6766
shape = list(input.shape)
68-
shape[dim] = min(split_sizes[i], cast(int, max_offset - offset))
67+
shape[dim] = min(split_sizes[i], max_offset - offset)
6968
start[dim] = offset
7069
if dynamic_shape:
7170
shape = get_shape_with_dynamic_shape(

0 commit comments

Comments
 (0)