Skip to content

Commit b5dc751

Browse files
committed
fix: Bug in slice operator with default inputs (#2463)
1 parent ad6fa22 commit b5dc751

File tree

4 files changed

+52
-30
lines changed

4 files changed

+52
-30
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,11 @@ def aten_ops_select(
687687

688688

689689
@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor)
690+
@enforce_tensor_types(
691+
{
692+
0: (TRTTensor,),
693+
}
694+
)
690695
def aten_ops_slice(
691696
ctx: ConversionContext,
692697
target: Target,
@@ -700,9 +705,9 @@ def aten_ops_slice(
700705
SourceIR.ATEN,
701706
name,
702707
args[0],
703-
args[1],
704-
args[2],
705-
args[3],
708+
args_bounds_check(args, 1, replacement=0),
709+
args_bounds_check(args, 2, replacement=None),
710+
args_bounds_check(args, 3, replacement=None),
706711
args_bounds_check(args, 4, replacement=1),
707712
)
708713

@@ -877,6 +882,11 @@ def aten_ops_clone_copy_placeholder(
877882

878883

879884
@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
885+
@enforce_tensor_types(
886+
{
887+
0: (TRTTensor,),
888+
}
889+
)
880890
def aten_ops_expand(
881891
ctx: ConversionContext,
882892
target: Target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ def get_positive_dim(
339339
) -> Union[int, Tuple[int, ...]]:
340340
"""
341341
Given an integer number or tuple that represents dimension(s) in the array,
342-
transform it to a positive integer dim if it's negative. Otherwise, do
343-
nothing.
342+
transform it to a positive integer dim if it's negative.
343+
Otherwise, truncate it to the dimension size
344344
345345
Args:
346346
dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array.
@@ -353,7 +353,8 @@ def get_positive_dim(
353353
def positive_dim(d: int) -> int:
354354
if d < 0:
355355
return d % dim_size
356-
return d
356+
else:
357+
return min(d, dim_size)
357358

358359
return (
359360
positive_dim(dim)

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ def slice_op( # TODO: This should be slice not whatever is in base
2121
name: str,
2222
input: TRTTensor,
2323
dim: int,
24-
start: int,
25-
stop: int,
24+
start: Optional[int],
25+
stop: Optional[int],
2626
step: int,
2727
) -> TRTTensor:
28-
if not isinstance(input, TRTTensor):
29-
raise RuntimeError(
30-
f"slice_tensor received input {input} that is not part "
31-
"of the TensorRT region!"
32-
)
28+
# Special case for start being None
29+
if start is None:
30+
start = 0
31+
32+
# Special case for stop being None
33+
if stop is None:
34+
stop = input.shape[dim]
3335

3436
dim = get_positive_dim(dim, len(input.shape))
3537
start = get_positive_dim(start, input.shape[dim])
@@ -39,9 +41,6 @@ def slice_op( # TODO: This should be slice not whatever is in base
3941
# Check whether slice target dim is dynamic shape dim
4042
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"
4143

42-
if stop == 2**63 - 1:
43-
stop = input.shape[dim]
44-
4544
start_slice = [0] * len(input.shape)
4645
start_slice[dim] = start
4746
stride_slice = [1] * len(input.shape)
@@ -62,11 +61,6 @@ def expand(
6261
input_t: TRTTensor,
6362
shape: Shape,
6463
) -> TRTTensor:
65-
if not isinstance(input_t, TRTTensor):
66-
raise RuntimeError(
67-
f"expand received input {input_t} that is not a TensorRT ITensor"
68-
)
69-
7064
shape_rank = len(shape)
7165
initial_tensor_rank = len(input_t.shape)
7266

tests/py/dynamo/conversion/test_slice_aten.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
from .harness import DispatchTestCase
88

99

10-
class TestSelectConverter(DispatchTestCase):
10+
class TestSliceConverter(DispatchTestCase):
1111
@parameterized.expand(
1212
[
13-
("select_dim_start_stop_step", 0, 0, 7, 2),
14-
("select_dim_start_stop_step_offset", 1, 0, 7, 2),
15-
("select_dim_start_stop_step_exact", 1, 0, 10, 2),
16-
("select_dim_start_stop_step_negatives", -3, -2, -1, 1),
17-
("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
13+
("slice_dim_start_stop_step", 0, 0, 7, 2),
14+
("slice_dim_start_stop_step_offset", 1, 0, 7, 2),
15+
("slice_dim_start_stop_step_exact", 1, 0, 10, 2),
16+
("slice_dim_start_stop_step_negatives", -3, -2, -1, 1),
17+
("slice_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
18+
("slice_dim_start_stop_step_past_end", 2, 0, 2048, 1),
19+
("slice_dim_start_stop_step_none", 2, None, None, 1),
1820
]
1921
)
2022
def test_slice(self, _, dim, start, stop, step):
@@ -32,12 +34,27 @@ def forward(self, input):
3234
input,
3335
)
3436

37+
def test_slice_empty(self):
38+
class TestModule(torch.nn.Module):
39+
def __init__(self):
40+
super().__init__()
41+
42+
def forward(self, input):
43+
out = torch.ops.aten.slice.Tensor(input)
44+
return out
45+
46+
input = [torch.randn(10, 10, 3, 1)]
47+
self.run_test(
48+
TestModule(),
49+
input,
50+
)
51+
3552

36-
class TestSelectConverterDynamicShape(DispatchTestCase):
53+
class TestSliceConverterDynamicShape(DispatchTestCase):
3754
@parameterized.expand(
3855
[
39-
("select_dim_start_stop_step", 1, 0, 7, 2),
40-
("select_dim_start_stop_step", 1, 0, 10, 2),
56+
("slice_dim_start_stop_step", 1, 0, 7, 2),
57+
("slice_dim_start_stop_step", 1, 0, 10, 2),
4158
]
4259
)
4360
def test_slice(self, _, dim, start, stop, step):

0 commit comments

Comments
 (0)