Skip to content

Commit 05545eb

Browse files
metascroyfacebook-github-bot
authored andcommitted
Remove redundant view_copy ops (pytorch#2278)
Summary: https://docs.google.com/document/d/1l9x925EOrE8mHFJdRCC59nBJXyqBdnoeK-EgNQScXD0/edit Reviewed By: JacobSzwejbka, larryliu0820 Differential Revision: D54498251
1 parent 0b6add8 commit 05545eb

File tree

4 files changed

+148
-0
lines changed

4 files changed

+148
-0
lines changed

exir/passes/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ python_library(
1717
":prim_ops_py_registry",
1818
":quant_fusion_pass",
1919
":remove_noop_pass",
20+
":remove_redundant_view_copy_pass",
2021
":replace_aten_with_edge_pass",
2122
":replace_broken_ops_with_function_ops_pass",
2223
":replace_edge_with_backend_pass",
@@ -299,3 +300,14 @@ python_library(
299300
"//executorch/exir/dialects/edge:lib",
300301
],
301302
)
303+
304+
python_library(
305+
name = "remove_redundant_view_copy_pass",
306+
srcs = [
307+
"remove_redundant_view_copy_pass.py",
308+
],
309+
deps = [
310+
"//caffe2:torch",
311+
"//executorch/exir/dialects:lib",
312+
],
313+
)

exir/passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
4545
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
4646
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass
47+
from executorch.exir.passes.remove_redundant_view_copy_pass import (
48+
RemoveRedundantViewCopyPass,
49+
)
4750
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
4851
from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
4952
ReplaceBrokenOpsWithFunctionalOpsPass,
@@ -481,6 +484,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
481484
ScalarToTensorPass(),
482485
SymToTensorPass(),
483486
RemoveNoopPass(),
487+
RemoveRedundantViewCopyPass(),
484488
]
485489
).passes
486490

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops
13+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
14+
15+
16+
def _is_view_copy(node: torch.fx.Node) -> bool:
17+
return node.op == "call_function" and node.target in (
18+
torch.ops.aten.view_copy.default,
19+
ops.edge.aten.view_copy.default,
20+
)
21+
22+
23+
def _maybe_remove_view_copy(node: torch.fx.Node) -> bool:
24+
assert _is_view_copy(node)
25+
26+
# Remove node if all users are views
27+
for user in node.users:
28+
if not _is_view_copy(user):
29+
return False
30+
31+
base = node.args[0]
32+
node.replace_all_uses_with(base)
33+
node.graph.erase_node(node)
34+
return True
35+
36+
37+
class RemoveRedundantViewCopyPass(PassBase):
38+
"""
39+
Removes redundant view_copy nodes.
40+
41+
A view_copy is redundant if all of its users are view_copy. Consider the
42+
following example:
43+
op1 -> view_copy1 -> view_copy2 -> view_copy3 -> op2.
44+
45+
Provided view_copy1 and view_copy2 have no users outside the illustration
46+
above, we can remove them and shorten the graph to
47+
op1 -> view_copy3 -> op2.
48+
49+
"""
50+
51+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
52+
n_removed = 0 # number of redundant view_copy nodes removed
53+
for module in graph_module.modules():
54+
if not isinstance(module, torch.fx.GraphModule):
55+
continue
56+
57+
for node in module.graph.nodes:
58+
if _is_view_copy(node):
59+
removed = _maybe_remove_view_copy(node)
60+
if removed:
61+
n_removed += 1
62+
module.recompile()
63+
64+
logging.info(f"Removed {n_removed} view_copy nodes.")
65+
any_removed = n_removed > 0
66+
return PassResult(graph_module, any_removed)

exir/tests/test_passes.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,3 +1244,69 @@ def forward(self, x):
12441244
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
12451245
# return (copy__default, aten_add_tensor)
12461246
self.assertEqual(count_copies(gm), 1)
1247+
1248+
def test_remove_redundant_view_copy_pass(self) -> None:
1249+
def is_view(node: torch.fx.Node) -> bool:
1250+
return node.op == "call_function" and node.target in (
1251+
torch.ops.aten.view_copy.default,
1252+
exir_ops.edge.aten.view_copy.default,
1253+
# before to_edge, the view_copy are view
1254+
# we include these to count n_views_before
1255+
torch.ops.aten.view.default,
1256+
exir_ops.edge.aten.view.default,
1257+
)
1258+
1259+
# Test chain
1260+
class ViewChain(torch.nn.Module):
1261+
def forward(self, x):
1262+
return x.reshape(30, 1).reshape(5, 6).reshape(2, 15).reshape(3, -1)
1263+
1264+
view_chain = ViewChain()
1265+
1266+
exported_program = export(view_chain, (torch.ones(30),))
1267+
n_views_before = 0
1268+
for node in exported_program.graph.nodes:
1269+
if is_view(node):
1270+
n_views_before += 1
1271+
self.assertEqual(n_views_before, 4)
1272+
1273+
edge_program_manager = to_edge(exported_program)
1274+
n_views_after = 0
1275+
for node in edge_program_manager.exported_program().graph.nodes:
1276+
if is_view(node):
1277+
n_views_after += 1
1278+
the_view_copy_node = node
1279+
1280+
self.assertEqual(n_views_after, 1)
1281+
self.assertEqual(the_view_copy_node.args[1], [3, -1])
1282+
self.assertEqual(
1283+
the_view_copy_node.target, exir_ops.edge.aten.view_copy.default
1284+
)
1285+
1286+
# Test branch
1287+
class ViewBranch(torch.nn.Module):
1288+
def forward(self, x):
1289+
x = x.reshape(30, 1)
1290+
y = torch.sum(x)
1291+
z = x.reshape(5, 6).reshape(2, 15).reshape(3, -1)
1292+
return z + y
1293+
1294+
view_branch = ViewBranch()
1295+
exported_program = export(view_branch, (torch.ones(30),))
1296+
n_views_before = 0
1297+
for node in exported_program.graph.nodes:
1298+
if is_view(node):
1299+
n_views_before += 1
1300+
self.assertEqual(n_views_before, 4)
1301+
1302+
edge_program_manager = to_edge(exported_program)
1303+
n_views_after = 0
1304+
for node in edge_program_manager.exported_program().graph.nodes:
1305+
if is_view(node):
1306+
n_views_after += 1
1307+
the_view_copy_node = node
1308+
1309+
# We keep the view on x (which is consumed by y)
1310+
# The 3 views on z are collapsed to 1 view
1311+
# In total, 2 view remain
1312+
self.assertEqual(n_views_after, 2)

0 commit comments

Comments
 (0)