|
13 | 13 |
|
14 | 14 | import math
|
15 | 15 | import typing
|
16 |
| -from typing import Dict, List, Optional, Tuple, Union |
| 16 | +from typing import Dict, List, NamedTuple, Optional, Tuple, Union |
17 | 17 |
|
18 | 18 | import executorch.exir.schema as schema
|
19 | 19 | import torch
|
@@ -70,8 +70,29 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]:
|
70 | 70 | for _, s in enumerate(stride):
|
71 | 71 | if s == 0:
|
72 | 72 | raise ValueError("0 in strides is not supported for ExecuTorch.")
|
| 73 | + |
| 74 | + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious |
| 75 | + |
| 76 | + class K(NamedTuple): |
| 77 | + stride: int |
| 78 | + |
| 79 | + def __lt__(self, other): |
| 80 | + return guard_size_oblivious(self.stride < other.stride) |
| 81 | + |
| 82 | + def __gt__(self, other): |
| 83 | + return guard_size_oblivious(self.stride > other.stride) |
| 84 | + |
| 85 | + def __le__(self, other): |
| 86 | + return guard_size_oblivious(self.stride <= other.stride) |
| 87 | + |
| 88 | + def __ge__(self, other): |
| 89 | + return guard_size_oblivious(self.stride >= other.stride) |
| 90 | + |
| 91 | + def __eq__(self, other): |
| 92 | + return guard_size_oblivious(self.stride == other.stride) |
| 93 | + |
73 | 94 | sorted_dims = [
|
74 |
| - i[0] for i in sorted(enumerate(stride), key=lambda x: x[1], reverse=True) |
| 95 | + i[0] for i in sorted(enumerate(stride), key=lambda x: K(x[1]), reverse=True) |
75 | 96 | ]
|
76 | 97 | return tuple(typing.cast(Tuple[bytes], sorted_dims))
|
77 | 98 |
|
|
0 commit comments