Skip to content

Commit c919152

Browse files
hsharma35facebook-github-bot
authored andcommitted
Fix conv1D and conv2D custom ops with channel last = True. (#6459)
Summary: Fix shape meta kernel for conv1D/conv2D when channel_last=True. Reviewed By: zonglinpeng Differential Revision: D62484686
1 parent 4f12131 commit c919152

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,16 @@ def quantized_conv_meta(
139139
assert len(in_size) < 6
140140

141141
# Compute the output tensor size
142+
# TODO(hardiksharma): Is this consistent with our implementation?
142143
output_size = (
143144
get_conv1d_output_size(
144-
in_size, out_channels, stride[1], padding[1], dilation[1], kernel_size[0]
145+
in_size,
146+
out_channels,
147+
stride[0],
148+
padding[0],
149+
dilation[0],
150+
kernel_size[0],
151+
channel_last,
145152
)
146153
if len(in_size) == 3
147154
else get_conv2d_output_size(

backends/cadence/aot/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,16 @@ def get_conv1d_output_size(
4343
padding: int,
4444
dilation: int,
4545
kernel_size: int,
46+
channel_last: bool,
4647
) -> torch.Size:
4748
assert len(in_size) == 3
4849
N, C, L = in_size
4950

5051
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
5152
lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
5253

54+
if channel_last:
55+
return torch.Size((N, lout, out_channels))
5356
return torch.Size((in_size[0], out_channels, lout))
5457

5558

@@ -76,7 +79,8 @@ def get_conv2d_output_size(
7679
wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[
7780
1
7881
] + 1
79-
82+
if channel_last:
83+
return torch.Size((N, hout, wout, out_channels))
8084
return torch.Size((in_size[0], out_channels, hout, wout))
8185

8286

0 commit comments

Comments
 (0)