Skip to content

Commit 24bd94e

Browse files
tatwaichongfacebook-github-bot
authored andcommitted
Adjust pad value to meet the strict convolution shape calculation (#2059)
Summary: torch.nn.Conv2d does not require the result of `(input_size + 2 * pad - dilation * (kernel_size - 1) - 1) / stride` must be an integer, but tosa currently strictly require this property. Add a simple function to adjust the pad value to meet the requirement. Pull Request resolved: #2059 Reviewed By: mcr229 Differential Revision: D54214138 Pulled By: digantdesai fbshipit-source-id: 8ae0d3a0aabe47c61767c7ba6afa6d525054a566
1 parent f327e53 commit 24bd94e

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

backends/arm/operators/op_conv2d.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@ class Conv2dVisitor(NodeVisitor):
2828
def __init__(self, *args):
2929
super().__init__(*args)
3030

31+
# torch.nn.Conv2d does not require the result of
32+
# `(input + 2 * pad - dilation * (weight - 1) - 1) / stride`
33+
# must be an integer, but tosa currently strictly require this property.
34+
# This function adjusts the pad value to meet the requirement.
35+
def adjust_pad_if_needed(self, input, weight, stride, pad, dilation):
36+
mod_remainder = (input + 2 * pad - dilation * (weight - 1) - 1) % stride
37+
38+
# No need to adjust
39+
if mod_remainder == 0:
40+
return pad
41+
42+
if mod_remainder > pad:
43+
raise RuntimeError(
44+
f"ignoring input element is not currently supported, got a large stride {stride}"
45+
)
46+
47+
return pad - mod_remainder
48+
3149
def define_node(
3250
self,
3351
node: torch.fx.Node,
@@ -52,6 +70,23 @@ def define_node(
5270
pad_attr = [val for val in pad.special for _ in (0, 1)]
5371
stride_attr = stride.special
5472
dilation_attr = dilation.special
73+
74+
# Adjust the pad value if needed to meet the strict convolution output shape calculation.
75+
pad_attr[1] = self.adjust_pad_if_needed(
76+
input.shape[2],
77+
weight.shape[2],
78+
stride_attr[0],
79+
pad_attr[1],
80+
dilation_attr[0],
81+
)
82+
pad_attr[3] = self.adjust_pad_if_needed(
83+
input.shape[3],
84+
weight.shape[3],
85+
stride_attr[1],
86+
pad_attr[3],
87+
dilation_attr[1],
88+
)
89+
5590
attr.ConvAttribute(
5691
pad=pad_attr,
5792
stride=stride_attr,

backends/arm/test/test_models.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -214,6 +214,30 @@ def forward(self, x):
214214
x = self.conv2d(x)
215215
return x
216216

217+
# A test where `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` is not an integer.
218+
@register_test
219+
class simple_conv2d_3x3_1x3x12x12_st2_pad1(torch.nn.Module):
220+
data = torch.ones(1, 3, 12, 12)
221+
inputs = {
222+
TosaProfile.BI: (data,),
223+
TosaProfile.MI: (data,),
224+
}
225+
226+
def __init__(self):
227+
super().__init__()
228+
self.conv2d = torch.nn.Conv2d(
229+
in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=1
230+
)
231+
with torch.no_grad():
232+
self.conv2d.weight.copy_(
233+
rand_test_integers(low=1, high=4, size=(4, 3, 3, 3))
234+
)
235+
self.conv2d.bias.copy_(rand_test_integers(low=1, high=4, size=(4)))
236+
237+
def forward(self, x):
238+
x = self.conv2d(x)
239+
return x
240+
217241
@register_test
218242
class simple_conv2d_1x1_1x2x128x128_stride1(torch.nn.Module):
219243
data = torch.from_numpy(

0 commit comments

Comments
 (0)