Skip to content

Commit 8487a23

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
directly using dim order for memory format comparsion. (#2170)
Summary: bypass-github-export-checks Reviewed By: digantdesai Differential Revision: D54341919
1 parent c0b5f5f commit 8487a23

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

exir/dim_order_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,21 @@ def get_dim_order(
6161
raise AssertionError(
6262
f"Failed to generate dim_order for a given memory format: {memory_format}"
6363
)
64+
65+
66+
def is_channel_last_dim_order(tensor: torch.Tensor) -> bool:
67+
"""
68+
Check if a tensor has channels last dim order
69+
"""
70+
if tensor.dim() != 4:
71+
# Only support 4D tensors for channel list memory format.
72+
return False
73+
74+
return tensor.dim_order() == tuple(_get_channels_last_dim_order(tensor.dim()))
75+
76+
77+
def is_contiguous_dim_order(tensor: torch.Tensor) -> bool:
78+
"""
79+
Check if a tensor has contiguous dim order
80+
"""
81+
return tensor.dim_order() == tuple(_get_contiguous_dim_order(tensor.dim()))

exir/tests/test_memory_format_ops_pass.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010

1111
import torch
1212
from executorch.exir import EdgeCompileConfig, to_edge
13+
14+
from executorch.exir.dim_order_utils import (
15+
is_channel_last_dim_order,
16+
is_contiguous_dim_order,
17+
)
1318
from torch.export import export
1419
from torch.testing import FileCheck
1520

@@ -22,15 +27,6 @@ class MemoryFormatTestSet:
2227

2328

2429
class TestMemoryFormatOpsPass(unittest.TestCase):
25-
def is_channel_last(self, x: torch.Tensor):
26-
# This is a heuristic to determine if the input tensor is in NHWC (channel last)
27-
# due to we do not have a good way to infer the dimension order or the memory format
28-
# of the input tensor. Please not this function is specific for contiguous tensors
29-
# whose dim(1) is channel one only, other types of tensors may not work well
30-
# due to different channel configuration and memory arrangement.
31-
32-
return x.stride(1) == 1
33-
3430
def memory_format_test_runner(self, test_set: MemoryFormatTestSet):
3531
aten_op_str = "torch.ops.aten._to_copy.default"
3632
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
@@ -60,13 +56,13 @@ def memory_format_test_runner(self, test_set: MemoryFormatTestSet):
6056
actual = epm.exported_program().module()(*test_set.sample_input)
6157
self.assertTrue(torch.allclose(actual, expected))
6258
self.assertEqual(
63-
self.is_channel_last(actual),
64-
self.is_channel_last(expected),
59+
is_channel_last_dim_order(actual),
60+
is_channel_last_dim_order(expected),
6561
)
6662
if test_set.target_memory_format == torch.channels_last:
67-
self.assertTrue(self.is_channel_last(actual))
63+
self.assertTrue(is_channel_last_dim_order(actual))
6864
elif test_set.target_memory_format == torch.contiguous_format:
69-
self.assertFalse(self.is_channel_last(actual))
65+
self.assertTrue(is_contiguous_dim_order(actual))
7066
else:
7167
raise RuntimeError("Unknown memory format")
7268

0 commit comments

Comments
 (0)