Skip to content

Commit 96bdead

Browse files
committed
Implementation of slice and select operations
1 parent 79286e7 commit 96bdead

File tree

4 files changed

+138
-24
lines changed

4 files changed

+138
-24
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,19 @@ def aten_ops_sigmoid(
572572
"input": args[0],
573573
}
574574
return add_sigmoid(network, target, kwargs_new, name)
575+
576+
577+
@tensorrt_converter(torch.ops.aten.select)
578+
def aten_ops_select(
579+
network: TRTNetwork,
580+
target: Target,
581+
args: Tuple[Argument, ...],
582+
kwargs: Dict[str, Argument],
583+
name: str,
584+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
585+
kwargs_new = {
586+
"input": args[0],
587+
"dim": args[1],
588+
"index": args[2],
589+
}
590+
return add_select(network, target.kwargs_new, name)

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,3 +543,27 @@ def type_cast(
543543
layer_i.set_output_type(0, cast_type)
544544
set_layer_name(layer_i, target, f"{name}_dtype_change")
545545
return layer_i.get_output(0)
546+
547+
548+
def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]:
549+
"""
550+
Convert a PyTorch Tensor to a Numpy Array. If the tensor is
551+
quantized it will be dequantized first.
552+
553+
Args:
554+
tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
555+
556+
Returns:
557+
A Numpy array.
558+
"""
559+
560+
if tensor is None:
561+
return tensor
562+
563+
assert isinstance(
564+
tensor, torch.Tensor
565+
), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}"
566+
if tensor.is_quantized:
567+
tensor = tensor.dequantize()
568+
569+
return tensor.cpu().detach().contiguous().numpy()

py/torch_tensorrt/fx/converters/operator.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .converter_utils import prepend_ones
2323
from .converter_utils import has_dynamic_shape
2424
from .converter_utils import get_shape_with_dynamic_shape
25+
from .converter_utils import to_numpy
2526

2627
from ..types import (
2728
Shape,
@@ -278,30 +279,6 @@ def trunc_div(
278279
return output
279280

280281

281-
def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]:
282-
"""
283-
Convert a PyTorch Tensor to a Numpy Array. If the tensor is
284-
quantized it will be dequantized first.
285-
286-
Args:
287-
tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
288-
289-
Returns:
290-
A Numpy array.
291-
"""
292-
293-
if tensor is None:
294-
return tensor
295-
296-
assert isinstance(
297-
tensor, torch.Tensor
298-
), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}"
299-
if tensor.is_quantized:
300-
tensor = tensor.dequantize()
301-
302-
return tensor.cpu().detach().contiguous().numpy()
303-
304-
305282
def trt_dtype_to_torch_dtype(trt_dtype):
306283
table = {
307284
trt.bool: torch.bool,
@@ -1050,3 +1027,44 @@ def add_expand(network, target, kwargs, name):
10501027
layer = network.add_slice(input_val, start=start, shape=shape, stride=stride)
10511028
set_layer_name(layer, target, name)
10521029
return layer.get_output(0)
1030+
1031+
1032+
def add_select(network, target, kwargs, name):
1033+
input_val = kwargs["input"]
1034+
if not isinstance(input_val, TRTTensor):
1035+
raise RuntimeError(
1036+
f"slice_tensor received input {input_val} that is not part "
1037+
"of the TensorRT region!"
1038+
)
1039+
1040+
ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
1041+
dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
1042+
dynamic_shape = has_dynamic_shape(input_val.shape)
1043+
if network.has_implicit_batch_dimension:
1044+
if dim == 0:
1045+
raise RuntimeError(
1046+
f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
1047+
)
1048+
dim = dim - 1
1049+
else:
1050+
if dynamic_shape:
1051+
# Check whether slice target dim is dynamic shape dim
1052+
assert (
1053+
input_val.shape[dim] != -1
1054+
), "Can't select on negative shape dimension!"
1055+
index = kwargs[2]
1056+
if index >= input_val.shape[dim]:
1057+
raise RuntimeError(
1058+
f"cannot have index greater than the dimension length! {input_val.shape[dim]}"
1059+
)
1060+
output_shape = list(input_val.shape)
1061+
output_shape[dim] = 1
1062+
if dynamic_shape > 0:
1063+
output_shape = get_shape_with_dynamic_shape(
1064+
network, output_shape, input_val, target, name
1065+
)
1066+
layer = network.add_gather(input_val, dim, index)
1067+
out = layer.getOutput(0)
1068+
if len(out.shape) != 1:
1069+
layer = network.add_shuffle(out)
1070+
return layer.getOutput(0)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import unittest
2+
3+
import torch
4+
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
5+
from parameterized import param, parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
8+
9+
10+
class TestSelectConverter(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
("select_dim_index", 2, 1),
14+
]
15+
)
16+
def test_select(self, _, dim_test, index_test):
17+
class TestModule(torch.nn.Module):
18+
def __init__(self, dim, index):
19+
super().__init__()
20+
self.dim = dim
21+
self.index = index
22+
23+
def forward(self, input):
24+
return torch.select(input, self.dim, self.index)
25+
26+
input = [torch.randn(1, 3, 32)]
27+
self.run_test(
28+
TestModule(dim_test, index_test),
29+
input,
30+
expected_ops={torch.ops.aten.select},
31+
test_explicit_precision=True,
32+
)
33+
34+
# def test_select_with_dynamic_shape(self, _, dim_test, index_test):
35+
# class TestModule(torch.nn.Module):
36+
# def __init__(self, dim, index):
37+
# super().__init__()
38+
# self.dim = dim
39+
# self.index = index
40+
# def forward(self, input):
41+
# return torch.select(input, self.dim, self.index)
42+
43+
# input_spec = [
44+
# InputTensorSpec(
45+
# shape=(-1, 3, 32),
46+
# dtype=torch.float32,
47+
# shape_ranges=[((1, 3, 3), (3, 3, 3), (32, 32, 32))],
48+
# ),
49+
# ]
50+
# self.run_test_with_dynamic_shape(
51+
# TestModule(dim_test, index_test), input_spec, expected_ops={torch.ops.aten.select}
52+
# )
53+
54+
55+
if __name__ == "__main__":
56+
run_tests()

0 commit comments

Comments
 (0)