Skip to content

Commit e93ad5f

Browse files
authored
Add transpose ops make convolutions channels-last.
Differential Revision: D62484686 Pull Request resolved: #6459
1 parent cbfdf78 commit e93ad5f

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ def quantized_conv_meta(
132132
out_shift: torch.Tensor,
133133
channel_last: bool = False,
134134
) -> torch.Tensor:
135-
out_channels, _in_channels, *kernel_size = weight.shape
135+
if channel_last:
136+
out_channels, *kernel_size, _ = weight.shape
137+
else:
138+
out_channels, _, *kernel_size = weight.shape
139+
136140
in_size = input.shape
137141
# Assert that the input tensor has at least 3 dimensions, and at most 6
138142
assert len(in_size) > 2
@@ -141,7 +145,13 @@ def quantized_conv_meta(
141145
# Compute the output tensor size
142146
output_size = (
143147
get_conv1d_output_size(
144-
in_size, out_channels, stride[1], padding[1], dilation[1], kernel_size[0]
148+
in_size,
149+
out_channels,
150+
stride[1],
151+
padding[1],
152+
dilation[1],
153+
kernel_size[0],
154+
channel_last,
145155
)
146156
if len(in_size) == 3
147157
else get_conv2d_output_size(

backends/cadence/aot/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,20 @@ def get_conv1d_output_size(
4343
padding: int,
4444
dilation: int,
4545
kernel_size: int,
46+
channel_last: bool,
4647
) -> torch.Size:
4748
assert len(in_size) == 3
48-
N, C, L = in_size
49+
if channel_last:
50+
N, L, C = in_size
51+
else:
52+
N, C, L = in_size
4953

5054
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
5155
lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
5256

53-
return torch.Size((in_size[0], out_channels, lout))
57+
if channel_last:
58+
return torch.Size((N, lout, out_channels))
59+
return torch.Size((N, out_channels, lout))
5460

5561

5662
# Get the output size of a 2D convolution given the input size and parameters
@@ -76,7 +82,8 @@ def get_conv2d_output_size(
7682
wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[
7783
1
7884
] + 1
79-
85+
if channel_last:
86+
return torch.Size((N, hout, wout, out_channels))
8087
return torch.Size((in_size[0], out_channels, hout, wout))
8188

8289

0 commit comments

Comments
 (0)