Skip to content

Commit a3f792a

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK] Add pass to remove copy ops (#7326)
## Context This diff prepares Vulkan to handle dim order operators. For more context, see #4873 Since Vulkan has its own internal representation of memory layout, these ops are handled by simply remove explicit memory layout transition operators from the graph and let the memory metadata tagging pass insert the necessary memory layout transitions. A new pass is added to remove such operators, largely based on QNN's `RemoveRedundancy` pass. Differential Revision: [D67180898](https://our.internmc.facebook.com/intern/diff/D67180898/) ghstack-source-id: 258092214 Pull Request resolved: #7325 Co-authored-by: Stephen Jia <[email protected]>
1 parent 47c2f2e commit a3f792a

File tree

5 files changed

+91
-2
lines changed

5 files changed

+91
-2
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@ runtime.python_library(
4343
],
4444
)
4545

46+
runtime.python_library(
47+
name = "remove_redundant_ops",
48+
srcs = ["remove_redundant_ops.py"],
49+
visibility = [
50+
"//executorch/backends/...",
51+
],
52+
deps = [
53+
"//caffe2:torch",
54+
"//executorch/exir:pass_base",
55+
"//executorch/exir/dialects:lib",
56+
],
57+
)
58+
4659
runtime.python_library(
4760
name = "tag_memory_meta_pass",
4861
srcs = ["tag_memory_meta_pass.py"],
@@ -71,6 +84,7 @@ runtime.python_library(
7184
":insert_prepack_nodes",
7285
":int4_weight_only_quantizer",
7386
":remove_local_scalar_dense",
87+
":remove_redundant_ops",
7488
":tag_memory_meta_pass"
7589
]
7690
)

backends/vulkan/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
66
RemoveLocalScalarDenseOpsTransform,
77
)
8+
from executorch.backends.vulkan._passes.remove_redundant_ops import (
9+
RemoveRedundantOpsTransform,
10+
)
811
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
912

1013
__all__ = [
1114
"insert_prepack_nodes",
1215
"VkInt4WeightOnlyQuantizer",
1316
"RemoveLocalScalarDenseOpsTransform",
17+
"RemoveRedundantOpsTransform",
1418
"TagMemoryMetaPass",
1519
]
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Set, Union
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.passes import dead_code_elimination_pass
16+
17+
OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload]
18+
19+
20+
class RemoveRedundantOpsTransform(ExportPass):
21+
"""
22+
Trim certain operators to reduce unnecessary overhead.
23+
"""
24+
25+
redundant_ops: Set[OpType] = {
26+
torch.clone,
27+
torch.ops.aten.clone.default,
28+
exir_ops.edge.aten.clone.default,
29+
torch.ops.aten.alias.default,
30+
exir_ops.edge.aten.alias.default,
31+
exir_ops.edge.aten.lift_fresh_copy.default,
32+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
33+
}
34+
35+
def __init__(self) -> None:
36+
super(RemoveRedundantOpsTransform, self).__init__()
37+
38+
def _should_remove(self, node: torch.fx.Node) -> bool:
39+
if node.target in self.redundant_ops:
40+
return True
41+
42+
# Only remove to_copy if dtype does not change. Otherwise, memory format changes
43+
# will be handled internally by the backend.
44+
if (
45+
node.target == exir_ops.edge.aten._to_copy.default
46+
or node.target == torch.ops.aten._to_copy.default
47+
):
48+
src_dtype = node.meta["val"].dtype
49+
# pyre-ignore
50+
dst_dtype = node.args[0].meta["val"].dtype
51+
return src_dtype == dst_dtype
52+
53+
return False
54+
55+
def _remove(self, graph_module: torch.fx.GraphModule) -> None:
56+
for node in graph_module.graph.nodes:
57+
if not self._should_remove(node):
58+
continue
59+
60+
with graph_module.graph.inserting_after(node):
61+
node.replace_all_uses_with(node.args[0])
62+
63+
graph_module.graph.eliminate_dead_code()
64+
65+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
66+
self._remove(graph_module)
67+
graph_module.recompile()
68+
dead_code_elimination_pass(graph_module)
69+
return PassResult(graph_module, True)

backends/vulkan/op_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def update_features_impl(op: OpKey):
228228
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
229229
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
230230
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231+
# dim order copy operator will be removed; memory layout is handled internally
232+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
231233
]
232234
)
233235
def register_ephemeral_op(features: OpFeatures):

backends/vulkan/vulkan_preprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
1818
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
1919
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
20-
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
2120

2221
from executorch.backends.vulkan._passes import (
2322
insert_prepack_nodes,
2423
RemoveLocalScalarDenseOpsTransform,
24+
RemoveRedundantOpsTransform,
2525
TagMemoryMetaPass,
2626
)
2727

@@ -143,7 +143,7 @@ def preprocess( # noqa: C901
143143
program = apply_passes(
144144
program,
145145
[
146-
RemoveCloneOpsTransform(),
146+
RemoveRedundantOpsTransform(),
147147
AddmmToLinearTransform(),
148148
FuseDequantLinearPass(),
149149
FuseViewCopyTransform(),

0 commit comments

Comments
 (0)