Skip to content

Remove redundant view_copy ops #2278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions exir/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ python_library(
":prim_ops_py_registry",
":quant_fusion_pass",
":remove_noop_pass",
":remove_redundant_view_copy_pass",
":replace_aten_with_edge_pass",
":replace_broken_ops_with_function_ops_pass",
":replace_edge_with_backend_pass",
Expand Down Expand Up @@ -299,3 +300,14 @@ python_library(
"//executorch/exir/dialects/edge:lib",
],
)

python_library(
name = "remove_redundant_view_copy_pass",
srcs = [
"remove_redundant_view_copy_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir/dialects:lib",
],
)
4 changes: 4 additions & 0 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass
from executorch.exir.passes.remove_redundant_view_copy_pass import (
RemoveRedundantViewCopyPass,
)
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
ReplaceBrokenOpsWithFunctionalOpsPass,
Expand Down Expand Up @@ -481,6 +484,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
ScalarToTensorPass(),
SymToTensorPass(),
RemoveNoopPass(),
RemoveRedundantViewCopyPass(),
]
).passes

Expand Down
66 changes: 66 additions & 0 deletions exir/passes/remove_redundant_view_copy_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging

import torch
from executorch.exir.dialects._ops import ops
from torch.fx.passes.infra.pass_base import PassBase, PassResult


def _is_view_copy(node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in (
torch.ops.aten.view_copy.default,
ops.edge.aten.view_copy.default,
)


def _maybe_remove_view_copy(node: torch.fx.Node) -> bool:
assert _is_view_copy(node)

# Remove node if all users are views
for user in node.users:
if not _is_view_copy(user):
return False

base = node.args[0]
node.replace_all_uses_with(base)
node.graph.erase_node(node)
return True


class RemoveRedundantViewCopyPass(PassBase):
"""
Removes redundant view_copy nodes.
A view_copy is redundant if all of its users are view_copy. Consider the
following example:
op1 -> view_copy1 -> view_copy2 -> view_copy3 -> op2.
Provided view_copy1 and view_copy2 have no users outside the illustration
above, we can remove them and shorten the graph to
op1 -> view_copy3 -> op2.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
n_removed = 0 # number of redundant view_copy nodes removed
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue

for node in module.graph.nodes:
if _is_view_copy(node):
removed = _maybe_remove_view_copy(node)
if removed:
n_removed += 1
module.recompile()

logging.info(f"Removed {n_removed} view_copy nodes.")
any_removed = n_removed > 0
return PassResult(graph_module, any_removed)
1 change: 0 additions & 1 deletion exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ python_unittest(
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:pass_base",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:backend_details",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
Expand Down
66 changes: 66 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,3 +1244,69 @@ def forward(self, x):
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
# return (copy__default, aten_add_tensor)
self.assertEqual(count_copies(gm), 1)

def test_remove_redundant_view_copy_pass(self) -> None:
def is_view(node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in (
torch.ops.aten.view_copy.default,
exir_ops.edge.aten.view_copy.default,
# before to_edge, the view_copy are view
# we include these to count n_views_before
torch.ops.aten.view.default,
exir_ops.edge.aten.view.default,
)

# Test chain
class ViewChain(torch.nn.Module):
def forward(self, x):
return x.reshape(30, 1).reshape(5, 6).reshape(2, 15).reshape(3, -1)

view_chain = ViewChain()

exported_program = export(view_chain, (torch.ones(30),))
n_views_before = 0
for node in exported_program.graph.nodes:
if is_view(node):
n_views_before += 1
self.assertEqual(n_views_before, 4)

edge_program_manager = to_edge(exported_program)
n_views_after = 0
for node in edge_program_manager.exported_program().graph.nodes:
if is_view(node):
n_views_after += 1
the_view_copy_node = node

self.assertEqual(n_views_after, 1)
self.assertEqual(the_view_copy_node.args[1], [3, -1])
self.assertEqual(
the_view_copy_node.target, exir_ops.edge.aten.view_copy.default
)

# Test branch
class ViewBranch(torch.nn.Module):
def forward(self, x):
x = x.reshape(30, 1)
y = torch.sum(x)
z = x.reshape(5, 6).reshape(2, 15).reshape(3, -1)
return z + y

view_branch = ViewBranch()
exported_program = export(view_branch, (torch.ones(30),))
n_views_before = 0
for node in exported_program.graph.nodes:
if is_view(node):
n_views_before += 1
self.assertEqual(n_views_before, 4)

edge_program_manager = to_edge(exported_program)
n_views_after = 0
for node in edge_program_manager.exported_program().graph.nodes:
if is_view(node):
n_views_after += 1
the_view_copy_node = node

# We keep the view on x (which is consumed by y)
# The 3 views on z are collapsed to 1 view
# In total, 2 view remain
self.assertEqual(n_views_after, 2)