Skip to content

Commit 49f0df0

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
make utils support empty dim order (#2142)
Summary: This update makes util function support empty dim order, to make the empty dim order behave the same as empty memory format (preserve_format). bypass-github-export-checks Differential Revision: D54236386
1 parent 9d6bf72 commit 49f0df0

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

exir/dim_order_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List
7+
from typing import List, Optional
88

99
import torch
1010

@@ -27,11 +27,13 @@ def _get_channels_last_dim_order(ndim: int) -> List[int]:
2727
raise AssertionError(f"Unsupported rank: {ndim}")
2828

2929

30-
def get_memory_format(dim_order: List[int]) -> torch.memory_format:
30+
def get_memory_format(dim_order: Optional[List[int]]) -> torch.memory_format:
3131
"""
3232
Given a dim_order try to map it to torch.memory_format
3333
"""
34-
if dim_order == _get_contiguous_dim_order(len(dim_order)):
34+
if dim_order is None:
35+
return torch.preserve_format
36+
elif dim_order == _get_contiguous_dim_order(len(dim_order)):
3537
return torch.contiguous_format
3638
elif len(dim_order) == 4 and dim_order == _get_channels_last_dim_order(
3739
len(dim_order)
@@ -43,11 +45,15 @@ def get_memory_format(dim_order: List[int]) -> torch.memory_format:
4345
)
4446

4547

46-
def get_dim_order(memory_format: torch.memory_format, ndim: int) -> List[int]:
48+
def get_dim_order(
49+
memory_format: Optional[torch.memory_format], ndim: int
50+
) -> Optional[List[int]]:
4751
"""
4852
Given a memory_format and a tensor rank, generate a dim_order
4953
"""
50-
if memory_format == torch.contiguous_format:
54+
if memory_format in [None, torch.preserve_format]:
55+
return None
56+
elif memory_format == torch.contiguous_format:
5157
return _get_contiguous_dim_order(ndim)
5258
elif memory_format == torch.channels_last:
5359
return _get_channels_last_dim_order(ndim)

0 commit comments

Comments
 (0)