Skip to content

Commit 9181e83

Browse files
authored
Make executorch stride sort size oblivious
Differential Revision: D68169867 Pull Request resolved: #7657
1 parent b450f9c commit 9181e83

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

exir/tensor.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import math
1515
import typing
16-
from typing import Dict, List, Optional, Tuple, Union
16+
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
1717

1818
import executorch.exir.schema as schema
1919
import torch
@@ -70,8 +70,29 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]:
7070
for _, s in enumerate(stride):
7171
if s == 0:
7272
raise ValueError("0 in strides is not supported for ExecuTorch.")
73+
74+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
75+
76+
class K(NamedTuple):
77+
stride: int
78+
79+
def __lt__(self, other):
80+
return guard_size_oblivious(self.stride < other.stride)
81+
82+
def __gt__(self, other):
83+
return guard_size_oblivious(self.stride > other.stride)
84+
85+
def __le__(self, other):
86+
return guard_size_oblivious(self.stride <= other.stride)
87+
88+
def __ge__(self, other):
89+
return guard_size_oblivious(self.stride >= other.stride)
90+
91+
def __eq__(self, other):
92+
return guard_size_oblivious(self.stride == other.stride)
93+
7394
sorted_dims = [
74-
i[0] for i in sorted(enumerate(stride), key=lambda x: x[1], reverse=True)
95+
i[0] for i in sorted(enumerate(stride), key=lambda x: K(x[1]), reverse=True)
7596
]
7697
return tuple(typing.cast(Tuple[bytes], sorted_dims))
7798

0 commit comments

Comments
 (0)