Skip to content

Commit 9bbdc9e

Browse files
apbosegs-olive
authored andcommitted
converter reorg and slice
converter reorg slice op Correcting linting error and slice changes Correcting the slice operation
1 parent fb70253 commit 9bbdc9e

File tree

5 files changed

+221
-0
lines changed

5 files changed

+221
-0
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torch_tensorrt.fx.converters.impl.normalization import softmax
3131
from torch_tensorrt.fx.converters.impl.squeeze import squeeze
3232
from torch_tensorrt.fx.converters.impl.select import select
33+
from torch_tensorrt.fx.converters.impl.slice import slice_op
3334

3435
_LOGGER: logging.Logger = logging.getLogger(__name__)
3536

@@ -673,6 +674,27 @@ def aten_ops_sym_numel(
673674
return reduce_layer.get_output(0)
674675

675676

677+
@tensorrt_converter(torch.ops.aten.slice.Tensor)
678+
def aten_ops_slice(
679+
network: TRTNetwork,
680+
target: Target,
681+
args: Tuple[Argument, ...],
682+
kwargs: Dict[str, Argument],
683+
name: str,
684+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
685+
return slice_op(
686+
network,
687+
target,
688+
SourceIR.ATEN,
689+
name,
690+
args[0],
691+
args[1],
692+
args[2],
693+
args[3],
694+
args[4],
695+
)
696+
697+
676698
@tensorrt_converter(torch.ops.aten.sym_size)
677699
def aten_ops_sym_size(
678700
network: TRTNetwork,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ops import *
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import operator
2+
import warnings
3+
from typing import Optional, cast
4+
5+
import numpy as np
6+
7+
import tensorrt as trt
8+
import torch
9+
from torch.fx.node import Target
10+
11+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape
12+
from torch_tensorrt.fx.converters.converter_utils import (
13+
SourceIR,
14+
has_dynamic_shape,
15+
set_layer_name,
16+
)
17+
18+
from torch_tensorrt.fx.converters.impl.shape import get_shape_with_dynamic_shape
19+
20+
21+
def slice(
22+
network: TRTNetwork,
23+
target: Target,
24+
source_ir: Optional[SourceIR],
25+
name: str,
26+
input: TRTTensor,
27+
start: Shape,
28+
shape: Shape,
29+
stride: Shape,
30+
) -> TRTTensor:
31+
dynamic_shape = has_dynamic_shape(input.shape)
32+
if dynamic_shape:
33+
shape = get_shape_with_dynamic_shape(
34+
network, target, source_ir, name, shape, input
35+
)
36+
layer = network.add_slice(
37+
input,
38+
start=start,
39+
shape=[] if dynamic_shape else shape,
40+
stride=stride,
41+
)
42+
if dynamic_shape:
43+
layer.set_input(2, shape)
44+
set_layer_name(layer, target, name)
45+
return layer.get_output(0)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import operator
2+
import warnings
3+
from typing import Optional, cast
4+
import math
5+
6+
import numpy as np
7+
8+
# @manual=//deeplearning/trt/python:py_tensorrt
9+
import tensorrt as trt
10+
import torch
11+
from torch.fx.node import Target
12+
13+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
14+
from torch_tensorrt.fx.converters.converter_utils import (
15+
SourceIR,
16+
set_layer_name,
17+
get_positive_dim,
18+
has_dynamic_shape,
19+
)
20+
from torch_tensorrt.fx.converters.impl.shape import get_shape_with_dynamic_shape
21+
from torch_tensorrt.fx.converters.impl.slice.base import slice
22+
23+
24+
def slice_op(
25+
network: TRTNetwork,
26+
target: Target,
27+
source_ir: Optional[SourceIR],
28+
name: str,
29+
input: TRTTensor,
30+
dim: int,
31+
start: int,
32+
stop: int,
33+
step: int,
34+
) -> TRTTensor:
35+
if not isinstance(input, TRTTensor):
36+
raise RuntimeError(
37+
f"slice_tensor received input {input} that is not part "
38+
"of the TensorRT region!"
39+
)
40+
41+
ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0)
42+
dim = get_positive_dim(cast(int, dim), ranks)
43+
dynamic_shape = has_dynamic_shape(input.shape)
44+
if network.has_implicit_batch_dimension:
45+
if dim == 0:
46+
raise RuntimeError(
47+
f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
48+
)
49+
dim = dim - 1
50+
else:
51+
if dynamic_shape:
52+
# Check whether slice target dim is dynamic shape dim
53+
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
54+
start_int = cast(int, start)
55+
stop_int = cast(int, stop)
56+
step_int = cast(int, step)
57+
start = [0] * len(input.shape)
58+
start[dim] = start_int
59+
stride = [1] * len(start)
60+
stride[dim] = step_int
61+
output_shape = list(input.shape)
62+
output_shape[dim] = math.ceil((stop_int - start_int) / step_int)
63+
64+
return slice(network, target, source_ir, name, input, start, output_shape, stride)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import unittest
2+
3+
import torch
4+
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
5+
from parameterized import param, parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
8+
9+
10+
class TestSelectConverterImplicitBatch(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
("select_dim_start_stop_step", 0, 0, 7, 2),
14+
]
15+
)
16+
def test_slice(self, _, dim, start, stop, step):
17+
class TestModule(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, input):
22+
out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step)
23+
return out
24+
25+
input = [torch.randn(10, 2, 3, 1)]
26+
self.run_test(
27+
TestModule(),
28+
input,
29+
expected_ops={torch.ops.aten.slice.Tensor},
30+
)
31+
32+
33+
class TestSelectConverterExplicitBatch(DispatchTestCase):
34+
@parameterized.expand(
35+
[
36+
("select_dim_start_stop_step", 1, 0, 7, 2),
37+
("select_dim_start_stop_step_exact", 1, 0, 10, 2),
38+
]
39+
)
40+
def test_slice(self, _, dim, start, stop, step):
41+
class TestModule(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
45+
def forward(self, input):
46+
out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step)
47+
return out
48+
49+
input = [torch.randn(10, 10, 3, 1)]
50+
self.run_test(
51+
TestModule(),
52+
input,
53+
expected_ops={torch.ops.aten.slice.Tensor},
54+
test_explicit_precision=True,
55+
)
56+
57+
58+
class TestSelectConverterDynamicShape(DispatchTestCase):
59+
@parameterized.expand(
60+
[
61+
("select_dim_start_stop_step", 1, 0, 7, 2),
62+
("select_dim_start_stop_step", 1, 0, 10, 2),
63+
]
64+
)
65+
def test_slice(self, _, dim, start, stop, step):
66+
class TestModule(torch.nn.Module):
67+
def __init__(self):
68+
super().__init__()
69+
70+
def forward(self, input):
71+
out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step)
72+
return out
73+
74+
input_specs = [
75+
InputTensorSpec(
76+
shape=(1, 10, -1),
77+
dtype=torch.float32,
78+
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
79+
),
80+
]
81+
self.run_test_with_dynamic_shape(
82+
TestModule(),
83+
input_specs,
84+
expected_ops={torch.ops.aten.slice.Tensor},
85+
)
86+
87+
88+
if __name__ == "__main__":
89+
run_tests()

0 commit comments

Comments
 (0)