Skip to content

Commit a704dd6

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Inject Inplace Copies into Graph for Mutable Bufers (#1995)
Summary: Pull Request resolved: #1995 Injects copy nodes into the graph Reviewed By: larryliu0820 Differential Revision: D53713415 fbshipit-source-id: 78381f9df5356a50c126ad8eee7955e2d8e0be10
1 parent b321616 commit a704dd6

File tree

7 files changed

+193
-0
lines changed

7 files changed

+193
-0
lines changed

exir/passes/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ python_library(
1010
deps = [
1111
":const_prop_pass",
1212
":debug_handle_generator_pass",
13+
":insert_write_back_for_buffers_pass",
1314
":memory_format_ops_pass",
1415
":memory_planning_pass",
1516
":normalize_transpose_pass",
@@ -51,6 +52,16 @@ python_library(
5152
],
5253
)
5354

55+
python_library(
56+
name = "insert_write_back_for_buffers_pass",
57+
srcs = [
58+
"insert_write_back_for_buffers_pass.py",
59+
],
60+
deps = [
61+
"//caffe2:torch",
62+
],
63+
)
64+
5465
python_library(
5566
name = "const_prop_pass",
5667
srcs = [

exir/passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
3737

3838
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
39+
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
40+
insert_write_back_for_buffers_pass,
41+
)
3942
from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass
4043
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
4144
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
@@ -65,6 +68,7 @@
6568
"MemoryFormatOpsPass",
6669
"MemoryPlanningPass",
6770
"HintBasedSymShapeEvalPass",
71+
"insert_write_back_for_buffers_pass",
6872
]
6973

7074
Argument = Optional[
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict, List, Optional
8+
9+
import torch
10+
11+
from torch.export.exported_program import (
12+
ExportedProgram,
13+
ExportGraphSignature,
14+
InputKind,
15+
OutputKind,
16+
OutputSpec,
17+
)
18+
from torch.utils import _pytree as pytree
19+
20+
21+
def _insert_copy(
22+
gm: torch.fx.GraphModule,
23+
mutated_outputs: List[Optional[str]],
24+
input_name_to_node: Dict[str, torch.fx.Node],
25+
):
26+
"""
27+
Find the all the buffers and inputs that were mutated and insert copy_
28+
operators to reflect mutations.
29+
"""
30+
output_node = None
31+
for node in gm.graph.nodes:
32+
if node.op == "output":
33+
output_node = node
34+
break
35+
assert output_node is not None
36+
outputs = pytree.tree_flatten(output_node.args)[0]
37+
assert len(outputs) == len(mutated_outputs)
38+
39+
user_output_nodes = []
40+
buffer_output_nodes = []
41+
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
42+
# User output, leave alone
43+
if mutated_node_name is None:
44+
user_output_nodes.append(return_node)
45+
continue
46+
47+
# Mutable buffer grab the node
48+
if mutated_node_name in input_name_to_node:
49+
mutated_node = input_name_to_node[mutated_node_name]
50+
else:
51+
raise RuntimeError(
52+
f"Could not find {mutated_node_name} in either buffer or input nodes"
53+
)
54+
55+
# insert copy
56+
with gm.graph.inserting_before(output_node):
57+
buffer_output = gm.graph.call_function(
58+
torch.ops.aten.copy_.default, (mutated_node, return_node)
59+
)
60+
# add output of copy to graph outputs
61+
buffer_output_nodes.append(buffer_output)
62+
63+
with gm.graph.inserting_before(output_node):
64+
buffer_output_nodes.extend(user_output_nodes)
65+
# Remove old outputs
66+
new_output = gm.graph.output(tuple(buffer_output_nodes))
67+
output_node.replace_all_uses_with(new_output)
68+
gm.graph.erase_node(output_node)
69+
return buffer_output_nodes
70+
71+
72+
def insert_write_back_for_buffers_pass(ep: ExportedProgram):
73+
gm: torch.fx.GraphModule = ep.graph_module
74+
lifted_inputs: List[Optional[str]] = [
75+
in_spec.target
76+
if in_spec.kind
77+
in (
78+
InputKind.BUFFER,
79+
InputKind.CONSTANT_TENSOR,
80+
InputKind.PARAMETER,
81+
InputKind.CUSTOM_OBJ,
82+
)
83+
else None
84+
for in_spec in ep.graph_signature.input_specs
85+
]
86+
87+
# Grab the mutable buffer nodes in the outputs
88+
mutated_outputs: List[Optional[str]] = [
89+
out_spec.target if out_spec.kind in (OutputKind.BUFFER_MUTATION,) else None
90+
for out_spec in ep.graph_signature.output_specs
91+
]
92+
93+
input_name_to_node: Dict[str, torch.fx.Node] = {}
94+
95+
placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
96+
assert len(lifted_inputs) == len(placeholder_nodes)
97+
# Grab the all the non user inputs
98+
for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
99+
if lifted_node is not None:
100+
input_name_to_node[lifted_node] = input_node
101+
102+
# insert the copy ops and update the outputs
103+
buffer_output_nodes = _insert_copy(gm, mutated_outputs, input_name_to_node)
104+
gm.graph.lint()
105+
gm.graph.eliminate_dead_code()
106+
gm.recompile()
107+
108+
# patch the output signature to point to the new updated outputs
109+
new_output_specs: List[OutputSpec] = []
110+
i = 0
111+
for output_spec in ep.graph_signature.output_specs:
112+
if output_spec.kind == OutputKind.BUFFER_MUTATION:
113+
output_spec.arg.name = buffer_output_nodes[i].name
114+
i += 1
115+
new_output_specs.append(output_spec)
116+
117+
signature = ExportGraphSignature(
118+
input_specs=ep.graph_signature.input_specs,
119+
output_specs=new_output_specs,
120+
)
121+
122+
return gm, signature

exir/program/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ python_library(
3030
"//executorch/exir/capture:config",
3131
"//executorch/exir/emit:emit",
3232
"//executorch/exir/emit:lib",
33+
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
3334
"//executorch/exir/passes:lib",
3435
"//executorch/exir/passes:remove_graph_asserts_pass",
3536
"//executorch/exir/passes:remove_mixed_type_operators",

exir/program/_program.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
MemoryFormatOpsPass,
2727
OpReplacePass,
2828
)
29+
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
30+
insert_write_back_for_buffers_pass,
31+
)
2932
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
3033
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
3134
from executorch.exir.passes.spec_prop_pass import SpecPropPass
@@ -45,6 +48,7 @@
4548
ExportGraphSignature,
4649
InputKind,
4750
InputSpec,
51+
OutputKind,
4852
OutputSpec,
4953
TensorArgument,
5054
)
@@ -1034,6 +1038,7 @@ def to_executorch(
10341038

10351039
execution_programs: Dict[str, ExportedProgram] = {}
10361040
for name, program in self._edge_programs.items():
1041+
gm, _ = insert_write_back_for_buffers_pass(program)
10371042
new_gm = program.graph_module
10381043
for p in edge_to_executorch_passes(config):
10391044
new_gm_res = p(new_gm)

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ python_unittest(
213213
"//executorch/exir/emit:lib",
214214
"//executorch/exir/passes:constant_prop_pass",
215215
"//executorch/exir/passes:debug_handle_generator_pass",
216+
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
216217
"//executorch/exir/passes:lib",
217218
"//executorch/exir/passes:remove_graph_asserts_pass",
218219
"//executorch/exir/passes:remove_mixed_type_operators",

exir/tests/test_passes.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
)
3434
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
3535
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
36+
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
37+
insert_write_back_for_buffers_pass,
38+
)
3639
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
3740
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
3841
from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass
@@ -1195,3 +1198,49 @@ def forward(self, pred, x):
11951198
error_msg,
11961199
):
11971200
_ = constant_prop_pass(edge.exported_program())
1201+
1202+
def test_mutable_buffers(self) -> None:
1203+
def count_copies(gm: torch.fx.GraphModule) -> int:
1204+
return sum(
1205+
(node.target == torch.ops.aten.copy_.default) for node in gm.graph.nodes
1206+
)
1207+
1208+
class MutableStateModule(torch.nn.Module):
1209+
def __init__(self):
1210+
super().__init__()
1211+
self.register_buffer("state", torch.zeros(1))
1212+
1213+
def forward(self, x):
1214+
y = x + self.state
1215+
self.state.add_(1)
1216+
return y
1217+
1218+
model = to_edge(
1219+
export(
1220+
MutableStateModule(),
1221+
(torch.zeros(1),),
1222+
)
1223+
)
1224+
self.assertEqual(count_copies(model.exported_program().graph_module), 0)
1225+
# Before
1226+
# graph():
1227+
# %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
1228+
# %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1229+
# %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1230+
# %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1231+
# %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1232+
# %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1233+
# return (aten_add_tensor_1, aten_add_tensor)
1234+
gm, _ = insert_write_back_for_buffers_pass(model.exported_program())
1235+
1236+
# After
1237+
# graph():
1238+
# %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
1239+
# %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1240+
# %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1241+
# %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1242+
# %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1243+
# %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1244+
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
1245+
# return (copy__default, aten_add_tensor)
1246+
self.assertEqual(count_copies(gm), 1)

0 commit comments

Comments
 (0)