Skip to content

Commit 376523a

Browse files
metascroyfacebook-github-bot
authored andcommitted
Replace view_copy with view (2/3) (#2462)
Summary: Design: https://docs.google.com/document/d/1l9x925EOrE8mHFJdRCC59nBJXyqBdnoeK-EgNQScXD0/edit#heading=h.kocb2mvchnib This stack replaces view_copy nodes with memory.view nodes. In the first diff (D54816555), I write a pass to normalize view_copy nodes by making their base point to the upstream non-view node. This means if we have something like op -> view_copy1 -> view_copy2, then after normalization, both view copies will point to op in their base (assuming op is not a view node). Note that this pass combined with dead-code elimination removes redundant view copies. This is because a redundant view copy will have no users have this pass. In the second diff (D54827305), I write a pass to convert view_copy nodes to memory.view nodes. A memory.view is similar to torch.ops.aten.view.default, but it is its own function so that we can handle it specially during memory planning and emission. A memory.view node has a special TensorSpec of type _MemoryViewSpec. This spec is immutable and dynamically looks up non-size related fields from its base's TensorSpec. Because it is immutable, fields on a _MemoryViewSpec cannot be set, but if a field is updated on the base spec, this update is reflected in the memory.view node's _MemoryViewSpec. Not all view_copy nodes are converted to memory.view nodes. Only static nodes that are memory planned are converted. Not all static nodes are memory planned in ExecuTorch. For example, there is an option to turn off memory planning for input nodes, and outputs from some higher order ops like cond are not memory planned. Which nodes are memory planned is not easily available, and I did not try to cover all cases of nodes that can be converted. We can expand this list over time. In the third diff (D54827438), I implement the actual view_copy elimination. In the ExecutorchBackendConfig, there is a new option remove_static_view_copy. If remove_static_view_copy = True, the memory planning passes are [NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass]; if remove_static_view_copy = False, the memory planning passes are [config.to_out_var_pass, config.memory_planning_pass] (state today). Let's look at the flow when remove_static_view_copy = True: NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass. The first two steps are the just the first and second diff described above. In config.to_out_var_pass, the memory.view nodes are skipped. In config.memory_planning_pass, when a spec is requested for a memory.view node (e.g., to update the lifetime), we return the spec of its base. Returning the spec for the base means that whenever we see a memory.view node, we actually update the lifetime of the base to cover it. Moreover, the memory.view node's special _MemoryViewSpec sees this update reflected. (Note that an exception would be thrown if we kept the usual flow and returned the spec for the memory.view node. This is because the special _MemoryViewSpec is immutable and would not allow the memory_planning_pass to update its lifetime.) Finally, during emission the memory.view is emitted as an evalue. There are two more diffs on the stack D54866523 and D54866539. The first of these replaces the old RemoveRedundantViewCopy pass with a NormalizeViewCopyBasePass + dead code elimination. The second converts view-like ops (squeeze, unsqueeze, slice) to view ops when safe to do so to take advantage of the view_copy elimination. Differential Revision: D54827305
1 parent 1eaab97 commit 376523a

File tree

6 files changed

+393
-10
lines changed

6 files changed

+393
-10
lines changed

exir/memory.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,12 @@ def free(spec: TensorSpec) -> None:
3939
E.g., it can be the target of call_function node.
4040
"""
4141
pass
42+
43+
44+
def view(base: torch.Tensor, size: List[int]) -> torch.Tensor:
45+
"""
46+
This function mimics torch.ops.aten.view.default.
47+
48+
It is used to elide memory-planned, static view_copy nodes.
49+
"""
50+
return base.view(size)

exir/memory_planning.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@
3535

3636
REGISTERED_ALGOS: Dict[str, Callable[..., List[int]]] = {}
3737

38+
SPECIAL_TARGETS = [ # pyre-ignore
39+
memory.alloc,
40+
memory.view,
41+
operator.getitem,
42+
torch.ops.higher_order.cond,
43+
exir_while,
44+
torch.ops.higher_order.map_impl,
45+
executorch_call_delegate,
46+
]
47+
3848

3949
class Verifier:
4050
"""
@@ -393,16 +403,7 @@ def collect_specs_from_nodes( # noqa: C901
393403

394404
if do_assertion:
395405
internal_assert(
396-
node.op in ("placeholder", "output")
397-
or node.target
398-
in [
399-
memory.alloc,
400-
operator.getitem,
401-
torch.ops.higher_order.cond,
402-
exir_while,
403-
torch.ops.higher_order.map_impl,
404-
executorch_call_delegate,
405-
],
406+
node.op in ("placeholder", "output") or node.target in SPECIAL_TARGETS,
406407
f"Unexpected op {node.op}, target {node.target}",
407408
)
408409
for spec in specs:
@@ -689,6 +690,24 @@ def get_input_specs(graph_module: fx.GraphModule) -> Set[TensorSpec]:
689690
return input_specs
690691

691692

693+
def is_op_memory_planned(node: torch.fx.Node) -> bool:
694+
"""
695+
Return true if this call function node is memory planned.
696+
We are cautious in this function and are OK returning False even if a node
697+
may be memory planned.
698+
"""
699+
assert node.op == "call_function"
700+
if node.target in SPECIAL_TARGETS:
701+
# Assume nodes in SPECIAL_TARGETS are not memory planned,
702+
# unless they are in the exceptions below
703+
exceptions = [
704+
executorch_call_delegate,
705+
]
706+
return node.target in exceptions
707+
708+
return True
709+
710+
692711
def insert_calls_to_free(
693712
graph_module: fx.GraphModule, allspecs: Set[TensorSpec]
694713
) -> None:

exir/passes/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,17 @@ python_library(
323323
"//executorch/exir/dialects:lib",
324324
],
325325
)
326+
327+
python_library(
328+
name = "replace_view_copy_with_memory_view_pass",
329+
srcs = [
330+
"replace_view_copy_with_memory_view_pass.py",
331+
],
332+
deps = [
333+
"//caffe2:torch",
334+
"//executorch/exir:memory",
335+
"//executorch/exir:memory_planning",
336+
"//executorch/exir:tensor",
337+
"//executorch/exir/dialects:lib",
338+
],
339+
)
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
import math
11+
from typing import Any, Dict, List, Optional, Tuple
12+
13+
import torch
14+
15+
from executorch.exir import memory
16+
from executorch.exir.dialects._ops import ops
17+
from executorch.exir.memory_planning import is_op_memory_planned
18+
from executorch.exir.tensor import (
19+
contiguous_stride_from_shape,
20+
determine_tensor_dynanism,
21+
dim_order_from_stride,
22+
TensorShapeDynamism,
23+
TensorSpec,
24+
)
25+
from torch.export.exported_program import ExportedProgram, ExportGraphSignature
26+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
27+
28+
logger: logging.Logger = logging.getLogger(__name__)
29+
30+
31+
def _is_view_copy(node: torch.fx.Node) -> bool:
32+
return node.op == "call_function" and node.target in (
33+
torch.ops.aten.view_copy.default,
34+
ops.edge.aten.view_copy.default,
35+
)
36+
37+
38+
class _MemoryViewSpec(TensorSpec):
39+
def __init__(self, base: TensorSpec, shape: List[int]) -> None:
40+
"""
41+
A MemoryViewSpec is an immutable TensorSpec that mirrors its base for non-size
42+
related information.
43+
"""
44+
45+
if math.prod(base.shape) != math.prod(shape):
46+
raise Exception(
47+
f"Cannot create a MemoryViewSpec because the provided shape {shape} is not consistent with the number of elements in the provided base ({math.prod(base.shape)})."
48+
)
49+
50+
if not base.is_static_shape_tensor:
51+
raise Exception(
52+
"Cannot create a MemoryViewSpec because the provided base is not a static shape tensor."
53+
)
54+
55+
self._init_setters = [
56+
"_frozen",
57+
"_base",
58+
"_guards",
59+
"shape",
60+
"stride",
61+
"dim_order",
62+
"shape_dynamism",
63+
]
64+
self._frozen = False
65+
self._base = base
66+
self.shape: List[int] = shape
67+
self.stride: Tuple[int] = contiguous_stride_from_shape(torch.Size(self.shape))
68+
self.dim_order: Tuple[bytes] = dim_order_from_stride(self.stride)
69+
self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(
70+
torch.Size(self.shape)
71+
)
72+
73+
# This spec gives a view into its base.
74+
# The base can be modified (e.g., mem_id) and this spec will
75+
# update accordingly, but certain fields we do not expect to change
76+
# We create guards for these
77+
self._guards: Dict[str, Any] = {
78+
"shape_dynamism": base.shape_dynamism,
79+
"scalar_type": base.scalar_type,
80+
"layout": base.layout,
81+
"is_sparse": base.is_sparse,
82+
}
83+
self._frozen = True
84+
85+
def _check_guards(self) -> None:
86+
for name in self._guards:
87+
if getattr(self._base, name) != self._guards[name]:
88+
raise Exception(
89+
f"The guarded attribute '{name}' has changed value. At creation of the MemoryViewSpec, it was {self._guards[name]}, but it is now {getattr(self._base, name)}."
90+
)
91+
92+
def __getattribute__(self, name): # pyre-ignore
93+
if name in [
94+
"_init_setters",
95+
"_frozen",
96+
"_base",
97+
"_guards",
98+
"_check_guards",
99+
# Adding debug is needed so that view_spec.debug() shows the right id in
100+
# its string (if debug is excluded, it shows the id(view_spec._base) instead
101+
# of id(view_spec))
102+
"debug",
103+
]:
104+
return object.__getattribute__(self, name)
105+
106+
# Guard check after freeze
107+
if self._frozen:
108+
self._check_guards()
109+
110+
# self._init_setters attributes come from self, others come from base
111+
if name in self._init_setters:
112+
return object.__getattribute__(self, name)
113+
return getattr(self._base, name)
114+
115+
def __setattr__(self, name: str, val) -> None: # pyre-ignore
116+
if name in ["_init_setters", "_frozen"]:
117+
object.__setattr__(self, name, val)
118+
return
119+
120+
# Allow setting during initialization
121+
if name in self._init_setters and not self._frozen:
122+
object.__setattr__(self, name, val)
123+
return
124+
125+
if name in self._init_setters:
126+
raise Exception(
127+
f"MemoryViewSpec is immutable. Cannot set the attribute '{name}' after creation."
128+
)
129+
130+
raise Exception(
131+
f"MemoryViewSpec is immutable. To update the non-size related attribute '{name}', update the base."
132+
)
133+
134+
135+
class ReplaceViewCopyWithMemoryViewPass(PassBase):
136+
def __init__(self) -> None:
137+
super().__init__()
138+
self._graph_signature: Optional[ExportGraphSignature] = None
139+
140+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
141+
"""
142+
This pass replaces all static, memory-planned view_copy nodes with special
143+
memory.view nodes.
144+
145+
This should be run after the NormalizeViewCopyBasePass.
146+
147+
During memory planning, memory.view nodes share the same storage as their base.
148+
149+
During emission, memory.view nodes are not emitted as operators, but are instead
150+
directly emitted as evalues. They share the same allocation_info/storage as their
151+
base.
152+
"""
153+
154+
n_replaced = 0
155+
for module in graph_module.modules():
156+
if not isinstance(module, torch.fx.GraphModule):
157+
continue
158+
for node in module.graph.nodes:
159+
if self._is_replaceable_view_copy(node):
160+
base, _ = node.args
161+
node.target = memory.view
162+
163+
# Create spec for the node.
164+
# _MemoryViewSpec is an immutable TensorSpec gives a view into
165+
# its base spec for non-size related information.
166+
167+
# the shape is not the same as node.args[1] because node.args[1]
168+
# can have an inferred sizes (-1). This is also better if we want to
169+
# extend support of memory.view for dynamic shapes in the future (which requires
170+
# and operator that implements memory.view).
171+
shape = node.meta["val"].shape
172+
node.meta["spec"] = _MemoryViewSpec(base.meta["spec"], shape)
173+
174+
n_replaced += 1
175+
176+
module.recompile()
177+
178+
logger.debug(f"Replaced {n_replaced} view_copy nodes with memory.view nodes.")
179+
return PassResult(graph_module, n_replaced > 0)
180+
181+
def ensures(self, graph_module: torch.fx.GraphModule) -> None:
182+
for module in graph_module.modules():
183+
if not isinstance(module, torch.fx.GraphModule):
184+
continue
185+
for node in module.graph.nodes:
186+
assert not self._is_replaceable_view_copy(node)
187+
if node.op == "call_function" and node.target == memory.view:
188+
assert isinstance(node.meta["spec"], _MemoryViewSpec)
189+
190+
def requires(self, graph_module: torch.fx.GraphModule) -> None:
191+
"""
192+
This pass should be called after NormalizeViewCopyBasePass.
193+
We check that all view_copy nodes have been normalized.
194+
"""
195+
for module in graph_module.modules():
196+
if not isinstance(module, torch.fx.GraphModule):
197+
continue
198+
for node in module.graph.nodes:
199+
if _is_view_copy(node):
200+
base, size = node.args
201+
assert not _is_view_copy(base)
202+
203+
def set_graph_signature(self, graph_signature: ExportGraphSignature) -> None:
204+
self._graph_signature = graph_signature
205+
206+
def _is_replaceable_view_copy(self, node: torch.fx.Node) -> bool:
207+
if not _is_view_copy(node):
208+
return False
209+
210+
base = node.args[0]
211+
assert isinstance(base, torch.fx.Node)
212+
213+
is_base_memory_planned = False # until proven otherwise
214+
if base.op == "call_function":
215+
is_base_memory_planned = is_op_memory_planned(base)
216+
217+
if base.op == "placeholder":
218+
# For now, we only assume placeholder parameters + buffers are memory planned.
219+
# We cautiously assume that general user inputs are not memory planned.
220+
# Memory planning for user inputs can be turned off, see
221+
# https://github.com/pytorch/executorch/blob/main/exir/passes/memory_planning_pass.py#L32
222+
223+
if self._graph_signature is None or not isinstance(
224+
self._graph_signature, ExportGraphSignature
225+
):
226+
logger.warning(
227+
"The ExportGraphSignature was not set prior to calling ReplaceViewCopyWithMemoryViewPass. Placeholder view_copy nodes cannot be replaced without first calling set_graph_signature."
228+
)
229+
else:
230+
if base.name in self._graph_signature.inputs_to_parameters:
231+
is_base_memory_planned = True
232+
elif base.name in self._graph_signature.inputs_to_buffers:
233+
is_base_memory_planned = True
234+
else:
235+
pass
236+
237+
return is_base_memory_planned and node.meta["spec"].is_static_shape_tensor

exir/tests/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ python_unittest(
208208
"//executorch/exir:memory",
209209
"//executorch/exir:memory_planning",
210210
"//executorch/exir:pass_base",
211+
"//executorch/exir:schema",
211212
"//executorch/exir:tensor",
212213
"//executorch/exir/dialects:lib",
213214
"//executorch/exir/dialects/edge:lib",
@@ -220,6 +221,7 @@ python_unittest(
220221
"//executorch/exir/passes:remove_graph_asserts_pass",
221222
"//executorch/exir/passes:remove_mixed_type_operators",
222223
"//executorch/exir/passes:replace_edge_with_backend_pass",
224+
"//executorch/exir/passes:replace_view_copy_with_memory_view_pass",
223225
"//executorch/exir/passes:scalar_to_tensor_pass",
224226
"//executorch/exir/passes:spec_prop_pass",
225227
"//executorch/exir/passes:sym_to_tensor_pass",

0 commit comments

Comments
 (0)