Skip to content

Commit 4f358a6

Browse files
Scott Royfacebook-github-bot
authored andcommitted
Remove static view_copy (1/4): normalize_view_copy_base_pass
Summary: Design: https://docs.google.com/document/d/1l9x925EOrE8mHFJdRCC59nBJXyqBdnoeK-EgNQScXD0/edit#heading=h.kocb2mvchnib When remove_static_view_copy is turned off (state today), the pass flow in to_executorch is: 1. config.to_out_var_pass 2. config.memory_planning_pass When remove_static_view_copy is turned on, the pass flow in to_executorch becomes: 1. NormalizeViewCopyBasePass() 2. ReplaceStaticViewCopyWithMemoryViewPass() (introduces executorch.exir.memory.view) 3. config.to_out_var_pass (skips executorch.exir.memory.view) 4. config.memory_planning_pass 5. ReplaceMemoryViewWithAllocPass() (removes executorch.exir.memory.view) The basic idea is to replace view_copy with a new operator executorch.exir.memory.view before memory planning (ReplaceStaticViewCopyWithMemoryViewPass). These nodes share the same spec as their base so that lifetimes are updated appropriately during memory planning. After memory planning, these nodes are converted to executorch.exir.memory.alloc nodes before emission. They are not converted to alloc nodes before memory planning. Before memory planning, memory.view nodes share the same spec as their base, but after memory planning, they get new specs when they're converted to memory.alloc (pointing to the same storage as base). Differential Revision: https://internalfb.com/D54816555
1 parent d4507c9 commit 4f358a6

File tree

4 files changed

+129
-0
lines changed

4 files changed

+129
-0
lines changed

exir/passes/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,14 @@ python_library(
312312
"//executorch/exir/dialects:lib",
313313
],
314314
)
315+
316+
python_library(
317+
name = "normalize_view_copy_base_pass",
318+
srcs = [
319+
"normalize_view_copy_base_pass.py",
320+
],
321+
deps = [
322+
"//caffe2:torch",
323+
"//executorch/exir/dialects:lib",
324+
],
325+
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
11+
import torch
12+
13+
from executorch.exir.dialects._ops import ops
14+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
15+
16+
17+
def _is_view_copy(node: torch.fx.Node) -> bool:
18+
return node.op == "call_function" and node.target in (
19+
torch.ops.aten.view_copy.default,
20+
ops.edge.aten.view_copy.default,
21+
)
22+
23+
24+
class NormalizeViewCopyBasePass(PassBase):
25+
"""
26+
Point each view_copy to the first upstream non-view.
27+
28+
After this pass, the base of each view_copy is not a view_copy.
29+
30+
When combined with dead-code elimination, this pass removes redundant
31+
view_copy nodes.
32+
33+
TODO: replace RemoveRedundantViewCopyPass with NormalizeViewCopyBasePass + dead code elimination.
34+
"""
35+
36+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
37+
n_updated = 0
38+
for module in graph_module.modules():
39+
if not isinstance(module, torch.fx.GraphModule):
40+
continue
41+
for node in module.graph.nodes:
42+
if _is_view_copy(node):
43+
base, size = node.args
44+
if _is_view_copy(base):
45+
# Point base to bases's base and update node's args
46+
# Base's base will not be a view_copy because we iterate
47+
# through the graph in topological order, replacing as we go.
48+
base = base.args[0]
49+
node.args = (base, size)
50+
n_updated += 1
51+
52+
module.recompile()
53+
54+
logging.debug(f"Updated the base on {n_updated} view_copy nodes.")
55+
return PassResult(graph_module, n_updated > 0)
56+
57+
def ensures(self, graph_module: torch.fx.GraphModule) -> None:
58+
for module in graph_module.modules():
59+
if not isinstance(module, torch.fx.GraphModule):
60+
continue
61+
for node in module.graph.nodes:
62+
if _is_view_copy(node):
63+
base, size = node.args
64+
assert not _is_view_copy(base)

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ python_unittest(
216216
"//executorch/exir/passes:debug_handle_generator_pass",
217217
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
218218
"//executorch/exir/passes:lib",
219+
"//executorch/exir/passes:normalize_view_copy_base_pass",
219220
"//executorch/exir/passes:remove_graph_asserts_pass",
220221
"//executorch/exir/passes:remove_mixed_type_operators",
221222
"//executorch/exir/passes:replace_edge_with_backend_pass",

exir/tests/test_passes.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
3737
insert_write_back_for_buffers_pass,
3838
)
39+
from executorch.exir.passes.normalize_view_copy_base_pass import (
40+
NormalizeViewCopyBasePass,
41+
)
3942
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
4043
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
4144
from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass
@@ -1310,3 +1313,53 @@ def forward(self, x):
13101313
# The 3 views on z are collapsed to 1 view
13111314
# In total, 2 view remain
13121315
self.assertEqual(n_views_after, 2)
1316+
1317+
def test_normalize_view_copy_base_pass(self) -> None:
1318+
1319+
class ViewChain(torch.nn.Module):
1320+
def forward(self, x):
1321+
x = torch.ops.aten.view_copy.default(x, [30, 1])
1322+
x = torch.ops.aten.view_copy.default(x, [5, 6])
1323+
x = torch.ops.aten.view_copy.default(x, [2, 15])
1324+
x = torch.ops.aten.view_copy.default(x, [3, -1])
1325+
return x
1326+
1327+
def is_view_copy(node: torch.fx.Node) -> bool:
1328+
return (
1329+
node.op == "call_function"
1330+
and node.target == torch.ops.aten.view_copy.default
1331+
)
1332+
1333+
gm = export(ViewChain(), (torch.ones(30),)).graph_module
1334+
1335+
# Check before transformation
1336+
n_view_copy_before = 0
1337+
n_view_copy_bases_before = 0
1338+
for node in gm.graph.nodes:
1339+
if is_view_copy(node):
1340+
n_view_copy_before += 1
1341+
base = node.args[0]
1342+
if is_view_copy(base):
1343+
n_view_copy_bases_before += 1
1344+
1345+
self.assertEqual(n_view_copy_before, 4)
1346+
self.assertEqual(n_view_copy_bases_before, 3)
1347+
1348+
# Do transformation
1349+
p = NormalizeViewCopyBasePass()
1350+
gm_res = p(gm)
1351+
assert gm_res is not None
1352+
gm = gm_res.graph_module
1353+
1354+
# Check after transformation
1355+
n_view_copy_after = 0
1356+
n_view_copy_bases_after = 0
1357+
for node in gm.graph.nodes:
1358+
if is_view_copy(node):
1359+
n_view_copy_after += 1
1360+
base = node.args[0]
1361+
if is_view_copy(base):
1362+
n_view_copy_bases_after += 1
1363+
1364+
self.assertEqual(n_view_copy_after, 4)
1365+
self.assertEqual(n_view_copy_bases_after, 0)

0 commit comments

Comments
 (0)