Skip to content

Commit 8dcea23

Browse files
metascroyfacebook-github-bot
authored andcommitted
Replace view copy with view (3/3)
Summary: Design: https://docs.google.com/document/d/1l9x925EOrE8mHFJdRCC59nBJXyqBdnoeK-EgNQScXD0/edit#heading=h.kocb2mvchnib This stack replaces view_copy nodes with memory.view nodes. In the first diff (D54816555), I write a pass to normalize view_copy nodes by making their base point to the upstream non-view node. This means if we have something like op -> view_copy1 -> view_copy2, then after normalization, both view copies will point to op in their base (assuming op is not a view node). Note that this pass combined with dead-code elimination removes redundant view copies. This is because a redundant view copy will have no users have this pass. In the second diff (D54827305), I write a pass to convert view_copy nodes to memory.view nodes. A memory.view is similar to torch.ops.aten.view.default, but it is its own function so that we can handle it specially during memory planning and emission. A memory.view node has a special TensorSpec of type _MemoryViewSpec. This spec is immutable and dynamically looks up non-size related fields from its base's TensorSpec. Because it is immutable, fields on a _MemoryViewSpec cannot be set, but if a field is updated on the base spec, this update is reflected in the memory.view node's _MemoryViewSpec. Not all view_copy nodes are converted to memory.view nodes. Only static nodes that are memory planned are converted. Not all static nodes are memory planned in ExecuTorch. For example, there is an option to turn off memory planning for input nodes, and outputs from some higher order ops like cond are not memory planned. Which nodes are memory planned is not easily available, and I did not try to cover all cases of nodes that can be converted. We can expand this list over time. In the third diff (D54827438), I implement the actual view_copy elimination. In the ExecutorchBackendConfig, there is a new option remove_static_view_copy. If remove_static_view_copy = True, the memory planning passes are [NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass]; if remove_static_view_copy = False, the memory planning passes are [config.to_out_var_pass, config.memory_planning_pass] (state today). Let's look at the flow when remove_static_view_copy = True: NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass. The first two steps are the just the first and second diff described above. In config.to_out_var_pass, the memory.view nodes are skipped. In config.memory_planning_pass, when a spec is requested for a memory.view node (e.g., to update the lifetime), we return the spec of its base. Returning the spec for the base means that whenever we see a memory.view node, we actually update the lifetime of the base to cover it. Moreover, the memory.view node's special _MemoryViewSpec sees this update reflected. (Note that an exception would be thrown if we kept the usual flow and returned the spec for the memory.view node. This is because the special _MemoryViewSpec is immutable and would not allow the memory_planning_pass to update its lifetime.) Finally, during emission the memory.view is emitted as an evalue. There are two more diffs on the stack D54866523 and D54866539. The first of these replaces the old RemoveRedundantViewCopy pass with a NormalizeViewCopyBasePass + dead code elimination. The second converts view-like ops (squeeze, unsqueeze, slice) to view ops when safe to do so to take advantage of the view_copy elimination. Differential Revision: D54827438
1 parent 1d6105d commit 8dcea23

File tree

9 files changed

+292
-3
lines changed

9 files changed

+292
-3
lines changed

exir/capture/_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,8 @@ class ExecutorchBackendConfig:
7575
# be a power of 2. If not provided, uses the value in the schema file.
7676
delegate_alignment: Optional[int] = None
7777
sym_shape_eval_pass: PassType = HintBasedSymShapeEvalPass()
78+
79+
# If set to true, view_copy operations will be removed from the graph when safe
80+
# Rather than be emitted as operators, they will be emitted as evalues that share
81+
# the same underlying storage as their base
82+
try_remove_view_copy: bool = True

exir/emit/_emitter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,14 @@ def call_function(
11981198
assert len(args) == 1
11991199
return self._emit_spec(self.node.meta["spec"])
12001200

1201+
elif target == memory.view:
1202+
assert len(args) == 2
1203+
1204+
# A memory.view's base should already be emitted, so the
1205+
# memory.view's spec should dynamically reference its base's
1206+
# final state
1207+
return self._emit_spec(self.node.meta["spec"])
1208+
12011209
elif target == memory.free:
12021210
assert len(args) == 1
12031211
# pyre-ignore

exir/memory_planning.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,19 @@ def verify_graph_input_output(self) -> None:
261261
graph_output_allocated == self.alloc_graph_output
262262
), f"Misallocate graph output {graph_output_allocated} v.s. {self.alloc_graph_output}"
263263

264+
def verify_memory_view_are_memory_planned(self) -> None:
265+
"""
266+
memory.view nodes should only exist if their base is memory planned.
267+
"""
268+
for node in self.graph_module.graph.nodes:
269+
if node.op == "call_function" and node.target == memory.view:
270+
assert (
271+
node.meta["spec"].const or node.meta["spec"].mem_id is not None
272+
), "memory.view node is not const and has no mem_id."
273+
assert (
274+
node.meta["spec"].const or node.meta["spec"].mem_offset is not None
275+
), "memory.view node is not const has no mem_offset."
276+
264277

265278
def register_algo(fn: Callable[..., List[int]]) -> Callable[..., List[int]]:
266279
algo_name = fn.__name__
@@ -535,7 +548,13 @@ def get_node_tensor_specs(
535548
has no tensor specs.
536549
"""
537550
# get tensor specs
538-
specs = node.meta.get("spec")
551+
if node.target == memory.view:
552+
base = node.args[0]
553+
assert isinstance(base, torch.fx.Node)
554+
specs = base.meta.get("spec")
555+
else:
556+
specs = node.meta.get("spec")
557+
539558
if isinstance(specs, TensorSpec):
540559
specs = [specs]
541560
if not isinstance(specs, (list, tuple)):

exir/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None:
251251
# we won't see it in the input graph to the to_out_variant pass, unless
252252
# it's retraced after running to_out_variant with the first trace.
253253
memory.alloc,
254+
memory.view,
254255
executorch_call_delegate,
255256
torch.ops.aten.copy_.default,
256257
}

exir/passes/memory_planning_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,5 @@ def run(
128128
f"The {self.memory_planning_algo} algorithm reuses storage for {num_reuse_pairs} pair of tensors"
129129
)
130130
verifier.verify_graph_input_output()
131+
verifier.verify_memory_view_are_memory_planned()
131132
return PassResult(graph_module, True)

exir/program/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ python_library(
3232
"//executorch/exir/emit:lib",
3333
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
3434
"//executorch/exir/passes:lib",
35+
"//executorch/exir/passes:normalize_view_copy_base_pass",
3536
"//executorch/exir/passes:remove_graph_asserts_pass",
3637
"//executorch/exir/passes:remove_mixed_type_operators",
38+
"//executorch/exir/passes:replace_view_copy_with_memory_view_pass",
3739
"//executorch/exir/passes:spec_prop_pass",
3840
"//executorch/exir/verification:verifier",
3941
],

exir/program/_program.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,14 @@
3131
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
3232
insert_write_back_for_buffers_pass,
3333
)
34+
from executorch.exir.passes.normalize_view_copy_base_pass import (
35+
NormalizeViewCopyBasePass,
36+
)
3437
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
3538
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
39+
from executorch.exir.passes.replace_view_copy_with_memory_view_pass import (
40+
ReplaceViewCopyWithMemoryViewPass,
41+
)
3642
from executorch.exir.passes.spec_prop_pass import SpecPropPass
3743
from executorch.exir.print_program import pretty_print, print_program
3844
from executorch.exir.schema import Program
@@ -580,6 +586,23 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
580586
return new_ep
581587

582588

589+
def memory_planning_passes(config: ExecutorchBackendConfig) -> List[PassType]:
590+
if config.try_remove_view_copy:
591+
# pyre-ignore
592+
return [
593+
NormalizeViewCopyBasePass(),
594+
ReplaceViewCopyWithMemoryViewPass(),
595+
config.to_out_var_pass,
596+
config.memory_planning_pass,
597+
]
598+
else:
599+
# pyre-ignore
600+
return [
601+
config.to_out_var_pass,
602+
config.memory_planning_pass,
603+
]
604+
605+
583606
def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]:
584607
# pyre-ignore
585608
passes: List[PassType] = [
@@ -591,8 +614,8 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]
591614
EdgeToBackendOpsPass(),
592615
RemoveGraphAssertsPass(),
593616
config.sym_shape_eval_pass,
594-
config.to_out_var_pass,
595-
]
617+
] + memory_planning_passes(config)
618+
596619
return passes
597620

598621

@@ -835,6 +858,12 @@ def to_executorch(
835858
gm, new_signature = insert_write_back_for_buffers_pass(program)
836859
new_gm = program.graph_module
837860
for p in edge_to_executorch_passes(config):
861+
if isinstance(p, ReplaceViewCopyWithMemoryViewPass):
862+
# This is similar to the hack in SpecPropPass
863+
# Ideally passes would work on ExportedPrograms, but today
864+
# they work on GraphModules
865+
p.set_program(program)
866+
838867
new_gm_res = p(new_gm)
839868
assert new_gm_res is not None
840869
new_gm = new_gm_res.graph_module

exir/tests/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,17 @@ python_unittest(
447447
"//executorch/exir:print_program",
448448
],
449449
)
450+
451+
python_unittest(
452+
name = "test_try_remove_view_copy",
453+
srcs = [
454+
"test_try_remove_view_copy.py",
455+
],
456+
deps = [
457+
"//caffe2:torch",
458+
"//executorch/exir:lib",
459+
"//executorch/exir:memory",
460+
"//executorch/exir/capture:config",
461+
"//executorch/exir/passes:lib",
462+
],
463+
)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import unittest
9+
10+
import torch
11+
import torch.nn as nn
12+
from executorch.exir import memory, to_edge
13+
from executorch.exir.capture._config import ExecutorchBackendConfig
14+
from executorch.exir.passes import MemoryPlanningPass
15+
16+
17+
class TestModel1(nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
self.parameter = nn.Parameter(torch.rand(5, 6))
21+
self.parameter.requires_grad = False
22+
23+
def forward(self, x):
24+
v1 = self.parameter.view(
25+
6, 5
26+
) # removed, lifetime of parameter will be extended
27+
v2 = x.view(6, 5) # not removed
28+
v3 = torch.ops.aten.mul.Tensor(v1, v2).view(
29+
30
30+
) # removed, lifetime of mul.Tensor will be extended
31+
return v3
32+
33+
def get_example_inputs(self):
34+
return (torch.rand(5, 6),)
35+
36+
37+
class TestTryRemoveViewCopy(unittest.TestCase):
38+
def test_disable(self) -> None:
39+
model = TestModel1()
40+
model.eval()
41+
example_inputs = model.get_example_inputs()
42+
ep = torch.export.export(model, example_inputs)
43+
etpm = to_edge(ep).to_executorch(
44+
config=ExecutorchBackendConfig(
45+
try_remove_view_copy=False,
46+
memory_planning_pass=MemoryPlanningPass(
47+
"greedy", alloc_graph_input=False
48+
),
49+
),
50+
)
51+
52+
for node in etpm.exported_program().graph_module.graph.nodes:
53+
assert node.target != memory.view
54+
55+
def test_output_matches(self) -> None:
56+
model = TestModel1()
57+
model.eval()
58+
example_inputs = model.get_example_inputs()
59+
ep = torch.export.export(model, example_inputs)
60+
61+
epm_remove = to_edge(ep)
62+
epm_no_remove = copy.deepcopy(
63+
epm_remove
64+
) # to_executorch modifies the edge_program, so we make a copy
65+
66+
# Run pass with no removal
67+
etpm_remove = epm_remove.to_executorch(
68+
config=ExecutorchBackendConfig(
69+
try_remove_view_copy=True,
70+
memory_planning_pass=MemoryPlanningPass(
71+
"greedy", alloc_graph_input=False
72+
),
73+
),
74+
)
75+
76+
# Run pass with removal
77+
etpm_no_remove = epm_no_remove.to_executorch(
78+
config=ExecutorchBackendConfig(
79+
try_remove_view_copy=True,
80+
memory_planning_pass=MemoryPlanningPass(
81+
"greedy", alloc_graph_input=False
82+
),
83+
),
84+
)
85+
86+
out_remove = etpm_remove.exported_program().module()(*example_inputs)
87+
out_no_remove = etpm_no_remove.exported_program().module()(*example_inputs)
88+
89+
self.assertTrue(torch.allclose(out_remove, out_no_remove))
90+
91+
def test_spec(self) -> None:
92+
model = TestModel1()
93+
model.eval()
94+
example_inputs = model.get_example_inputs()
95+
ep = torch.export.export(model, example_inputs)
96+
97+
etpm = to_edge(ep).to_executorch(
98+
config=ExecutorchBackendConfig(
99+
try_remove_view_copy=True,
100+
memory_planning_pass=MemoryPlanningPass(
101+
"greedy", alloc_graph_input=False
102+
),
103+
),
104+
)
105+
106+
# etpm.exported_program().graph.print_tabular()
107+
108+
# idx opcode name target args kwargs
109+
# --- ------------- ------------------------ ---------------------------------- -------------------------------------------------- ----------------
110+
# 0 placeholder arg0_1 arg0_1 () {}
111+
# 1 placeholder arg1_1 arg1_1 () {}
112+
# 2 call_function aten_view_copy_default <function view at 0x7f10a6dfeb00> (arg0_1, [6, 5]) {}
113+
# 3 call_function alloc <function alloc at 0x7f10a6dfe9e0> (((6, 5), torch.float32),) {}
114+
# 4 call_function aten_view_copy_default_1 aten.view_copy.out (arg1_1, [6, 5]) {'out': alloc}
115+
# 5 call_function alloc_1 <function alloc at 0x7f10a6dfe9e0> (((6, 5), torch.float32),) {}
116+
# 6 call_function aten_mul_tensor aten.mul.out (aten_view_copy_default, aten_view_copy_default_1) {'out': alloc_1}
117+
# 7 call_function aten_view_copy_default_2 <function view at 0x7f10a6dfeb00> (aten_mul_tensor, [30]) {}
118+
# 8 output output_1 output ((aten_view_copy_default_2,),) {}
119+
120+
# arg0_1 is the parameter
121+
# arg1_1 is the user input
122+
123+
for node in etpm.exported_program().graph.nodes:
124+
if node.name == "arg0_1":
125+
# arg0_1's lifetime is extended through aten_view_copy_default (memory.view) to idx 6
126+
self.assertEqual(node.meta["spec"].lifetime, [0, 6])
127+
elif node.name == "aten_view_copy_default":
128+
# aten_view_copy_default is a memory.view of arg0_1.
129+
# arg0_1 is a constant with storage, so we check that the view's storage matches the base
130+
131+
# assert base is arg0_1
132+
self.assertEqual(node.args[0].name, "arg0_1")
133+
134+
# assert base is const with storage
135+
self.assertTrue(node.args[0].meta["spec"].const)
136+
self.assertTrue(node.args[0].meta["spec"].storage is not None)
137+
self.assertTrue(node.args[0].meta["spec"].mem_id is None)
138+
self.assertTrue(node.args[0].meta["spec"].mem_offset is None)
139+
140+
# assert self is const with storage
141+
self.assertTrue(node.meta["spec"].const)
142+
self.assertTrue(node.meta["spec"].storage is not None)
143+
self.assertTrue(node.meta["spec"].mem_id is None)
144+
self.assertTrue(node.meta["spec"].mem_offset is None)
145+
146+
# assert storage matches
147+
self.assertEqual(
148+
node.meta["spec"].storage, node.args[0].meta["spec"].storage
149+
)
150+
151+
# assert lifetime matches
152+
self.assertEqual(
153+
node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime
154+
)
155+
elif node.name == "aten_mul_tensor":
156+
# aten_mul_tensor's lifetime is extended through aten_view_copy_default_2 (memory.view) to idx 8
157+
self.assertEqual(node.meta["spec"].lifetime, [5, 8])
158+
elif node.name == "aten_view_copy_default_2":
159+
# aten_view_copy_default_2 is a memory.view of aten_mul_tensor
160+
161+
# assert base is aten_mul_tensor
162+
self.assertEqual(node.args[0].name, "aten_mul_tensor")
163+
164+
# assert base and self are not const, do not have storage,
165+
# but do have mem_id and mem_offset
166+
self.assertFalse(node.args[0].meta["spec"].const)
167+
self.assertTrue(node.args[0].meta["spec"].storage is None)
168+
self.assertTrue(node.args[0].meta["spec"].mem_id is not None)
169+
self.assertTrue(node.args[0].meta["spec"].mem_offset is not None)
170+
171+
self.assertFalse(node.meta["spec"].const)
172+
self.assertTrue(node.meta["spec"].storage is None)
173+
self.assertTrue(node.meta["spec"].mem_id is not None)
174+
self.assertTrue(node.meta["spec"].mem_offset is not None)
175+
176+
# assert self and base mem_id, mem_offset, and lifetime matches
177+
self.assertEqual(
178+
node.meta["spec"].mem_id, node.args[0].meta["spec"].mem_id
179+
)
180+
self.assertEqual(
181+
node.meta["spec"].mem_offset, node.args[0].meta["spec"].mem_offset
182+
)
183+
self.assertEqual(
184+
node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime
185+
)
186+
187+
# Test evalues in execution plan
188+
evalues = etpm.executorch_program.execution_plan[0].values
189+
190+
# evalue 0 is the parameter arg0_1 and evalue 2 is view aten_view_copy_default
191+
# assert their sizes are as expected and constant_buffer_idx != 0
192+
self.assertEqual(evalues[0].val.sizes, [5, 6]) # pyre-ignore
193+
self.assertNotEqual(evalues[0].val.constant_buffer_idx, 0) # pyre-ignore
194+
self.assertEqual(evalues[2].val.sizes, [6, 5]) # pyre-ignore
195+
self.assertNotEqual(evalues[2].val.constant_buffer_idx, 0) # pyre-ignore
196+
197+
# assert they have the same constant_buffer_idx
198+
self.assertEqual(evalues[0].val.constant_buffer_idx, evalues[2].val.constant_buffer_idx) # pyre-ignore
199+
200+
# evalue 7 is alloc_1 (aten_mul_tensor) and evalue 8 is aten_view_copy_default_2
201+
# assert their sizes are as expected and constant_buffer_idx == 0
202+
self.assertEqual(evalues[7].val.sizes, [6, 5]) # pyre-ignore
203+
self.assertEqual(evalues[7].val.constant_buffer_idx, 0) # pyre-ignore
204+
self.assertEqual(evalues[8].val.sizes, [30]) # pyre-ignore
205+
self.assertEqual(evalues[8].val.constant_buffer_idx, 0) # pyre-ignore
206+
207+
# assert they have the same mem_id and mem_offset low and high
208+
self.assertEqual(evalues[7].val.allocation_info.memory_id, evalues[8].val.allocation_info.memory_id) # pyre-ignore
209+
self.assertEqual(evalues[7].val.allocation_info.memory_offset_low, evalues[8].val.allocation_info.memory_offset_low) # pyre-ignore
210+
self.assertEqual(evalues[7].val.allocation_info.memory_offset_high, evalues[8].val.allocation_info.memory_offset_high) # pyre-ignore

0 commit comments

Comments
 (0)