Skip to content

Commit 294545c

Browse files
apbosegs-olive
authored andcommitted
Converter reorg and squeeze operator
Correcting squeeze operator implementation, linting error and acc squeeze test Adding the condition to convert dim to int and removing the comment
1 parent 37d1168 commit 294545c

File tree

5 files changed

+171
-35
lines changed

5 files changed

+171
-35
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from torch_tensorrt.fx.converters.impl.unary.base import convert_unary
4040
from torch_tensorrt.fx.converters.impl.shape import get_shape_with_dynamic_shape
41+
from torch_tensorrt.fx.converters.impl.squeeze import squeeze
4142

4243
_LOGGER: logging.Logger = logging.getLogger(__name__)
4344

@@ -2064,40 +2065,14 @@ def acc_ops_squeeze(
20642065
kwargs: Dict[str, Argument],
20652066
name: str,
20662067
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2067-
input_val = kwargs["input"]
2068-
2069-
if not isinstance(input_val, TRTTensor):
2070-
raise RuntimeError(
2071-
f"squeeze received input {input_val} that is not part "
2072-
"of the TensorRT region!"
2073-
)
2074-
2075-
dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None)
2076-
# Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
2077-
# dim, which is a very rare case. For now we just claim not supporting dim=None.
2078-
assert dim is not None, "We don't support dim=None right now for squeeze."
2079-
2080-
dim = get_positive_dim(
2081-
dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
2068+
return squeeze(
2069+
network,
2070+
target,
2071+
SourceIR.ACC,
2072+
name,
2073+
kwargs["input"],
2074+
kwargs["dim"],
20822075
)
2083-
if network.has_implicit_batch_dimension:
2084-
assert dim != 0, "We don't support squeeze batch dim when it's implicit."
2085-
dim -= 1
2086-
2087-
assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
2088-
assert (
2089-
len(get_dynamic_dims(input_val.shape)) <= 1
2090-
), "Currently more than one dynamic dim for input to squeeze is not supported."
2091-
2092-
output_shape = []
2093-
for i, s in enumerate(input_val.shape):
2094-
if i == dim and s == 1:
2095-
continue
2096-
output_shape.append(s)
2097-
layer = network.add_shuffle(input_val)
2098-
layer.reshape_dims = tuple(output_shape)
2099-
set_layer_name(layer, target, name)
2100-
return layer.get_output(0)
21012076

21022077

21032078
@tensorrt_converter(acc_ops.add)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torch_tensorrt.fx.converters.impl.normalization import batch_norm
2929
from torch_tensorrt.fx.converters.impl.normalization import layer_norm
3030
from torch_tensorrt.fx.converters.impl.normalization import softmax
31+
from torch_tensorrt.fx.converters.impl.squeeze import squeeze
3132

3233
_LOGGER: logging.Logger = logging.getLogger(__name__)
3334

@@ -457,6 +458,18 @@ def aten_ops_sub(
457458
return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name)
458459

459460

461+
@tensorrt_converter(torch.ops.aten.squeeze.dim)
462+
@tensorrt_converter(torch.ops.aten.squeeze.dims)
463+
def aten_ops_squeeze(
464+
network: TRTNetwork,
465+
target: Target,
466+
args: Tuple[Argument, ...],
467+
kwargs: Dict[str, Argument],
468+
name: str,
469+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
470+
return squeeze(network, target, SourceIR.ATEN, name, args[0], args[1])
471+
472+
460473
@tensorrt_converter(torch.ops.aten.view.default)
461474
def aten_ops_reshape(
462475
network: TRTNetwork,
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import operator
2+
import warnings
3+
from typing import Optional, cast, Any
4+
5+
import numpy as np
6+
7+
import tensorrt as trt
8+
import torch
9+
from torch.fx.node import Target
10+
11+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape
12+
from torch_tensorrt.fx.converters.converter_utils import (
13+
SourceIR,
14+
get_positive_dim,
15+
)
16+
17+
from torch_tensorrt.fx.converters.converter_utils import (
18+
SourceIR,
19+
get_positive_dim,
20+
set_layer_name,
21+
)
22+
23+
from torch_tensorrt.fx.utils import get_dynamic_dims
24+
25+
26+
def squeeze(
27+
network: TRTNetwork,
28+
target: Target,
29+
source_ir: Optional[SourceIR],
30+
name: str,
31+
input: TRTTensor,
32+
dim: Optional[Any] = None,
33+
) -> TRTTensor:
34+
if not isinstance(input, TRTTensor):
35+
raise RuntimeError(
36+
f"squeeze received input {input} that is not part "
37+
"of the TensorRT region!"
38+
)
39+
dims = []
40+
if dim is not None:
41+
if isinstance(dim, int):
42+
dims.append(cast(Optional[int], dim))
43+
else:
44+
for dim in dim:
45+
dims.append(cast(Optional[int], dim))
46+
47+
# Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
48+
# dim, which is a very rare case. For now we just claim not supporting dim=None.
49+
assert not (len(dims) == 0), "We don't support dim=None right now for squeeze."
50+
51+
for dim in dims:
52+
dim = cast(Optional[int], dim)
53+
dim = get_positive_dim(
54+
dim,
55+
len(input.shape) + (1 if network.has_implicit_batch_dimension else 0),
56+
)
57+
if network.has_implicit_batch_dimension:
58+
assert dim != 0, "We don't support squeeze batch dim when it's implicit."
59+
dim -= 1
60+
61+
assert input.shape[dim] != -1, "We don't support squeeze dynamic dim."
62+
assert (
63+
len(get_dynamic_dims(input.shape)) <= 1
64+
), "Currently more than one dynamic dim for input to squeeze is not supported."
65+
66+
output_shape = []
67+
for i, s in enumerate(input.shape):
68+
if (i in dims) and s == 1:
69+
continue
70+
output_shape.append(s)
71+
layer = network.add_shuffle(input)
72+
layer.reshape_dims = tuple(output_shape)
73+
set_layer_name(layer, target, name)
74+
return layer.get_output(0)

py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@ def forward(self, x):
1212
return x.squeeze(2)
1313

1414
inputs = [torch.randn(1, 2, 1)]
15-
self.run_test(Squeeze(), inputs, expected_ops={acc_ops.squeeze})
15+
self.run_test(
16+
Squeeze(),
17+
inputs,
18+
expected_ops={acc_ops.squeeze},
19+
test_implicit_batch_dim=False,
20+
)
1621

1722
# Testing with shape=(-1, -1, -1, -1) results in error:
1823
# AssertionError: We don't support squeeze dynamic dim.
@@ -33,7 +38,9 @@ def forward(self, x):
3338
),
3439
]
3540
self.run_test_with_dynamic_shape(
36-
Squeeze(), input_specs, expected_ops={acc_ops.squeeze}
41+
Squeeze(),
42+
input_specs,
43+
expected_ops={acc_ops.squeeze},
3744
)
3845

3946

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
6+
7+
8+
class TestSqueezeConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
("2d_dim", (0), (2, 1)),
12+
("3d_one_dim", (0), (2, 2, 1)),
13+
("3d_two_dim", (0, 1), (2, 1, 1)),
14+
("4d_dim", (0, 1, 2), (2, 2, 1, 1)),
15+
]
16+
)
17+
def test_squeeze(self, _, dim, init_size):
18+
class Squeeze(nn.Module):
19+
def forward(self, x):
20+
return torch.squeeze(x, dim)
21+
22+
inputs = [torch.randn(*init_size)]
23+
expected_op = {}
24+
if isinstance(dim, int) == 1:
25+
expected_op = {torch.ops.aten.squeeze.dim}
26+
else:
27+
expected_op = {torch.ops.aten.squeeze.dims}
28+
self.run_test(
29+
Squeeze(),
30+
inputs,
31+
expected_ops=expected_op,
32+
)
33+
34+
35+
class TestSqueezeConverter(DispatchTestCase):
36+
@parameterized.expand(
37+
[
38+
("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]),
39+
("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]),
40+
# ("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]),
41+
]
42+
)
43+
def test_squeeze(self, _, dim, init_size, shape_range):
44+
class Squeeze(nn.Module):
45+
def forward(self, x):
46+
return torch.squeeze(x, dim)
47+
48+
if isinstance(dim, int) == 1:
49+
expected_op = {torch.ops.aten.squeeze.dim}
50+
else:
51+
expected_op = {torch.ops.aten.squeeze.dims}
52+
input_specs = [
53+
InputTensorSpec(
54+
shape=init_size,
55+
dtype=torch.float32,
56+
shape_ranges=shape_range,
57+
),
58+
]
59+
self.run_test_with_dynamic_shape(
60+
Squeeze(),
61+
input_specs,
62+
expected_ops=expected_op,
63+
)
64+
65+
66+
if __name__ == "__main__":
67+
run_tests()

0 commit comments

Comments
 (0)