Skip to content

Commit 1c4b16d

Browse files
EgorLakomkinWei Wei
authored andcommitted
Add support for torch.nn.functional.max_pool1d in fx2trt. (#15)
Summary: Pull Request resolved: pytorch/fx2trt#15 This commit adds support of max_pool1d in torch fx2trt Reviewed By: 842974287 Differential Revision: D34695867 fbshipit-source-id: 5f06e526edf4cdd5fe8ffee022802aade5808715
1 parent 3047783 commit 1c4b16d

File tree

4 files changed

+95
-0
lines changed

4 files changed

+95
-0
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,53 @@ def acc_ops_fmod(
11011101
)
11021102
return sub_value
11031103

1104+
@tensorrt_converter(acc_ops.max_pool1d, no_explicit_batch_dim=True)
1105+
def acc_ops_max_pool1d(
1106+
network: TRTNetwork,
1107+
target: Target,
1108+
args: Tuple[Argument, ...],
1109+
kwargs: Dict[str, Argument],
1110+
name: str,
1111+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1112+
if not network.has_implicit_batch_dimension:
1113+
raise RuntimeError("Current implementation does not support dynamic shape. Make sure that the network has an explicit batch dimension!")
1114+
1115+
input_trt = kwargs["input"]
1116+
1117+
# adds unsqueeze layer -> max pool 2d -> squeeze layer to emulate max pool 1d.
1118+
unsqueeze_layer = network.add_shuffle(input=input_trt)
1119+
unsqueeze_layer.reshape_dims = tuple([*input_trt.shape, 1])
1120+
set_layer_name(unsqueeze_layer, target, name+"_unsqueeze")
1121+
1122+
input_trt = unsqueeze_layer.get_output(0)
1123+
1124+
kernel_size = kwargs["kernel_size"]
1125+
stride = kwargs["stride"]
1126+
padding = kwargs["padding"]
1127+
dilation = kwargs["dilation"]
1128+
ceil_mode = kwargs["ceil_mode"]
1129+
1130+
if any([not isinstance(param, int) for param in [kernel_size, stride, padding, dilation]]):
1131+
raise RuntimeError(f"Parameters kernel_size, stride, padding, and dilation should be of type int.")
1132+
if dilation != 1:
1133+
raise RuntimeError(
1134+
f"Only support dilation=1 for maxpool, but got {dilation}"
1135+
)
1136+
1137+
max_pooling_layer = network.add_pooling(
1138+
input=input_trt, type=trt.PoolingType.MAX, window_size=(kernel_size, 1)
1139+
)
1140+
max_pooling_layer.stride_nd = (stride, 1)
1141+
max_pooling_layer.padding_nd = (padding, 0)
1142+
set_layer_name(max_pooling_layer, target, name)
1143+
1144+
if ceil_mode:
1145+
max_pooling_layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP
1146+
input_trt = max_pooling_layer.get_output(0)
1147+
squeeze_layer = network.add_shuffle(input=input_trt)
1148+
squeeze_layer.reshape_dims = tuple(input_trt.shape[:-1])
1149+
set_layer_name(squeeze_layer, target, name+"_squeeze")
1150+
return squeeze_layer.get_output(0)
11041151

11051152
@tensorrt_converter(acc_ops.max_pool2d)
11061153
def acc_ops_max_pool2d(

test/converters/acc_op/test_maxpool.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,38 @@
88

99

1010
class TestMaxPoolConverter(AccTestCase):
11+
@parameterized.expand(
12+
[
13+
("default", 1),
14+
("kernel_3", 3),
15+
("stride", 1, 2),
16+
param("padding", 2, padding=1),
17+
param("padding_even", 5, padding=2),
18+
param("ceil_mode", 1, ceil_mode=True),
19+
]
20+
)
21+
def test_max_pool1d(self,
22+
test_name,
23+
kernel_size,
24+
stride=1,
25+
padding=0,
26+
dilation=1,
27+
ceil_mode=False,
28+
):
29+
class TestModule(torch.nn.Module):
30+
def __init__(self):
31+
super().__init__()
32+
self.max_pool = torch.nn.MaxPool1d(
33+
kernel_size, stride, padding, ceil_mode=ceil_mode, dilation=dilation
34+
)
35+
36+
def forward(self, x):
37+
return self.max_pool(x)
38+
39+
inputs = [torch.randn(1, 3, 224)]
40+
self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool1d}, test_explicit_batch_dim=False,)
41+
42+
1143
@parameterized.expand(
1244
[
1345
("default", 1),

test/tracer/test_acc_tracer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,6 +2062,7 @@ def test_all_acc_ops_registered(self):
20622062
acc_normalizer._acc_ops,
20632063
{
20642064
acc_ops.linear,
2065+
acc_ops.max_pool1d,
20652066
acc_ops.max_pool2d,
20662067
acc_ops.flatten,
20672068
acc_ops.adaptive_avg_pool2d,

tracer/acc_tracer/acc_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,21 @@ def squeeze(*, input, dim=None):
8080
return input.squeeze(dim=dim)
8181

8282

83+
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.max_pool1d))
84+
@register_acc_op
85+
def max_pool1d(
86+
*, input, kernel_size, stride, padding, dilation, ceil_mode, return_indices
87+
):
88+
return nn.functional.max_pool1d(
89+
input=input,
90+
kernel_size=kernel_size,
91+
stride=stride,
92+
padding=padding,
93+
dilation=dilation,
94+
ceil_mode=ceil_mode,
95+
return_indices=return_indices,
96+
)
97+
8398
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.max_pool2d))
8499
@register_acc_op
85100
def max_pool2d(

0 commit comments

Comments
 (0)