Skip to content

Commit 052f624

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Fix op_to_copy (#416)
Summary: Pull Request resolved: #416 `memory_format` is optional. Make contiguous as default. Only convert to nhwc when explicitly needed. Reviewed By: digantdesai Differential Revision: D49443773 fbshipit-source-id: 508d941557dc99b64462fc4749b9015af2d0a6ec
1 parent 0d58374 commit 052f624

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

backends/xnnpack/operators/op_to_copy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Dict
88

99
import torch
10+
from executorch.backends.transforms import get_shape
1011
from executorch.backends.xnnpack.operators.node_visitor import (
1112
get_tensor_value,
1213
NodeVisitor,
@@ -41,7 +42,7 @@ def define_node(
4142
vals_to_ids: Dict[torch.fx.Node, int],
4243
debug_handle: int,
4344
) -> None:
44-
memory_format_target = node.kwargs["memory_format"]
45+
memory_format_target = node.kwargs.get("memory_format", torch.contiguous_format)
4546
to_channels_last = bool(memory_format_target == torch.channels_last)
4647
to_contiguous = bool(memory_format_target == torch.contiguous_format)
4748
check_or_raise(
@@ -61,7 +62,7 @@ def define_node(
6162
vals_to_ids,
6263
quant_params=input_quant_params,
6364
convert_to_nhwc=(
64-
not to_channels_last
65+
(not to_channels_last) and len(get_shape(input_node)) == 4
6566
), # input is contiguous if converting out of channels last
6667
)
6768

0 commit comments

Comments
 (0)