Skip to content

Commit 7323c86

Browse files
authored
fix: Bug in slice operator with default inputs (#2463)
1 parent 60bfa04 commit 7323c86

File tree

4 files changed

+53
-32
lines changed

4 files changed

+53
-32
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,11 @@ def aten_ops_select(
687687

688688

689689
@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor)
690+
@enforce_tensor_types(
691+
{
692+
0: (TRTTensor,),
693+
}
694+
)
690695
def aten_ops_slice(
691696
ctx: ConversionContext,
692697
target: Target,
@@ -700,9 +705,9 @@ def aten_ops_slice(
700705
SourceIR.ATEN,
701706
name,
702707
args[0],
703-
args[1],
704-
args[2],
705-
args[3],
708+
args_bounds_check(args, 1, replacement=0),
709+
args_bounds_check(args, 2, replacement=None),
710+
args_bounds_check(args, 3, replacement=None),
706711
args_bounds_check(args, 4, replacement=1),
707712
)
708713

@@ -900,6 +905,11 @@ def aten_ops_clone_copy_placeholder(
900905

901906

902907
@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
908+
@enforce_tensor_types(
909+
{
910+
0: (TRTTensor,),
911+
}
912+
)
903913
def aten_ops_expand(
904914
ctx: ConversionContext,
905915
target: Target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
55

66
import numpy as np
7+
import tensorrt as trt
78
import torch
89
from torch import SymBool, SymFloat, SymInt
910
from torch.fx.node import Argument, Target
@@ -20,8 +21,6 @@
2021
)
2122
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
2223

23-
import tensorrt as trt
24-
2524
_LOGGER: logging.Logger = logging.getLogger(__name__)
2625

2726

@@ -339,8 +338,8 @@ def get_positive_dim(
339338
) -> Union[int, Tuple[int, ...]]:
340339
"""
341340
Given an integer number or tuple that represents dimension(s) in the array,
342-
transform it to a positive integer dim if it's negative. Otherwise, do
343-
nothing.
341+
transform it to a positive integer dim if it's negative.
342+
Otherwise, truncate it to the dimension size
344343
345344
Args:
346345
dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array.
@@ -353,7 +352,8 @@ def get_positive_dim(
353352
def positive_dim(d: int) -> int:
354353
if d < 0:
355354
return d % dim_size
356-
return d
355+
else:
356+
return min(d, dim_size)
357357

358358
return (
359359
positive_dim(dim)

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,17 @@ def slice_op( # TODO: This should be slice not whatever is in base
2727
name: str,
2828
input: TRTTensor,
2929
dim: int,
30-
start: int,
31-
stop: int,
30+
start: Optional[int],
31+
stop: Optional[int],
3232
step: int,
3333
) -> TRTTensor:
34-
if not isinstance(input, TRTTensor):
35-
raise RuntimeError(
36-
f"slice_tensor received input {input} that is not part "
37-
"of the TensorRT region!"
38-
)
34+
# Special case for start being None
35+
if start is None:
36+
start = 0
37+
38+
# Special case for stop being None
39+
if stop is None:
40+
stop = input.shape[dim]
3941

4042
dim = get_positive_dim(dim, len(input.shape))
4143
start = get_positive_dim(start, input.shape[dim])
@@ -45,9 +47,6 @@ def slice_op( # TODO: This should be slice not whatever is in base
4547
# Check whether slice target dim is dynamic shape dim
4648
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"
4749

48-
if stop == 2**63 - 1:
49-
stop = input.shape[dim]
50-
5150
start_slice = [0] * len(input.shape)
5251
start_slice[dim] = start
5352
stride_slice = [1] * len(input.shape)
@@ -68,11 +67,6 @@ def expand(
6867
input_t: TRTTensor,
6968
shape: Shape,
7069
) -> TRTTensor:
71-
if not isinstance(input_t, TRTTensor):
72-
raise RuntimeError(
73-
f"expand received input {input_t} that is not a TensorRT ITensor"
74-
)
75-
7670
shape_rank = len(shape)
7771
initial_tensor_rank = len(input_t.shape)
7872

tests/py/dynamo/conversion/test_slice_aten.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
from .harness import DispatchTestCase
88

99

10-
class TestSelectConverter(DispatchTestCase):
10+
class TestSliceConverter(DispatchTestCase):
1111
@parameterized.expand(
1212
[
13-
("select_dim_start_stop_step", 0, 0, 7, 2),
14-
("select_dim_start_stop_step_offset", 1, 0, 7, 2),
15-
("select_dim_start_stop_step_exact", 1, 0, 10, 2),
16-
("select_dim_start_stop_step_negatives", -3, -2, -1, 1),
17-
("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
13+
("slice_dim_start_stop_step", 0, 0, 7, 2),
14+
("slice_dim_start_stop_step_offset", 1, 0, 7, 2),
15+
("slice_dim_start_stop_step_exact", 1, 0, 10, 2),
16+
("slice_dim_start_stop_step_negatives", -3, -2, -1, 1),
17+
("slice_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
18+
("slice_dim_start_stop_step_past_end", 2, 0, 2048, 1),
19+
("slice_dim_start_stop_step_none", 2, None, None, 1),
1820
]
1921
)
2022
def test_slice(self, _, dim, start, stop, step):
@@ -32,12 +34,27 @@ def forward(self, input):
3234
input,
3335
)
3436

37+
def test_slice_empty(self):
38+
class TestModule(torch.nn.Module):
39+
def __init__(self):
40+
super().__init__()
41+
42+
def forward(self, input):
43+
out = torch.ops.aten.slice.Tensor(input)
44+
return out
45+
46+
input = [torch.randn(10, 10, 3, 1)]
47+
self.run_test(
48+
TestModule(),
49+
input,
50+
)
51+
3552

36-
class TestSelectConverterDynamicShape(DispatchTestCase):
53+
class TestSliceConverterDynamicShape(DispatchTestCase):
3754
@parameterized.expand(
3855
[
39-
("select_dim_start_stop_step", 1, 0, 7, 2),
40-
("select_dim_start_stop_step", 1, 0, 10, 2),
56+
("slice_dim_start_stop_step", 1, 0, 7, 2),
57+
("slice_dim_start_stop_step", 1, 0, 10, 2),
4158
]
4259
)
4360
def test_slice(self, _, dim, start, stop, step):

0 commit comments

Comments
 (0)