Skip to content

Commit e9cb030

Browse files
mcr229facebook-github-bot
authored andcommitted
Fix Concat Memory
Summary: There is a bug with some of the tensor manipulation functions when we operate them in NHWC. We serialize the permuted tensor shape. But other arguments which reference dimensions within the shape are not properly permuted as well. For example, if we are concating two tensors along dim 1. If we add a convolution before this node our memory format tagging will tag this node to operate in NHWC. As a result, we will actually be concatenating against dim 3 because the tensor was permuted. Reviewed By: digantdesai Differential Revision: D47413372 fbshipit-source-id: 3a1cd9b0ca678d2d673fbb09ca2a35888d61f986
1 parent ab98f7a commit e9cb030

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

backends/xnnpack/operators/op_cat.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
XNode,
1414
)
1515
from executorch.backends.xnnpack.utils.quant_utils import QuantParams
16+
from executorch.backends.xnnpack.utils.utils import PERM_NHWC_TO_NCHW
1617
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
1718

1819

@@ -49,6 +50,9 @@ def define_node(
4950
if len(node.args) > 1:
5051
axis = cast(int, node.args[1])
5152

53+
if "XNN_NHWC_NODE" in node.meta:
54+
axis = PERM_NHWC_TO_NCHW[axis]
55+
5256
if num_tensors_to_cat == 2:
5357
xnode = XNNConcatenate2(
5458
axis=axis,

backends/xnnpack/test/test_xnnpack.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,47 @@ def forward(self, x, y):
926926
(torch.randn(1, 2, 3), torch.randn(1, 2, 5)),
927927
)
928928

929+
def test_xnnpack_backend_concatenate_nhwc(self):
930+
class Concat(torch.nn.Module):
931+
def __init__(self):
932+
super().__init__()
933+
self.conv = torch.nn.Conv2d(
934+
in_channels=1,
935+
out_channels=3,
936+
kernel_size=(3, 3),
937+
padding=1,
938+
bias=False,
939+
)
940+
941+
def forward(self, x, y):
942+
x = self.conv(x)
943+
return torch.concatenate((y, x, y, x), 1)
944+
945+
self.lower_and_test_with_partitioner(
946+
Concat(), (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3))
947+
)
948+
949+
def test_xnnpack_backend_concatenate_nhwc2(self):
950+
class Concat(torch.nn.Module):
951+
def __init__(self):
952+
super().__init__()
953+
self.conv = torch.nn.Conv2d(
954+
in_channels=1,
955+
out_channels=3,
956+
kernel_size=(3, 3),
957+
padding=1,
958+
bias=False,
959+
)
960+
961+
def forward(self, x, y):
962+
x = self.conv(x)
963+
y = self.conv(y)
964+
return torch.concatenate((y, x, y, x), 3)
965+
966+
self.lower_and_test_with_partitioner(
967+
Concat(), (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3))
968+
)
969+
929970
def test_xnnpack_backend_slice_copy(self):
930971
class Slice(torch.nn.Module):
931972
def forward(self, x):

0 commit comments

Comments
 (0)