Skip to content

Commit f888bdf

Browse files
authored
Revert "Milestone2.1: Partition to_dim_order_copy op in XNN delegate" (#12090)
Reverts #11286 Failing consistently after the offending PR https://hud.pytorch.org/hud/pytorch/executorch/main/1?per_page=50&name_filter=unittest
1 parent 2f7440d commit f888bdf

File tree

5 files changed

+0
-139
lines changed

5 files changed

+0
-139
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,6 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
395395
# The node requires nchw inputs
396396
for input_node in node.all_input_nodes:
397397
self.input_to_nchw(graph_module, input_node, node)
398-
elif node.target == exir_ops.edge.aten._to_copy.default:
399-
if node.kwargs["memory_format"] == torch.channels_last:
400-
self.mark_as_nhwc_node(node)
401-
else:
402-
self.mark_as_nchw_node(node)
403398
else:
404399
# The node can have inputs in any format (but all must be the
405400
# same format)

backends/xnnpack/partition/config/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
SquareRootConfig,
5151
SubConfig,
5252
TanhConfig,
53-
ToDimOrderCopyConfig,
5453
UpsampleBilinear2dConfig,
5554
)
5655
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -103,8 +102,6 @@
103102
ReciprocalSquareRootConfig,
104103
ReLUConfig,
105104
TanhConfig,
106-
ToDimOrderCopyConfig,
107-
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
108105
SigmoidConfig,
109106
SliceCopyConfig,
110107
SoftmaxConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -397,35 +397,6 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
397397
return [ConfigPrecisionType.FP32]
398398

399399

400-
class ToDimOrderCopyConfig(GenericNodePartitionerConfig):
401-
target_name = "_to_dim_order_copy.default"
402-
403-
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
404-
"""
405-
Only support dim order conversion partitioning, not DType conversions
406-
"""
407-
if not self.check_common_constraints(node, ep):
408-
return False
409-
410-
# Get input node and compare dtypes
411-
input_node = get_input_node(node, 0)
412-
input_dtype = input_node.meta["val"].dtype
413-
output_dtype = node.meta["val"].dtype
414-
415-
# Return False if doing dtype conversion
416-
if input_dtype != output_dtype:
417-
why(
418-
node,
419-
reason=f"dtype conversion from {input_dtype} to {output_dtype} is not supported",
420-
)
421-
return False
422-
423-
return True
424-
425-
def supported_precision_types(self) -> List[ConfigPrecisionType]:
426-
return [ConfigPrecisionType.FP32]
427-
428-
429400
class MeanDimConfig(GenericNodePartitionerConfig):
430401
target_name = "mean.dim"
431402

backends/xnnpack/test/ops/test_to_copy.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,23 +173,6 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
173173
.run_method_and_compare_outputs()
174174
)
175175

176-
class LinearConvDimSwap(torch.nn.Module):
177-
def __init__(self):
178-
super().__init__()
179-
self.conv1 = torch.nn.Conv2d(3, 3, 3)
180-
self.linear1 = torch.nn.Linear(4, 3)
181-
182-
def forward(self, x):
183-
y = self.linear1(x)
184-
y = y.to(memory_format=torch.channels_last)
185-
y = y.to(memory_format=torch.contiguous_format)
186-
return self.conv1(y)
187-
188-
LinearConvDimSwapModule = LinearConvDimSwap()
189-
190-
def test_conv_linear_dim_order_swap_partitioner(self):
191-
self.run_tester(self.LinearConvDimSwapModule, (torch.randn(1, 3, 6, 4),))
192-
193176
def test_qs8_channels_last_tagged_reshape_pass(self):
194177
for module, num_reshape in self.modules.items():
195178
(

0 commit comments

Comments
 (0)