Skip to content

Commit c110654

Browse files
metascroyfacebook-github-bot
authored andcommitted
Remove RemoveRedundantViewCopyPass (#2464)
Summary: The RemoveRedundantViewCopyPass is unnecessary and can be replaced by NormalizeViewCopyBasePass + dead code elimintation. Differential Revision: D54866523
1 parent 137a387 commit c110654

File tree

5 files changed

+13
-85
lines changed

5 files changed

+13
-85
lines changed

exir/passes/TARGETS

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ python_library(
1414
":memory_format_ops_pass",
1515
":memory_planning_pass",
1616
":normalize_transpose_pass",
17+
":normalize_view_copy_base_pass",
1718
":prim_ops_py_registry",
1819
":quant_fusion_pass",
1920
":remove_noop_pass",
20-
":remove_redundant_view_copy_pass",
2121
":replace_aten_with_edge_pass",
2222
":replace_broken_ops_with_function_ops_pass",
2323
":replace_edge_with_backend_pass",
@@ -302,17 +302,6 @@ python_library(
302302
],
303303
)
304304

305-
python_library(
306-
name = "remove_redundant_view_copy_pass",
307-
srcs = [
308-
"remove_redundant_view_copy_pass.py",
309-
],
310-
deps = [
311-
"//caffe2:torch",
312-
"//executorch/exir/dialects:lib",
313-
],
314-
)
315-
316305
python_library(
317306
name = "normalize_view_copy_base_pass",
318307
srcs = [

exir/passes/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@
4242
from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass
4343
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
4444
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
45+
from executorch.exir.passes.normalize_view_copy_base_pass import (
46+
NormalizeViewCopyBasePass,
47+
)
4548
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
4649
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass
47-
from executorch.exir.passes.remove_redundant_view_copy_pass import (
48-
RemoveRedundantViewCopyPass,
49-
)
5050
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
5151
from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
5252
ReplaceBrokenOpsWithFunctionalOpsPass,
@@ -485,7 +485,9 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
485485
ScalarToTensorPass(),
486486
SymToTensorPass(),
487487
RemoveNoopPass(),
488-
RemoveRedundantViewCopyPass(),
488+
# Running NormalizeViewCopyBasePass + dead code elimination
489+
# removes redundant view_copy nodes
490+
NormalizeViewCopyBasePass(),
489491
]
490492
).passes
491493

exir/passes/normalize_view_copy_base_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ class NormalizeViewCopyBasePass(PassBase):
2929
3030
When combined with dead-code elimination, this pass removes redundant
3131
view_copy nodes.
32-
33-
TODO: replace RemoveRedundantViewCopyPass with NormalizeViewCopyBasePass + dead code elimination.
3432
"""
3533

3634
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

exir/passes/remove_redundant_view_copy_pass.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

exir/tests/test_passes.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,11 @@ def forward(self, x):
12541254
self.assertEqual(count_copies(gm), 1)
12551255

12561256
def test_remove_redundant_view_copy_pass(self) -> None:
1257+
# This tests that redundant view_copy nodes are removed
1258+
# during to_edge. There is no pass that explicitly does this.
1259+
# It results from running the NormalizeViewCopyBasePass and dead code
1260+
# elimination.
1261+
12571262
def is_view(node: torch.fx.Node) -> bool:
12581263
return node.op == "call_function" and node.target in (
12591264
torch.ops.aten.view_copy.default,
@@ -1271,7 +1276,7 @@ def forward(self, x):
12711276

12721277
view_chain = ViewChain()
12731278

1274-
exported_program = export(view_chain, (torch.ones(30),))
1279+
exported_program = torch.export.export(view_chain, (torch.ones(30),))
12751280
n_views_before = 0
12761281
for node in exported_program.graph.nodes:
12771282
if is_view(node):

0 commit comments

Comments
 (0)