Skip to content

feat: support aten.index_select converter #2710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2782,3 +2782,28 @@ def aten_ops_roll(
args[1],
args_bounds_check(args, 2, []),
)


@dynamo_tensorrt_converter(torch.ops.aten.index_select.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
2: (TRTTensor,),
}
)
def aten_ops_index_select(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.index_select(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
elementwise,
embedding,
grid,
index,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can likely be removed - it seems to be causing a circular import error in CI

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It seems I overlooked removing an unnecessary import.

linear,
matmul,
normalization,
Expand Down
30 changes: 24 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def index(
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
# If any is not this flag will be set to False
_LOGGER.debug(
f"Determining whether aten.index constant-index optimization can be invoked"
"Determining whether aten.index constant-index optimization can be invoked"
)
is_numpy = all(
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
Expand Down Expand Up @@ -123,7 +123,7 @@ def index(
return identity_layer.get_output(0)
elif len(tensor_indices) == 1:
indices_tensor = get_trt_tensor(
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor"
)
index = adv_indx_indices[0]
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
Expand Down Expand Up @@ -204,7 +204,7 @@ def index(
cum_adv_index = cum_adv_index + adv_index
multiplier = multiplier * input_shape[adv_indx_indices[i]]
cum_adv_index = get_trt_tensor(
ctx, cum_adv_index, name + f"_index_sum_intermediate"
ctx, cum_adv_index, name + "_index_sum_intermediate"
)
else:
multiplier = get_trt_tensor(
Expand Down Expand Up @@ -263,7 +263,7 @@ def index(
adv_indx_count
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
):
_LOGGER.debug(f"The indices are continuous in this case")
_LOGGER.debug("The indices are continuous in this case")
concat_tensor_reshape.append(
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
)
Expand All @@ -287,7 +287,7 @@ def index(
source_ir,
)
unfold_tensor = regular_index_shuffle_layer.get_output(0)
_LOGGER.debug(f"The tensor is unfolded now")
_LOGGER.debug("The tensor is unfolded now")
_LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")

# Transpose folded advanced indexed axis to its original location.
Expand Down Expand Up @@ -342,7 +342,7 @@ def index(
reshape_output = unfold_advanced_shuffle_layer.get_output(0)

else:
_LOGGER.debug(f"The indices are not continuous in this case")
_LOGGER.debug("The indices are not continuous in this case")
concat_final_tensor = []
concat_final_tensor.append(cum_adv_index_shape_tensor)
for i in range(0, rank):
Expand Down Expand Up @@ -370,3 +370,21 @@ def index(
reshape_output = reshape_layer.get_output(0)

return reshape_output


def index_select(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
index: TRTTensor,
) -> TRTTensor:
# The axis parameter specifies the dimension along which to index.
dim = get_positive_dim(dim, len(input.shape))
gather_layer = ctx.net.add_gather(input, index, axis=dim)

set_layer_name(gather_layer, target, f"{name}_gather", source_ir)

return gather_layer.get_output(0)
41 changes: 41 additions & 0 deletions tests/py/dynamo/conversion/test_index_select_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestIndexSelectConverter(DispatchTestCase):
@parameterized.expand(
[
("1d_input", (10,), 0, (1,)),
("2d_input_dim_0", (10, 3), 0, (0, 2)),
("2d_input_dim_1", (5, 10), 1, (1, 2, 3)),
("2d_input_dim_-2", (5, 10), -2, (1, 2, 3)),
("3d_input_dim_0", (10, 5, 10), 0, (0, 5)),
("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case for a negative dim input

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a test case for a negative dim input and verified a test case. Thank you!

("3d_input_dim_-1", (10, 5, 10), -1, (3, 3, 4)),
("3d_input_dim_-3", (10, 5, 10), -3, (5, 3, 4)),
]
)
def test_index_select(self, _, source_shape, dim, indices_val):
class TestIndexSelect(torch.nn.Module):
def forward(self, source_tensor, indices_tensor):
return torch.ops.aten.index_select.default(
source_tensor, dim, indices_tensor
)

input = [
torch.randn(*source_shape, dtype=torch.float32),
torch.tensor([*indices_val], dtype=torch.int32),
]

self.run_test(
TestIndexSelect(),
input,
)


if __name__ == "__main__":
run_tests()