Skip to content

Commit ab98f7a

Browse files
mcr229facebook-github-bot
authored andcommitted
Fix Slice Memory
Summary: There is a bug with some of the tensor manipulation functions when we operate them in NHWC. We serialize the permuted tensor shape. But other arguemtns which reference dimensions within the shape are not properly permuted as well. For example, if we are slicing dim 1. If we add a convolution before this node our memory format tagging will tag this node to operate in NHWC. As a result, we are actually slicing dim 3 because the tensor was permuted. Reviewed By: digantdesai Differential Revision: D47413374 fbshipit-source-id: d6bd661a30d084a48756a9f06fc5d8d16e12aebd
1 parent 1e4658e commit ab98f7a

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

backends/xnnpack/operators/op_slice_copy.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
XNNStaticSlice,
1111
XNode,
1212
)
13-
from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node
13+
from executorch.backends.xnnpack.utils.utils import (
14+
check_or_raise,
15+
get_input_node,
16+
PERM_NCHW_TO_NHWC,
17+
PERM_NHWC_TO_NCHW,
18+
)
1419

1520

1621
@register_node_visitor
@@ -52,8 +57,13 @@ def define_node(
5257
)
5358
output_tensor_val = node.meta["val"]
5459
output_shape = list(output_tensor_val.shape)
55-
5660
dim_of_slice = cast(int, node.args[1])
61+
62+
if "XNN_NHWC_NODE" in node.meta:
63+
input_shape = [input_shape[i] for i in PERM_NCHW_TO_NHWC]
64+
output_shape = [output_shape[i] for i in PERM_NCHW_TO_NHWC]
65+
dim_of_slice = PERM_NHWC_TO_NCHW[dim_of_slice]
66+
5767
slice_begin_index = cast(int, node.args[2])
5868
if slice_begin_index < 0:
5969
slice_begin_index = input_shape[dim_of_slice] + slice_begin_index

backends/xnnpack/test/test_xnnpack.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,3 +955,21 @@ def forward(self, x):
955955
self.lower_module_and_test_output(
956956
Slice(), (torch.randn(5, 5, 5),), use_partitioner=True
957957
)
958+
959+
def test_xnnpack_backend_slice_copy_memory_format(self):
960+
class ConvSlice(torch.nn.Module):
961+
def __init__(self):
962+
super().__init__()
963+
self.conv = torch.nn.Conv2d(
964+
in_channels=1,
965+
out_channels=3,
966+
kernel_size=(3, 3),
967+
padding=1,
968+
bias=False,
969+
)
970+
971+
def forward(self, x):
972+
y = self.conv(x)
973+
return y[:, :, 2:3, -2:]
974+
975+
self.lower_and_test_with_partitioner(ConvSlice(), (torch.randn(1, 1, 3, 3),))

0 commit comments

Comments
 (0)