Skip to content

Commit f7dd4fd

Browse files
Scott Royfacebook-github-bot
authored andcommitted
Replace view_copy with view (2/3)
Summary: Design: https://docs.google.com/document/d/1l9x925EOrE8mHFJdRCC59nBJXyqBdnoeK-EgNQScXD0/edit#heading=h.kocb2mvchnib This stack replaces view_copy nodes with memory.view nodes. In the first diff (D54816555), I write a pass to normalize view_copy nodes by making their base point to the upstream non-view node. This means if we have something like op -> view_copy1 -> view_copy2, then after normalization, both view copies will point to op in their base (assuming op is not a view node). Note that this pass combined with dead-code elimination removes redundant view copies. This is because a redundant view copy will have no users have this pass. In the second diff (D54827438), I write a pass to convert view_copy nodes to memory.view nodes. A memory.view is similar to torch.ops.aten.view.default, but it is its own function so that we can handle it specially during memory planning and emission. A memory.view node has a special TensorSpec of type _MemoryViewSpec. This spec is immutable and dynamically looks up non-size related fields from its base's TensorSpec. Because it is immutable, fields on a _MemoryViewSpec cannot be set, but if a field is updated on the base spec, this update is reflected in the memory.view node's _MemoryViewSpec. Not all view_copy nodes are converted to memory.view nodes. Only static nodes that are memory planned are converted. Not all static nodes are memory planned in ExecuTorch. For example, there is an option to turn off memory planning for input nodes, and outputs from some higher order ops like cond are not memory planned. Which nodes are memory planned is not easily available, and I did not try to cover all cases of nodes that can be converted. We can expand this list over time. In the third diff (D54827438), I implement the actual view_copy elimination. In the ExecutorchBackendConfig, there is a new option remove_static_view_copy. If remove_static_view_copy = True, the memory planning passes are [NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass]; if remove_static_view_copy = False, the memory planning passes are [config.to_out_var_pass, config.memory_planning_pass] (state today). Let's look at the flow when remove_static_view_copy = True: NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass. The first two steps are the just the first and second diff described above. In config.to_out_var_pass, the memory.view nodes are skipped. In config.memory_planning_pass, when a spec is requested for a memory.view node (e.g., to update the lifetime), we return the spec of its base. Returning the spec for the base means that whenever we see a memory.view node, we actually update the lifetime of the base to cover it. Moreover, the memory.view node's special _MemoryViewSpec sees this update reflected. (Note that an exception would be thrown if we kept the usual flow and returned the spec for the memory.view node. This is because the special _MemoryViewSpec is immutable and would not allow the memory_planning_pass to update its lifetime.) Finally, during emission the memory.view is emitted as an evalue. There are two more diffs on the stack D54866523 and D54866539. The first of these replaces the old RemoveRedundantViewCopy pass with a NormalizeViewCopyBasePass + dead code elimination. The second converts view-like ops (squeeze, unsqueeze, slice) to view ops when safe to do so to take advantage of the view_copy elimination. Differential Revision: https://internalfb.com/D54827305
1 parent e15b8f8 commit f7dd4fd

File tree

5 files changed

+311
-0
lines changed

5 files changed

+311
-0
lines changed

exir/memory.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,12 @@ def free(spec: TensorSpec) -> None:
3939
E.g., it can be the target of call_function node.
4040
"""
4141
pass
42+
43+
44+
def view(base: torch.Tensor, size: List[int]) -> torch.Tensor:
45+
"""
46+
This function mimics torch.ops.aten.view.default.
47+
48+
It is used to elide view_copy nodes.
49+
"""
50+
return base.view(size)

exir/passes/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,16 @@ python_library(
311311
"//executorch/exir/dialects:lib",
312312
],
313313
)
314+
315+
python_library(
316+
name = "replace_view_copy_with_view_pass",
317+
srcs = [
318+
"replace_view_copy_with_view_pass.py",
319+
],
320+
deps = [
321+
"//caffe2:torch",
322+
"//executorch/exir:memory",
323+
"//executorch/exir:tensor",
324+
"//executorch/exir/dialects:lib",
325+
],
326+
)
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
import math
11+
from typing import Any, Dict, List, Tuple
12+
13+
import torch
14+
from executorch.exir import memory
15+
16+
from executorch.exir.dialects._ops import ops
17+
from executorch.exir.tensor import (
18+
contiguous_stride_from_shape,
19+
determine_tensor_dynanism,
20+
dim_order_from_stride,
21+
TensorShapeDynamism,
22+
TensorSpec,
23+
)
24+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
25+
26+
logger: logging.Logger = logging.getLogger(__name__)
27+
28+
29+
def _is_view_copy(node: torch.fx.Node) -> bool:
30+
return node.op == "call_function" and node.target in (
31+
torch.ops.aten.view_copy.default,
32+
ops.edge.aten.view_copy.default,
33+
)
34+
35+
36+
_VIEW_OP = memory.view
37+
38+
39+
class _ViewSpec(TensorSpec):
40+
def __init__(self, base: TensorSpec, shape: List[int]) -> None:
41+
"""
42+
A ViewSpec is an immutable TensorSpec that mirrors its base for non-size
43+
related information.
44+
"""
45+
46+
if math.prod(base.shape) != math.prod(shape):
47+
raise Exception(
48+
f"Cannot create a ViewSpec because the provided shape {shape} is not consistent with the number of elements in the provided base ({math.prod(base.shape)})."
49+
)
50+
51+
self._init_setters = [
52+
"_frozen",
53+
"_base",
54+
"_guards",
55+
"shape",
56+
"stride",
57+
"dim_order",
58+
"shape_dynamism",
59+
]
60+
self._frozen = False
61+
self._base = base
62+
self.shape: List[int] = shape
63+
self.stride: Tuple[int] = contiguous_stride_from_shape(torch.Size(self.shape))
64+
self.dim_order: Tuple[bytes] = dim_order_from_stride(self.stride)
65+
self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(
66+
torch.Size(self.shape)
67+
)
68+
69+
# This spec gives a view into its base.
70+
# The base can be modified (e.g., mem_id) and this spec will
71+
# update accordingly, but certain fields we do not expect to change
72+
# We create guards for these
73+
self._guards: Dict[str, Any] = {
74+
"shape_dynamism": base.shape_dynamism,
75+
"scalar_type": base.scalar_type,
76+
"layout": base.layout,
77+
"is_sparse": base.is_sparse,
78+
}
79+
self._frozen = True
80+
81+
def _check_guards(self) -> None:
82+
for name in self._guards:
83+
if getattr(self._base, name) != self._guards[name]:
84+
raise Exception(
85+
f"The guarded attribute '{name}' has changed value. At creation of the ViewSpec, it was {self._guards[name]}, but it is now {getattr(self._base, name)}."
86+
)
87+
88+
def __getattribute__(self, name): # pyre-ignore
89+
if name in [
90+
"_init_setters",
91+
"_frozen",
92+
"_base",
93+
"_guards",
94+
"_check_guards",
95+
# Adding debug is needed so that view_spec.debug() shows the right id in
96+
# its string (if debug is excluded, it shows the id(view_spec._base) instead
97+
# of id(view_spec))
98+
"debug",
99+
]:
100+
return object.__getattribute__(self, name)
101+
102+
# Guard check after freeze
103+
if self._frozen:
104+
self._check_guards()
105+
106+
# self._init_setters attributes come from self, others come from base
107+
if name in self._init_setters:
108+
return object.__getattribute__(self, name)
109+
return getattr(self._base, name)
110+
111+
def __setattr__(self, name: str, val) -> None: # pyre-ignore
112+
if name in ["_init_setters", "_frozen"]:
113+
object.__setattr__(self, name, val)
114+
return
115+
116+
# Allow setting during initialization
117+
if name in self._init_setters and not self._frozen:
118+
object.__setattr__(self, name, val)
119+
return
120+
121+
if name in self._init_setters:
122+
raise Exception(
123+
f"ViewSpec is immutable. Cannot set the attribute '{name}' after creation."
124+
)
125+
126+
raise Exception(
127+
f"ViewSpec is immutable. To update the non-size related attribute '{name}', update the base."
128+
)
129+
130+
131+
class ReplaceViewCopyWithViewPass(PassBase):
132+
def __init__(self) -> None:
133+
super().__init__()
134+
135+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
136+
"""
137+
This pass replaces view_copy nodes with view nodes.
138+
139+
This should be run after the NormalizeViewCopyBasePass.
140+
141+
During memory planning, view nodes share the same storage as their base.
142+
"""
143+
144+
n_replaced = 0
145+
for module in graph_module.modules():
146+
if not isinstance(module, torch.fx.GraphModule):
147+
continue
148+
for node in module.graph.nodes:
149+
if _is_view_copy(node):
150+
base, _ = node.args
151+
node.target = _VIEW_OP
152+
153+
# Create spec for the node.
154+
# _ViewSpec is an immutable TensorSpec gives a view into
155+
# its base spec for non-size related information.
156+
157+
# the shape is not the same as node.args[1] because node.args[1]
158+
# can have an inferred sizes (-1).
159+
shape = node.meta["val"].shape
160+
node.meta["spec"] = _ViewSpec(base.meta["spec"], shape)
161+
162+
n_replaced += 1
163+
164+
module.recompile()
165+
166+
logger.debug(f"Replaced {n_replaced} view_copy nodes with {_VIEW_OP} nodes.")
167+
return PassResult(graph_module, n_replaced > 0)
168+
169+
def ensures(self, graph_module: torch.fx.GraphModule) -> None:
170+
for module in graph_module.modules():
171+
if not isinstance(module, torch.fx.GraphModule):
172+
continue
173+
for node in module.graph.nodes:
174+
assert not _is_view_copy(node)
175+
if node.op == "call_function" and node.target == _VIEW_OP:
176+
assert isinstance(node.meta["spec"], _ViewSpec)
177+
178+
def requires(self, graph_module: torch.fx.GraphModule) -> None:
179+
"""
180+
This pass should be called after NormalizeViewCopyBasePass.
181+
We check that all view_copy nodes have been normalized.
182+
"""
183+
for module in graph_module.modules():
184+
if not isinstance(module, torch.fx.GraphModule):
185+
continue
186+
for node in module.graph.nodes:
187+
if _is_view_copy(node):
188+
base, size = node.args
189+
assert not _is_view_copy(base)

exir/tests/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ python_unittest(
208208
"//executorch/exir:memory",
209209
"//executorch/exir:memory_planning",
210210
"//executorch/exir:pass_base",
211+
"//executorch/exir:schema",
211212
"//executorch/exir:tensor",
212213
"//executorch/exir/dialects:lib",
213214
"//executorch/exir/dialects/edge:lib",
@@ -220,6 +221,7 @@ python_unittest(
220221
"//executorch/exir/passes:remove_graph_asserts_pass",
221222
"//executorch/exir/passes:remove_mixed_type_operators",
222223
"//executorch/exir/passes:replace_edge_with_backend_pass",
224+
"//executorch/exir/passes:replace_view_copy_with_view_pass",
223225
"//executorch/exir/passes:scalar_to_tensor_pass",
224226
"//executorch/exir/passes:spec_prop_pass",
225227
"//executorch/exir/passes:sym_to_tensor_pass",

exir/tests/test_passes.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,18 @@
3939
from executorch.exir.passes.normalize_view_copy_base_pass import (
4040
NormalizeViewCopyBasePass,
4141
)
42+
4243
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
4344
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
4445
from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass
46+
from executorch.exir.passes.replace_view_copy_with_view_pass import (
47+
ReplaceViewCopyWithViewPass,
48+
)
4549
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
4650
from executorch.exir.passes.spec_prop_pass import SpecPropPass
4751
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
4852
from executorch.exir.program._program import lift_constant_tensor_pass
53+
from executorch.exir.schema import TensorShapeDynamism
4954
from executorch.exir.tensor import TensorSpec
5055
from executorch.exir.tests.common import register_additional_test_aten_ops
5156
from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
@@ -1473,3 +1478,96 @@ def is_view_copy(node: torch.fx.Node) -> bool:
14731478

14741479
self.assertEqual(n_view_copy_after, 4)
14751480
self.assertEqual(n_view_copy_bases_after, 0)
1481+
1482+
def test_replace_view_copy_with_view_pass(self) -> None: # noqa: C901
1483+
1484+
# Helper functions
1485+
def is_view_copy(node: torch.fx.Node) -> bool:
1486+
return (
1487+
node.op == "call_function"
1488+
and node.target == torch.ops.aten.view_copy.default
1489+
)
1490+
1491+
def is_memory_view(node: torch.fx.Node) -> bool:
1492+
return node.op == "call_function" and node.target == memory.view
1493+
1494+
# Test example set up
1495+
class TestViewCopies(torch.nn.Module):
1496+
def __init__(self):
1497+
super().__init__()
1498+
self.parameter = torch.nn.Parameter(torch.ones(1))
1499+
1500+
def forward(self, x):
1501+
o1 = torch.ops.aten.view_copy.default(
1502+
self.parameter, [1]
1503+
) # replaceable parameter
1504+
o2 = torch.ops.aten.view_copy.default(x, [1]) # replaceable user input
1505+
o3 = torch.ops.aten.view_copy.default(
1506+
torch.ops.aten.relu.default(x), [1]
1507+
) # replaceable dynamic unbound
1508+
o4 = torch.ops.aten.view_copy.default(
1509+
torch.ops.aten.gelu.default(x), [1]
1510+
) # replaceable dynamic bound
1511+
o5 = torch.ops.aten.view_copy.default(
1512+
torch.ops.aten.tanh.default(x), [1]
1513+
) # replaceable static
1514+
return o1, o2, o3, o4, o5
1515+
1516+
ep = torch.export.export(
1517+
TestViewCopies(),
1518+
args=(torch.ones(1),),
1519+
)
1520+
self.assertEqual(len(ep.graph.nodes), 11)
1521+
for node in ep.graph.nodes:
1522+
if node.op == "placeholder":
1523+
node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1))
1524+
node.meta["spec"].shape_dynamism = TensorShapeDynamism.STATIC
1525+
elif node.target == torch.ops.aten.relu.default:
1526+
node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1))
1527+
node.meta["spec"].shape_dynamism = TensorShapeDynamism.DYNAMIC_UNBOUND
1528+
elif node.target == torch.ops.aten.gelu.default:
1529+
node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1))
1530+
node.meta["spec"].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND
1531+
elif node.target == torch.ops.aten.tanh.default:
1532+
node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1))
1533+
node.meta["spec"].shape_dynamism = TensorShapeDynamism.STATIC
1534+
elif node.target == torch.ops.aten.view_copy.default:
1535+
node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1))
1536+
node.meta["spec"].shape_dynamism = (
1537+
node.args[0].meta["spec"].shape_dynamism
1538+
)
1539+
else:
1540+
pass
1541+
1542+
# Run tests
1543+
gm = ep.graph_module
1544+
1545+
# Check before transformation
1546+
n_view_copy_before = 0
1547+
n_memory_view_before = 0
1548+
for node in gm.graph.nodes:
1549+
if is_view_copy(node):
1550+
n_view_copy_before += 1
1551+
if is_memory_view(node):
1552+
n_memory_view_before += 1
1553+
1554+
self.assertEqual(n_view_copy_before, 5)
1555+
self.assertEqual(n_memory_view_before, 0)
1556+
1557+
# Do transformation
1558+
p = ReplaceViewCopyWithViewPass()
1559+
gm_res = p(gm)
1560+
assert gm_res is not None
1561+
gm = gm_res.graph_module
1562+
1563+
# Check after transformation
1564+
n_view_copy_after = 0
1565+
n_memory_view_after = 0
1566+
for node in gm.graph.nodes:
1567+
if is_view_copy(node):
1568+
n_view_copy_after += 1
1569+
if is_memory_view(node):
1570+
n_memory_view_after += 1
1571+
1572+
self.assertEqual(n_view_copy_after, 0)
1573+
self.assertEqual(n_memory_view_after, 5)

0 commit comments

Comments
 (0)