|
11 | 11 |
|
12 | 12 | # pyre-unsafe
|
13 | 13 |
|
14 |
| -import copy |
15 | 14 | import math
|
16 | 15 | from operator import neg
|
17 | 16 | from typing import cast, Dict, Iterable, Sequence, Set, Tuple
|
|
36 | 35 | from executorch.backends.cadence.aot.utils import get_edge_overload_packet
|
37 | 36 | from executorch.exir.dialects._ops import ops as exir_ops
|
38 | 37 | from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
|
39 |
| -from executorch.exir.dim_order_utils import get_memory_format |
40 | 38 | from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
|
41 |
| -from executorch.exir.passes.dim_order_ops_registry import ( |
42 |
| - DimOrderOpsMap, |
43 |
| - MemoryFormatOpsMap, |
44 |
| -) |
45 | 39 | from torch._subclasses import FakeTensor
|
46 | 40 | from torch.fx.node import Argument
|
47 | 41 |
|
@@ -1805,72 +1799,6 @@ def call_operator(
|
1805 | 1799 | )
|
1806 | 1800 |
|
1807 | 1801 |
|
1808 |
| -@register_cadence_pass(CadencePassAttribute(opt_level=0)) |
1809 |
| -class ReplaceToDimOrderCopyWithToCopyPass(ExportPass): |
1810 |
| - """ |
1811 |
| - dim_order_ops::to_dim_order_copy is not supported, so this is an opt_level=0 pass. |
1812 |
| - If the dim order is sequential, we don't need the extra work with strides and |
1813 |
| - can just use to_copy. |
1814 |
| - """ |
1815 |
| - |
1816 |
| - def call_operator( |
1817 |
| - self, |
1818 |
| - op, |
1819 |
| - args: Tuple[Argument, ...], |
1820 |
| - kwargs: Dict[str, Argument], |
1821 |
| - meta: NodeMetadata, |
1822 |
| - ) -> ProxyValue: |
1823 |
| - if op not in DimOrderOpsMap: |
1824 |
| - return super().call_operator(op, args, kwargs, meta) |
1825 |
| - |
1826 |
| - # new kwargs with dim_order, and no memory_format for the new op |
1827 |
| - nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable |
1828 |
| - |
1829 |
| - ndim = None |
1830 |
| - |
1831 |
| - # can always get the shape, assuming rank is specialized |
1832 |
| - |
1833 |
| - # pyre-ignore[16]: `None` has no attribute `to_tensor` |
1834 |
| - if isinstance(args[0], ProxyValue) and args[0].is_tensor(): |
1835 |
| - # pyre-ignore[16]: `None` has no attribute `to_tensor` |
1836 |
| - ndim = args[0].to_tensor().dim() |
1837 |
| - elif isinstance(args[0], torch.Tensor): |
1838 |
| - # pyre-ignore[16]: `None` has no attribute `dim` |
1839 |
| - ndim = args[0].dim() |
1840 |
| - elif isinstance(args[0], torch.fx.immutable_collections.immutable_list): |
1841 |
| - # pyre-ignore[6]: Incompatible parameter type |
1842 |
| - ndim = len(args[0]) |
1843 |
| - else: |
1844 |
| - assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}" |
1845 |
| - |
1846 |
| - # get the "to" memory format for the EdgeOp |
1847 |
| - contiguous_dim_order = list(range(ndim)) |
1848 |
| - dim_order = nkwargs.pop("dim_order", None) |
1849 |
| - |
1850 |
| - # Cadence only supports contiguous memory format |
1851 |
| - assert ( |
1852 |
| - dim_order is None |
1853 |
| - # pyre-ignore[6]: Incompatible parameter type |
1854 |
| - or len(dim_order) == 0 |
1855 |
| - or dim_order == contiguous_dim_order |
1856 |
| - ), "Expected dim order in congituous or prevserve memory format, but got {}".format( |
1857 |
| - dim_order |
1858 |
| - ) |
1859 |
| - |
1860 |
| - # bring back memory format |
1861 |
| - # pyre-ignore[6]: Incompatible parameter type |
1862 |
| - nkwargs["memory_format"] = get_memory_format(dim_order) |
1863 |
| - |
1864 |
| - memory_format_op = MemoryFormatOpsMap[op] |
1865 |
| - |
1866 |
| - return super().call_operator( |
1867 |
| - memory_format_op, |
1868 |
| - args, |
1869 |
| - nkwargs, |
1870 |
| - meta, |
1871 |
| - ) |
1872 |
| - |
1873 |
| - |
1874 | 1802 | @register_cadence_pass(CadencePassAttribute(opt_level=0))
|
1875 | 1803 | class ReplaceFullLikeWithFullPass(ExportPass):
|
1876 | 1804 | """
|
@@ -2180,5 +2108,4 @@ class CadenceReplaceOpsInGraph:
|
2180 | 2108 | ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
|
2181 | 2109 | ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
|
2182 | 2110 | ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
|
2183 |
| - ReplaceToDimOrderCopyWithToCopyPass, |
2184 | 2111 | ]
|
0 commit comments