Skip to content

Commit 4a653db

Browse files
authored
[ET-VK] Add pass to remove copy ops
Differential Revision: D67180898 Pull Request resolved: #7325
1 parent 8460d42 commit 4a653db

File tree

5 files changed

+91
-2
lines changed

5 files changed

+91
-2
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@ runtime.python_library(
4343
],
4444
)
4545

46+
runtime.python_library(
47+
name = "remove_redundant_ops",
48+
srcs = ["remove_redundant_ops.py"],
49+
visibility = [
50+
"//executorch/backends/...",
51+
],
52+
deps = [
53+
"//caffe2:torch",
54+
"//executorch/exir:pass_base",
55+
"//executorch/exir/dialects:lib",
56+
],
57+
)
58+
4659
runtime.python_library(
4760
name = "tag_memory_meta_pass",
4861
srcs = ["tag_memory_meta_pass.py"],
@@ -71,6 +84,7 @@ runtime.python_library(
7184
":insert_prepack_nodes",
7285
":int4_weight_only_quantizer",
7386
":remove_local_scalar_dense",
87+
":remove_redundant_ops",
7488
":tag_memory_meta_pass"
7589
]
7690
)

backends/vulkan/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
66
RemoveLocalScalarDenseOpsTransform,
77
)
8+
from executorch.backends.vulkan._passes.remove_redundant_ops import (
9+
RemoveRedundantOpsTransform,
10+
)
811
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
912

1013
__all__ = [
1114
"insert_prepack_nodes",
1215
"VkInt4WeightOnlyQuantizer",
1316
"RemoveLocalScalarDenseOpsTransform",
17+
"RemoveRedundantOpsTransform",
1418
"TagMemoryMetaPass",
1519
]
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Set, Union
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.passes import dead_code_elimination_pass
16+
17+
OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload]
18+
19+
20+
class RemoveRedundantOpsTransform(ExportPass):
21+
"""
22+
Trim certain operators to reduce unnecessary overhead.
23+
"""
24+
25+
redundant_ops: Set[OpType] = {
26+
torch.clone,
27+
torch.ops.aten.clone.default,
28+
exir_ops.edge.aten.clone.default,
29+
torch.ops.aten.alias.default,
30+
exir_ops.edge.aten.alias.default,
31+
exir_ops.edge.aten.lift_fresh_copy.default,
32+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
33+
}
34+
35+
def __init__(self) -> None:
36+
super(RemoveRedundantOpsTransform, self).__init__()
37+
38+
def _should_remove(self, node: torch.fx.Node) -> bool:
39+
if node.target in self.redundant_ops:
40+
return True
41+
42+
# Only remove to_copy if dtype does not change. Otherwise, memory format changes
43+
# will be handled internally by the backend.
44+
if (
45+
node.target == exir_ops.edge.aten._to_copy.default
46+
or node.target == torch.ops.aten._to_copy.default
47+
):
48+
src_dtype = node.meta["val"].dtype
49+
# pyre-ignore
50+
dst_dtype = node.args[0].meta["val"].dtype
51+
return src_dtype == dst_dtype
52+
53+
return False
54+
55+
def _remove(self, graph_module: torch.fx.GraphModule) -> None:
56+
for node in graph_module.graph.nodes:
57+
if not self._should_remove(node):
58+
continue
59+
60+
with graph_module.graph.inserting_after(node):
61+
node.replace_all_uses_with(node.args[0])
62+
63+
graph_module.graph.eliminate_dead_code()
64+
65+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
66+
self._remove(graph_module)
67+
graph_module.recompile()
68+
dead_code_elimination_pass(graph_module)
69+
return PassResult(graph_module, True)

backends/vulkan/op_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def update_features_impl(op: OpKey):
228228
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
229229
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
230230
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231+
# dim order copy operator will be removed; memory layout is handled internally
232+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
231233
]
232234
)
233235
def register_ephemeral_op(features: OpFeatures):

backends/vulkan/vulkan_preprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
1818
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
1919
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
20-
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
2120

2221
from executorch.backends.vulkan._passes import (
2322
insert_prepack_nodes,
2423
RemoveLocalScalarDenseOpsTransform,
24+
RemoveRedundantOpsTransform,
2525
TagMemoryMetaPass,
2626
)
2727

@@ -143,7 +143,7 @@ def preprocess( # noqa: C901
143143
program = apply_passes(
144144
program,
145145
[
146-
RemoveCloneOpsTransform(),
146+
RemoveRedundantOpsTransform(),
147147
AddmmToLinearTransform(),
148148
FuseDequantLinearPass(),
149149
FuseViewCopyTransform(),

0 commit comments

Comments
 (0)