|
16 | 16 | import itertools
|
17 | 17 | import logging
|
18 | 18 | from dataclasses import dataclass, field
|
19 |
| -from typing import Callable, cast, Dict, List, Optional, Sequence |
| 19 | +from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Union |
20 | 20 |
|
21 | 21 | import torch
|
22 | 22 | import torch.fx
|
@@ -698,16 +698,45 @@ def traverse_intermediate_node(self, node: torch.fx.Node, sg: Subgraph) -> None:
|
698 | 698 | sg.is_valid = False
|
699 | 699 |
|
700 | 700 | def is_starting_permute(self, node: torch.fx.Node) -> bool:
|
701 |
| - return ( |
702 |
| - node.target == exir_ops.edge.aten.permute_copy.default |
703 |
| - and cast(list[int], node.args[1]) == self.to_NCHW |
704 |
| - ) |
| 701 | + return self.is_boundary_permute(node, self.to_NCHW) |
705 | 702 |
|
706 | 703 | def is_ending_permute(self, node: torch.fx.Node) -> bool:
|
707 |
| - return ( |
708 |
| - node.target == exir_ops.edge.aten.permute_copy.default |
709 |
| - and cast(list[int], node.args[1]) == self.to_NHWC |
710 |
| - ) |
| 704 | + return self.is_boundary_permute(node, self.to_NHWC) |
| 705 | + |
| 706 | + @staticmethod |
| 707 | + def is_boundary_permute(node: torch.fx.Node, permute_dims: Iterable[int]) -> bool: |
| 708 | + permute_dims = list(permute_dims) |
| 709 | + if node.target == exir_ops.edge.aten.permute_copy.default: |
| 710 | + return cast(list[int], node.args[1]) == permute_dims |
| 711 | + elif node.target == exir_ops.edge.aten.view_copy.default: |
| 712 | + # If there's a view node, check if it's swapping two dimensions and |
| 713 | + # not splitting any others from the input shape. |
| 714 | + inp = node.args[0] |
| 715 | + if not isinstance(inp, torch.fx.Node): |
| 716 | + return False |
| 717 | + input_shape = inp.meta["val"].shape |
| 718 | + output_shape = node.args[1] |
| 719 | + assert isinstance(output_shape, (tuple, list)) |
| 720 | + # If the shapes are equal in length, no dimension is being split or |
| 721 | + # grouped. Then check if a permute of the input shape results in the output shape. |
| 722 | + return ( |
| 723 | + len(input_shape) == len(output_shape) |
| 724 | + and len(input_shape) == len(permute_dims) |
| 725 | + and RemovePermutesAroundElementwiseOps.permute_shape( |
| 726 | + input_shape, permute_dims |
| 727 | + ) |
| 728 | + == output_shape |
| 729 | + ) |
| 730 | + else: |
| 731 | + return False |
| 732 | + |
| 733 | + @staticmethod |
| 734 | + def permute_shape( |
| 735 | + shape: Union[List[int], torch.Size], permute_dims: Iterable[int] |
| 736 | + ) -> List[int]: |
| 737 | + permute_dims = list(permute_dims) |
| 738 | + assert len(shape) == len(permute_dims) |
| 739 | + return [shape[p] for p in permute_dims] |
711 | 740 |
|
712 | 741 |
|
713 | 742 | # The following class consolidates functions to remove ops that are redundant
|
|
0 commit comments