Skip to content

Commit 3639e5b

Browse files
metascroyfacebook-github-bot
authored andcommitted
Replace view_copy with view (1/3) (#2461)
Summary: Design: https://docs.google.com/document/d/1l9x925EOrE8mHFJdRCC59nBJXyqBdnoeK-EgNQScXD0/edit#heading=h.kocb2mvchnib This stack replaces view_copy nodes with memory.view nodes. In the first diff (D54816555), I write a pass to normalize view_copy nodes by making their base point to the upstream non-view node. This means if we have something like op -> view_copy1 -> view_copy2, then after normalization, both view copies will point to op in their base (assuming op is not a view node). Note that this pass combined with dead-code elimination removes redundant view copies. This is because a redundant view copy will have no users have this pass. In the second diff (D54827305), I write a pass to convert view_copy nodes to memory.view nodes. A memory.view is similar to torch.ops.aten.view.default, but it is its own function so that we can handle it specially during memory planning and emission. A memory.view node has a special TensorSpec of type _MemoryViewSpec. This spec is immutable and dynamically looks up non-size related fields from its base's TensorSpec. Because it is immutable, fields on a _MemoryViewSpec cannot be set, but if a field is updated on the base spec, this update is reflected in the memory.view node's _MemoryViewSpec. Not all view_copy nodes are converted to memory.view nodes. Only static nodes that are memory planned are converted. Not all static nodes are memory planned in ExecuTorch. For example, there is an option to turn off memory planning for input nodes, and outputs from some higher order ops like cond are not memory planned. Which nodes are memory planned is not easily available, and I did not try to cover all cases of nodes that can be converted. We can expand this list over time. In the third diff (D54827438), I implement the actual view_copy elimination. In the ExecutorchBackendConfig, there is a new option remove_static_view_copy. If remove_static_view_copy = True, the memory planning passes are [NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass]; if remove_static_view_copy = False, the memory planning passes are [config.to_out_var_pass, config.memory_planning_pass] (state today). Let's look at the flow when remove_static_view_copy = True: NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass. The first two steps are the just the first and second diff described above. In config.to_out_var_pass, the memory.view nodes are skipped. In config.memory_planning_pass, when a spec is requested for a memory.view node (e.g., to update the lifetime), we return the spec of its base. Returning the spec for the base means that whenever we see a memory.view node, we actually update the lifetime of the base to cover it. Moreover, the memory.view node's special _MemoryViewSpec sees this update reflected. (Note that an exception would be thrown if we kept the usual flow and returned the spec for the memory.view node. This is because the special _MemoryViewSpec is immutable and would not allow the memory_planning_pass to update its lifetime.) Finally, during emission the memory.view is emitted as an evalue. There are two more diffs on the stack D54866523 and D54866539. The first of these replaces the old RemoveRedundantViewCopy pass with a NormalizeViewCopyBasePass + dead code elimination. The second converts view-like ops (squeeze, unsqueeze, slice) to view ops when safe to do so to take advantage of the view_copy elimination. Reviewed By: larryliu0820 Differential Revision: D54816555
1 parent a89df58 commit 3639e5b

File tree

4 files changed

+129
-0
lines changed

4 files changed

+129
-0
lines changed

exir/passes/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,14 @@ python_library(
312312
"//executorch/exir/dialects:lib",
313313
],
314314
)
315+
316+
python_library(
317+
name = "normalize_view_copy_base_pass",
318+
srcs = [
319+
"normalize_view_copy_base_pass.py",
320+
],
321+
deps = [
322+
"//caffe2:torch",
323+
"//executorch/exir/dialects:lib",
324+
],
325+
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
11+
import torch
12+
13+
from executorch.exir.dialects._ops import ops
14+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
15+
16+
17+
def _is_view_copy(node: torch.fx.Node) -> bool:
18+
return node.op == "call_function" and node.target in (
19+
torch.ops.aten.view_copy.default,
20+
ops.edge.aten.view_copy.default,
21+
)
22+
23+
24+
class NormalizeViewCopyBasePass(PassBase):
25+
"""
26+
Point each view_copy to the first upstream non-view.
27+
28+
After this pass, the base of each view_copy is not a view_copy.
29+
30+
When combined with dead-code elimination, this pass removes redundant
31+
view_copy nodes.
32+
33+
TODO: replace RemoveRedundantViewCopyPass with NormalizeViewCopyBasePass + dead code elimination.
34+
"""
35+
36+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
37+
n_updated = 0
38+
for module in graph_module.modules():
39+
if not isinstance(module, torch.fx.GraphModule):
40+
continue
41+
for node in module.graph.nodes:
42+
if _is_view_copy(node):
43+
base, size = node.args
44+
if _is_view_copy(base):
45+
# Point base to bases's base and update node's args
46+
# Base's base will not be a view_copy because we iterate
47+
# through the graph in topological order, replacing as we go.
48+
base = base.args[0]
49+
node.args = (base, size)
50+
n_updated += 1
51+
52+
module.recompile()
53+
54+
logging.debug(f"Updated the base on {n_updated} view_copy nodes.")
55+
return PassResult(graph_module, n_updated > 0)
56+
57+
def ensures(self, graph_module: torch.fx.GraphModule) -> None:
58+
for module in graph_module.modules():
59+
if not isinstance(module, torch.fx.GraphModule):
60+
continue
61+
for node in module.graph.nodes:
62+
if _is_view_copy(node):
63+
base, size = node.args
64+
assert not _is_view_copy(base)

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ python_unittest(
216216
"//executorch/exir/passes:debug_handle_generator_pass",
217217
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
218218
"//executorch/exir/passes:lib",
219+
"//executorch/exir/passes:normalize_view_copy_base_pass",
219220
"//executorch/exir/passes:remove_graph_asserts_pass",
220221
"//executorch/exir/passes:remove_mixed_type_operators",
221222
"//executorch/exir/passes:replace_edge_with_backend_pass",

exir/tests/test_passes.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
3737
insert_write_back_for_buffers_pass,
3838
)
39+
from executorch.exir.passes.normalize_view_copy_base_pass import (
40+
NormalizeViewCopyBasePass,
41+
)
3942
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
4043
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
4144
from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass
@@ -1310,3 +1313,53 @@ def forward(self, x):
13101313
# The 3 views on z are collapsed to 1 view
13111314
# In total, 2 view remain
13121315
self.assertEqual(n_views_after, 2)
1316+
1317+
def test_normalize_view_copy_base_pass(self) -> None:
1318+
1319+
class ViewChain(torch.nn.Module):
1320+
def forward(self, x):
1321+
x = torch.ops.aten.view_copy.default(x, [30, 1])
1322+
x = torch.ops.aten.view_copy.default(x, [5, 6])
1323+
x = torch.ops.aten.view_copy.default(x, [2, 15])
1324+
x = torch.ops.aten.view_copy.default(x, [3, -1])
1325+
return x
1326+
1327+
def is_view_copy(node: torch.fx.Node) -> bool:
1328+
return (
1329+
node.op == "call_function"
1330+
and node.target == torch.ops.aten.view_copy.default
1331+
)
1332+
1333+
gm = export(ViewChain(), (torch.ones(30),)).graph_module
1334+
1335+
# Check before transformation
1336+
n_view_copy_before = 0
1337+
n_view_copy_bases_before = 0
1338+
for node in gm.graph.nodes:
1339+
if is_view_copy(node):
1340+
n_view_copy_before += 1
1341+
base = node.args[0]
1342+
if is_view_copy(base):
1343+
n_view_copy_bases_before += 1
1344+
1345+
self.assertEqual(n_view_copy_before, 4)
1346+
self.assertEqual(n_view_copy_bases_before, 3)
1347+
1348+
# Do transformation
1349+
p = NormalizeViewCopyBasePass()
1350+
gm_res = p(gm)
1351+
assert gm_res is not None
1352+
gm = gm_res.graph_module
1353+
1354+
# Check after transformation
1355+
n_view_copy_after = 0
1356+
n_view_copy_bases_after = 0
1357+
for node in gm.graph.nodes:
1358+
if is_view_copy(node):
1359+
n_view_copy_after += 1
1360+
base = node.args[0]
1361+
if is_view_copy(base):
1362+
n_view_copy_bases_after += 1
1363+
1364+
self.assertEqual(n_view_copy_after, 4)
1365+
self.assertEqual(n_view_copy_bases_after, 0)

0 commit comments

Comments
 (0)