Skip to content

Replace view_copy with view (1/3) #2461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions exir/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,14 @@ python_library(
"//executorch/exir/dialects/edge:lib",
],
)

python_library(
name = "normalize_view_copy_base_pass",
srcs = [
"normalize_view_copy_base_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir/dialects:lib",
],
)
64 changes: 64 additions & 0 deletions exir/passes/normalize_view_copy_base_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging

import torch

from executorch.exir.dialects._ops import ops
from torch.fx.passes.infra.pass_base import PassBase, PassResult


def _is_view_copy(node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in (
torch.ops.aten.view_copy.default,
ops.edge.aten.view_copy.default,
)


class NormalizeViewCopyBasePass(PassBase):
"""
Point each view_copy to the first upstream non-view.
After this pass, the base of each view_copy is not a view_copy.
When combined with dead-code elimination, this pass removes redundant
view_copy nodes.
TODO: replace RemoveRedundantViewCopyPass with NormalizeViewCopyBasePass + dead code elimination.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
n_updated = 0
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if _is_view_copy(node):
base, size = node.args
if _is_view_copy(base):
# Point base to bases's base and update node's args
# Base's base will not be a view_copy because we iterate
# through the graph in topological order, replacing as we go.
base = base.args[0]
node.args = (base, size)
n_updated += 1

module.recompile()

logging.debug(f"Updated the base on {n_updated} view_copy nodes.")
return PassResult(graph_module, n_updated > 0)

def ensures(self, graph_module: torch.fx.GraphModule) -> None:
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if _is_view_copy(node):
base, size = node.args
assert not _is_view_copy(base)
1 change: 1 addition & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ python_unittest(
"//executorch/exir/passes:debug_handle_generator_pass",
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
"//executorch/exir/passes:lib",
"//executorch/exir/passes:normalize_view_copy_base_pass",
"//executorch/exir/passes:remove_graph_asserts_pass",
"//executorch/exir/passes:remove_mixed_type_operators",
"//executorch/exir/passes:replace_edge_with_backend_pass",
Expand Down
53 changes: 53 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
insert_write_back_for_buffers_pass,
)
from executorch.exir.passes.normalize_view_copy_base_pass import (
NormalizeViewCopyBasePass,
)
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass
Expand Down Expand Up @@ -1420,3 +1423,53 @@ def forward(self, x):
for node in edge.exported_program().graph_module.graph.nodes
)
)

def test_normalize_view_copy_base_pass(self) -> None:

class ViewChain(torch.nn.Module):
def forward(self, x):
x = torch.ops.aten.view_copy.default(x, [30, 1])
x = torch.ops.aten.view_copy.default(x, [5, 6])
x = torch.ops.aten.view_copy.default(x, [2, 15])
x = torch.ops.aten.view_copy.default(x, [3, -1])
return x

def is_view_copy(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops.aten.view_copy.default
)

gm = export(ViewChain(), (torch.ones(30),)).graph_module

# Check before transformation
n_view_copy_before = 0
n_view_copy_bases_before = 0
for node in gm.graph.nodes:
if is_view_copy(node):
n_view_copy_before += 1
base = node.args[0]
if is_view_copy(base):
n_view_copy_bases_before += 1

self.assertEqual(n_view_copy_before, 4)
self.assertEqual(n_view_copy_bases_before, 3)

# Do transformation
p = NormalizeViewCopyBasePass()
gm_res = p(gm)
assert gm_res is not None
gm = gm_res.graph_module

# Check after transformation
n_view_copy_after = 0
n_view_copy_bases_after = 0
for node in gm.graph.nodes:
if is_view_copy(node):
n_view_copy_after += 1
base = node.args[0]
if is_view_copy(base):
n_view_copy_bases_after += 1

self.assertEqual(n_view_copy_after, 4)
self.assertEqual(n_view_copy_bases_after, 0)