Skip to content

Commit 07970da

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
remove clone ops pass (#4058)
Summary: Pull Request resolved: #4058 In the Vulkan backend, we always create copies when running an op, so the clone op is effectively a no-op with extra memory read/writes. Adding a pass to strip out clone ops. Reviewed By: copyrightly Differential Revision: D58761417 fbshipit-source-id: 1d9a1cdd295626cda8e4308b45ae6a6c2502b079
1 parent fdeda8e commit 07970da

File tree

4 files changed

+47
-0
lines changed

4 files changed

+47
-0
lines changed

backends/transforms/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ runtime.python_library(
7373
],
7474
)
7575

76+
runtime.python_library(
77+
name = "remove_clone_ops",
78+
srcs = ["remove_clone_ops.py"],
79+
visibility = [
80+
"//executorch/backends/...",
81+
],
82+
deps = [
83+
"//caffe2:torch",
84+
"//executorch/exir:pass_base",
85+
"//executorch/exir/dialects:lib",
86+
],
87+
)
88+
7689
runtime.python_library(
7790
name = "mean_to_sum_div",
7891
srcs = ["mean_to_sum_div.py"],
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import torch
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
def remove_clone_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
15+
"""
16+
Remove clone op nodes and replace uses with parent node.
17+
"""
18+
clone_op = exir_ops.edge.aten.clone.default
19+
for node in graph.nodes:
20+
if node.op == "call_function" and node.target == clone_op:
21+
with graph.inserting_after(node):
22+
node.replace_all_uses_with(node.args[0])
23+
24+
graph.eliminate_dead_code()
25+
return graph
26+
27+
28+
class RemoveCloneOpsTransform(ExportPass):
29+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
30+
graph_module.graph = remove_clone_ops(graph_module.graph)
31+
return PassResult(graph_module, True)

backends/vulkan/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ runtime.python_library(
2626
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
2727
"//executorch/backends/transforms:fuse_conv_with_clamp",
2828
"//executorch/backends/transforms:fuse_view_copy",
29+
"//executorch/backends/transforms:remove_clone_ops",
2930
"//executorch/exir:graph_module",
3031
"//executorch/exir/_serialize:_bindings",
3132
"//executorch/exir/_serialize:lib",

backends/vulkan/vulkan_preprocess.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
1616
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
17+
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
1718

1819
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
1920
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
@@ -47,6 +48,7 @@ def preprocess( # noqa: C901
4748
module_compile_spec: List[CompileSpec],
4849
) -> PreprocessResult:
4950
passes = [
51+
RemoveCloneOpsTransform(),
5052
AddmmToLinearTransform(),
5153
FuseViewCopyTransform(),
5254
FuseBatchNormWithConvPass(program),

0 commit comments

Comments
 (0)