Skip to content

fix: Bug in slice operator with default inputs #2463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,11 @@ def aten_ops_select(


@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_slice(
ctx: ConversionContext,
target: Target,
Expand All @@ -700,9 +705,9 @@ def aten_ops_slice(
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
args[3],
args_bounds_check(args, 1, replacement=0),
args_bounds_check(args, 2, replacement=None),
args_bounds_check(args, 3, replacement=None),
args_bounds_check(args, 4, replacement=1),
)

Expand Down Expand Up @@ -900,6 +905,11 @@ def aten_ops_clone_copy_placeholder(


@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_expand(
ctx: ConversionContext,
target: Target,
Expand Down
10 changes: 5 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
import tensorrt as trt
import torch
from torch import SymBool, SymFloat, SymInt
from torch.fx.node import Argument, Target
Expand All @@ -20,8 +21,6 @@
)
from torch_tensorrt.fx.types import TRTDataType, TRTTensor

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -339,8 +338,8 @@ def get_positive_dim(
) -> Union[int, Tuple[int, ...]]:
"""
Given an integer number or tuple that represents dimension(s) in the array,
transform it to a positive integer dim if it's negative. Otherwise, do
nothing.
transform it to a positive integer dim if it's negative.
Otherwise, truncate it to the dimension size

Args:
dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array.
Expand All @@ -353,7 +352,8 @@ def get_positive_dim(
def positive_dim(d: int) -> int:
if d < 0:
return d % dim_size
return d
else:
return min(d, dim_size)

return (
positive_dim(dim)
Expand Down
24 changes: 9 additions & 15 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@ def slice_op( # TODO: This should be slice not whatever is in base
name: str,
input: TRTTensor,
dim: int,
start: int,
stop: int,
start: Optional[int],
stop: Optional[int],
step: int,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"slice_tensor received input {input} that is not part "
"of the TensorRT region!"
)
# Special case for start being None
if start is None:
start = 0

# Special case for stop being None
if stop is None:
stop = input.shape[dim]

dim = get_positive_dim(dim, len(input.shape))
start = get_positive_dim(start, input.shape[dim])
Expand All @@ -45,9 +47,6 @@ def slice_op( # TODO: This should be slice not whatever is in base
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"

if stop == 2**63 - 1:
stop = input.shape[dim]

start_slice = [0] * len(input.shape)
start_slice[dim] = start
stride_slice = [1] * len(input.shape)
Expand All @@ -68,11 +67,6 @@ def expand(
input_t: TRTTensor,
shape: Shape,
) -> TRTTensor:
if not isinstance(input_t, TRTTensor):
raise RuntimeError(
f"expand received input {input_t} that is not a TensorRT ITensor"
)

shape_rank = len(shape)
initial_tensor_rank = len(input_t.shape)

Expand Down
35 changes: 26 additions & 9 deletions tests/py/dynamo/conversion/test_slice_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from .harness import DispatchTestCase


class TestSelectConverter(DispatchTestCase):
class TestSliceConverter(DispatchTestCase):
@parameterized.expand(
[
("select_dim_start_stop_step", 0, 0, 7, 2),
("select_dim_start_stop_step_offset", 1, 0, 7, 2),
("select_dim_start_stop_step_exact", 1, 0, 10, 2),
("select_dim_start_stop_step_negatives", -3, -2, -1, 1),
("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
("slice_dim_start_stop_step", 0, 0, 7, 2),
("slice_dim_start_stop_step_offset", 1, 0, 7, 2),
("slice_dim_start_stop_step_exact", 1, 0, 10, 2),
("slice_dim_start_stop_step_negatives", -3, -2, -1, 1),
("slice_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
("slice_dim_start_stop_step_past_end", 2, 0, 2048, 1),
("slice_dim_start_stop_step_none", 2, None, None, 1),
]
)
def test_slice(self, _, dim, start, stop, step):
Expand All @@ -32,12 +34,27 @@ def forward(self, input):
input,
)

def test_slice_empty(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.ops.aten.slice.Tensor(input)
return out

input = [torch.randn(10, 10, 3, 1)]
self.run_test(
TestModule(),
input,
)


class TestSelectConverterDynamicShape(DispatchTestCase):
class TestSliceConverterDynamicShape(DispatchTestCase):
@parameterized.expand(
[
("select_dim_start_stop_step", 1, 0, 7, 2),
("select_dim_start_stop_step", 1, 0, 10, 2),
("slice_dim_start_stop_step", 1, 0, 7, 2),
("slice_dim_start_stop_step", 1, 0, 10, 2),
]
)
def test_slice(self, _, dim, start, stop, step):
Expand Down