10
10
11
11
import torch
12
12
from executorch .exir import EdgeCompileConfig , to_edge
13
+
14
+ from executorch .exir .dim_order_utils import (
15
+ is_channel_last_dim_order ,
16
+ is_contiguous_dim_order ,
17
+ )
13
18
from torch .export import export
14
19
from torch .testing import FileCheck
15
20
@@ -22,15 +27,6 @@ class MemoryFormatTestSet:
22
27
23
28
24
29
class TestMemoryFormatOpsPass (unittest .TestCase ):
25
- def is_channel_last (self , x : torch .Tensor ):
26
- # This is a heuristic to determine if the input tensor is in NHWC (channel last)
27
- # due to we do not have a good way to infer the dimension order or the memory format
28
- # of the input tensor. Please not this function is specific for contiguous tensors
29
- # whose dim(1) is channel one only, other types of tensors may not work well
30
- # due to different channel configuration and memory arrangement.
31
-
32
- return x .stride (1 ) == 1
33
-
34
30
def memory_format_test_runner (self , test_set : MemoryFormatTestSet ):
35
31
aten_op_str = "torch.ops.aten._to_copy.default"
36
32
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
@@ -60,13 +56,13 @@ def memory_format_test_runner(self, test_set: MemoryFormatTestSet):
60
56
actual = epm .exported_program ().module ()(* test_set .sample_input )
61
57
self .assertTrue (torch .allclose (actual , expected ))
62
58
self .assertEqual (
63
- self . is_channel_last (actual ),
64
- self . is_channel_last (expected ),
59
+ is_channel_last_dim_order (actual ),
60
+ is_channel_last_dim_order (expected ),
65
61
)
66
62
if test_set .target_memory_format == torch .channels_last :
67
- self .assertTrue (self . is_channel_last (actual ))
63
+ self .assertTrue (is_channel_last_dim_order (actual ))
68
64
elif test_set .target_memory_format == torch .contiguous_format :
69
- self .assertFalse ( self . is_channel_last (actual ))
65
+ self .assertTrue ( is_contiguous_dim_order (actual ))
70
66
else :
71
67
raise RuntimeError ("Unknown memory format" )
72
68
0 commit comments