Skip to content

Commit fb70253

Browse files
apbosegs-olive
authored andcommitted
Converter reorg and select operation
select operation correction and linting changes
1 parent 294545c commit fb70253

File tree

3 files changed

+165
-0
lines changed

3 files changed

+165
-0
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torch_tensorrt.fx.converters.impl.normalization import layer_norm
3030
from torch_tensorrt.fx.converters.impl.normalization import softmax
3131
from torch_tensorrt.fx.converters.impl.squeeze import squeeze
32+
from torch_tensorrt.fx.converters.impl.select import select
3233

3334
_LOGGER: logging.Logger = logging.getLogger(__name__)
3435

@@ -626,6 +627,17 @@ def aten_ops_operator_add(
626627
return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name)
627628

628629

630+
@tensorrt_converter(torch.ops.aten.select.int)
631+
def aten_ops_select(
632+
network: TRTNetwork,
633+
target: Target,
634+
args: Tuple[Argument, ...],
635+
kwargs: Dict[str, Argument],
636+
name: str,
637+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
638+
return select(network, target, SourceIR.ATEN, name, args[0], args[1], args[2])
639+
640+
629641
@tensorrt_converter(operator.sub)
630642
def aten_ops_operator_sub(
631643
network: TRTNetwork,
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import operator
2+
import warnings
3+
from typing import Union, Callable, Any, Optional, cast
4+
5+
import numpy as np
6+
7+
# @manual=//deeplearning/trt/python:py_tensorrt
8+
import tensorrt as trt
9+
import torch
10+
from torch.fx.node import Target
11+
12+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape
13+
from torch_tensorrt.fx.converters.converter_utils import (
14+
SourceIR,
15+
get_positive_dim,
16+
has_dynamic_shape,
17+
to_numpy,
18+
)
19+
from torch_tensorrt.fx.converters.impl.shape import get_shape_with_dynamic_shape
20+
21+
22+
def select(
23+
network: TRTNetwork,
24+
target: Target,
25+
source_ir: Optional[SourceIR],
26+
name: str,
27+
input: TRTTensor,
28+
dim: Shape,
29+
index: Shape,
30+
) -> TRTTensor:
31+
if not isinstance(input, TRTTensor):
32+
raise RuntimeError(
33+
f"slice_tensor received input {input} that is not part "
34+
"of the TensorRT region!"
35+
)
36+
37+
ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0)
38+
dim = get_positive_dim(cast(int, dim), ranks)
39+
dynamic_shape = has_dynamic_shape(input.shape)
40+
if network.has_implicit_batch_dimension:
41+
if dim == 0:
42+
raise RuntimeError(
43+
f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
44+
)
45+
dim = dim - 1
46+
else:
47+
if dynamic_shape:
48+
# Check whether slice target dim is dynamic shape dim
49+
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
50+
index = index
51+
52+
if index >= input.shape[dim]:
53+
raise RuntimeError(
54+
f"cannot have index greater than the dimension length! {input.shape[dim]}"
55+
)
56+
output_shape = list(input.shape)
57+
output_shape[dim] = 1
58+
if dynamic_shape > 0:
59+
output_shape = get_shape_with_dynamic_shape(
60+
network, target, source_ir, name, output_shape, input
61+
)
62+
index_value = torch.tensor(index, dtype=torch.int32)
63+
indices_tensor = network.add_constant(
64+
index_value.shape, to_numpy(index_value)
65+
).get_output(0)
66+
layer = network.add_gather(input, indices_tensor, dim)
67+
out = layer.get_output(0)
68+
if len(out.shape) != 1:
69+
layer = network.add_shuffle(out)
70+
return layer.get_output(0)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import unittest
2+
3+
import torch
4+
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
5+
from parameterized import param, parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
8+
9+
10+
class TestSelectConverterOne(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
("select_dim_index", 1, 0),
14+
]
15+
)
16+
def test_select(self, _, dim, index):
17+
class TestModule(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, input):
22+
return torch.select(input, dim, index)
23+
24+
input = [torch.randn(1, 2)]
25+
self.run_test(
26+
TestModule(),
27+
input,
28+
expected_ops={torch.ops.aten.select.int},
29+
test_explicit_precision=True,
30+
)
31+
32+
33+
class TestSelectConverterTwo(DispatchTestCase):
34+
@parameterized.expand(
35+
[
36+
("select_dim_index", 1, 0),
37+
]
38+
)
39+
def test_select(self, _, dim, index):
40+
class TestModule(torch.nn.Module):
41+
def __init__(self):
42+
super().__init__()
43+
44+
def forward(self, input):
45+
return torch.select(input, dim, index)
46+
47+
input = [torch.randn(4, 4, 4, 4)]
48+
self.run_test(
49+
TestModule(),
50+
input,
51+
expected_ops={torch.ops.aten.select.int},
52+
test_explicit_precision=True,
53+
)
54+
55+
56+
class TestSelectConverterWithDynamicShape(DispatchTestCase):
57+
@parameterized.expand(
58+
[
59+
("select_dim_index", 1, 0),
60+
]
61+
)
62+
def test_select_with_dynamic_shape(self, _, dim, index):
63+
class TestModule(torch.nn.Module):
64+
def __init__(self):
65+
super().__init__()
66+
67+
def forward(self, input):
68+
return torch.select(input, dim, index)
69+
70+
input_spec = [
71+
InputTensorSpec(
72+
shape=(-1, 3, 3),
73+
dtype=torch.float32,
74+
shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))],
75+
),
76+
]
77+
self.run_test_with_dynamic_shape(
78+
TestModule(), input_spec, expected_ops={torch.ops.aten.select.int}
79+
)
80+
81+
82+
if __name__ == "__main__":
83+
run_tests()

0 commit comments

Comments
 (0)