Skip to content

Commit c41a6db

Browse files
author
Wei Wei
committed
[fx2trt] avg_pool, max_pool ops improvement
Summary: nn.avg_pool1d,nn.avg_pool2d,nn.max_pool1d,nn.max_pool2d, will set stride=kernel_size if stride=None But nn.functional.avg_pool and nn.functional.max_pool does not set that. This will set stride=Node to TRT and leads to a TRT error. 1. This diff fix this hole. 2. Add more test cases 3. Fix LearningToPaint model in torchbench Reviewed By: yinghai Differential Revision: D34953195 fbshipit-source-id: 0eb73de29bc7461be1c7c6fa96f820829b2cce10
1 parent 12e12cc commit c41a6db

File tree

3 files changed

+140
-9
lines changed

3 files changed

+140
-9
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,24 +1121,28 @@ def acc_ops_max_pool1d(
11211121

11221122
input_trt = unsqueeze_layer.get_output(0)
11231123

1124-
kernel_size = kwargs["kernel_size"]
1125-
stride = kwargs["stride"]
1126-
padding = kwargs["padding"]
1127-
dilation = kwargs["dilation"]
1124+
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], 1)
1125+
stride = extend_attr_to_tuple(kwargs["stride"], 1)
1126+
padding = extend_attr_to_tuple(kwargs["padding"], 1)
1127+
dilation = extend_attr_to_tuple(kwargs["dilation"], 1)
1128+
11281129
ceil_mode = kwargs["ceil_mode"]
11291130

1130-
if any([not isinstance(param, int) for param in [kernel_size, stride, padding, dilation]]):
1131+
if len(stride) == 0 or stride[0] == None:
1132+
stride = kernel_size
1133+
1134+
if any([not isinstance(param, int) for param in [kernel_size[0], stride[0], padding[0], dilation[0]]]):
11311135
raise RuntimeError(f"Parameters kernel_size, stride, padding, and dilation should be of type int.")
1132-
if dilation != 1:
1136+
if dilation[0] != 1:
11331137
raise RuntimeError(
11341138
f"Only support dilation=1 for maxpool, but got {dilation}"
11351139
)
11361140

11371141
max_pooling_layer = network.add_pooling(
1138-
input=input_trt, type=trt.PoolingType.MAX, window_size=(kernel_size, 1)
1142+
input=input_trt, type=trt.PoolingType.MAX, window_size=(kernel_size[0], 1)
11391143
)
1140-
max_pooling_layer.stride_nd = (stride, 1)
1141-
max_pooling_layer.padding_nd = (padding, 0)
1144+
max_pooling_layer.stride_nd = stride +(1,)
1145+
max_pooling_layer.padding_nd = padding + (0,)
11421146
set_layer_name(max_pooling_layer, target, name)
11431147

11441148
if ceil_mode:
@@ -1171,6 +1175,9 @@ def acc_ops_max_pool2d(
11711175
dilation = extend_attr_to_tuple(kwargs["dilation"], 2)
11721176
ceil_mode = kwargs["ceil_mode"]
11731177

1178+
if len(stride) == 0 or stride[0] == None:
1179+
stride = kernel_size
1180+
11741181
if dilation != (1, 1):
11751182
raise RuntimeError(
11761183
f"Only support dilation=(1, 1) for maxpool, but got {dilation}"
@@ -1445,6 +1452,9 @@ def acc_ops_avg_pool1d(
14451452
ceil_mode = kwargs["ceil_mode"]
14461453
count_include_pad = kwargs["count_include_pad"]
14471454

1455+
if len(stride) == 0 or stride[0] == None:
1456+
stride = kernel_size
1457+
14481458
shuffle_layer = network.add_shuffle(input_val)
14491459
shuffle_layer.reshape_dims = tuple(input_val.shape) + (1,)
14501460
set_layer_name(shuffle_layer, target, name + "_shuffle1")
@@ -1493,6 +1503,9 @@ def acc_ops_avg_pool2d(
14931503
count_include_pad = kwargs["count_include_pad"]
14941504
divisor_override = kwargs["divisor_override"]
14951505

1506+
if len(stride) == 0 or stride[0] == None:
1507+
stride = kernel_size
1508+
14961509
if divisor_override:
14971510
raise RuntimeError("TensorRT does not support divisor_override.")
14981511

test/converters/acc_op/test_avgpool.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,70 @@ def forward(self, x):
8484
inputs = [torch.randn(1, 3, 224, 224)]
8585
self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool2d})
8686

87+
@parameterized.expand(
88+
[
89+
("kernal_size", 1),
90+
param("stride", 2, stride=()),
91+
]
92+
)
93+
def test_stride_none__avg_pool1d(
94+
self,
95+
test_name,
96+
kernel_size,
97+
stride=None,
98+
padding=0,
99+
ceil_mode=False,
100+
count_include_pad=True
101+
):
102+
class TestModule(torch.nn.Module):
103+
def __init__(self):
104+
super().__init__()
105+
106+
def forward(self, x):
107+
return torch.nn.functional.avg_pool1d(
108+
x,
109+
kernel_size,
110+
stride=stride,
111+
padding=padding,
112+
ceil_mode=ceil_mode,
113+
count_include_pad=count_include_pad,
114+
)
115+
116+
inputs = [torch.randn(1, 3, 224)]
117+
self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool1d})
118+
119+
@parameterized.expand(
120+
[
121+
("kernal_size", 2),
122+
param("stride", 2, stride=()),
123+
]
124+
)
125+
def test_stride_none_avg_pool2d(
126+
self,
127+
test_name,
128+
kernel_size,
129+
stride=None,
130+
padding=0,
131+
ceil_mode=False,
132+
count_include_pad=True,
133+
divisor_override=None,
134+
):
135+
class TestModule(torch.nn.Module):
136+
def __init__(self):
137+
super().__init__()
138+
def forward(self, x):
139+
return torch.nn.functional.avg_pool2d(
140+
x,
141+
kernel_size,
142+
stride=stride,
143+
padding=padding,
144+
ceil_mode=ceil_mode,
145+
count_include_pad=count_include_pad,
146+
divisor_override=divisor_override,
147+
)
148+
149+
inputs = [torch.randn(1, 3, 224, 224)]
150+
self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool2d})
151+
87152
if __name__ == '__main__':
88153
run_tests()

test/converters/acc_op/test_maxpool.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,5 +92,58 @@ def forward(self, x):
9292
TestModule(), input_specs, expected_ops={acc_ops.max_pool2d}
9393
)
9494

95+
@parameterized.expand(
96+
[
97+
("default", 1),
98+
param("stride", 2, stride=()),
99+
]
100+
)
101+
def test_stride_none_max_pool1d(self,
102+
test_name,
103+
kernel_size,
104+
stride=None,
105+
padding=0,
106+
dilation=1,
107+
ceil_mode=False,
108+
):
109+
class TestModule(torch.nn.Module):
110+
def __init__(self):
111+
super().__init__()
112+
113+
def forward(self, x):
114+
return torch.nn.functional.max_pool1d(
115+
x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode, dilation=dilation
116+
)
117+
118+
inputs = [torch.randn(1, 3, 224)]
119+
self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool1d}, test_explicit_batch_dim=False,)
120+
121+
122+
@parameterized.expand(
123+
[
124+
("default", 1),
125+
param("stride", 2, stride=()),
126+
]
127+
)
128+
def test_stride_none_max_pool2d(
129+
self,
130+
test_name,
131+
kernel_size,
132+
stride=None,
133+
padding=0,
134+
ceil_mode=False,
135+
):
136+
class TestModule(torch.nn.Module):
137+
def __init__(self):
138+
super().__init__()
139+
140+
def forward(self, x):
141+
return torch.nn.functional.max_pool2d(
142+
x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode
143+
)
144+
145+
inputs = [torch.randn(1, 3, 224, 224)]
146+
self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool2d})
147+
95148
if __name__ == '__main__':
96149
run_tests()

0 commit comments

Comments
 (0)