Skip to content

Commit b77fd57

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
emit programs with mutable buffers (#2233)
Summary: Meaningful changes to the emitter logic here. Before we would ignore the tensor spec passed in and try to decide if the placeholder was a constant and if it was we would create a new spec from the actual value for that constant. That drops meta data on the input spec which is not great. Now instead of that we just look up the storage of the concrete tensor and hook it up to the spec. Also added some logic to seperate out behavior for mutable buffers specifically. While working on this I also discovered a bug that memory planning is planning space for parameters and constant buffers if its told to allocate space for inputs which is really bad lol. Oh one big assumption this diff makes is that the buffer does not have a meaningful initial state. I should probably throw out a warning during emission about this in the short term. Long term we will handle them properly. bypass-github-export-checks Reviewed By: tarun292 Differential Revision: D53713544
1 parent 862f755 commit b77fd57

File tree

4 files changed

+140
-36
lines changed

4 files changed

+140
-36
lines changed

exir/emit/_emit_program.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8+
import copy
89
from dataclasses import dataclass
910
from typing import Any, Dict, List, Optional, Union
1011

@@ -32,7 +33,8 @@
3233
)
3334
from executorch.exir.tensor import layout_enum, scalar_type_enum
3435
from executorch.exir.version import EXECUTORCH_SCHEMA_VERSION
35-
from torch.export.exported_program import ExportedProgram
36+
from torch.export.exported_program import ExportedProgram, OutputKind
37+
from torch.utils import _pytree as pytree
3638

3739

3840
def _emit_prim_getters(prim_getters: Dict[str, Any]) -> List[ExecutionPlan]:
@@ -122,6 +124,36 @@ class EmitterOutput:
122124
]
123125

124126

127+
def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
128+
gm = copy.deepcopy(exported_program.graph_module)
129+
output_node = None
130+
for node in gm.graph.nodes:
131+
if node.op == "output":
132+
output_node = node
133+
assert output_node is not None
134+
135+
mutated_outputs: List[Optional[str]] = [
136+
out_spec.target if out_spec.kind in (OutputKind.BUFFER_MUTATION,) else None
137+
for out_spec in exported_program.graph_signature.output_specs
138+
]
139+
outputs = pytree.tree_flatten(output_node.args)[0]
140+
141+
user_output_nodes = []
142+
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
143+
if mutated_node_name is None:
144+
user_output_nodes.append(return_node)
145+
continue
146+
147+
with gm.graph.inserting_before(output_node):
148+
# Only return user outputs
149+
new_output = gm.graph.output(tuple(user_output_nodes))
150+
new_output.meta = output_node.meta.copy()
151+
output_node.replace_all_uses_with(new_output)
152+
gm.graph.erase_node(output_node)
153+
154+
return gm
155+
156+
125157
def emit_program(
126158
methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
127159
emit_stacktrace: bool = False,
@@ -163,13 +195,6 @@ def emit_program(
163195

164196
# emit each entry point in order according to name.
165197
for name, exported_program in sorted(methods.items()):
166-
if (
167-
exported_program.graph_signature.buffers_to_mutate
168-
): # see if we are mutating any state
169-
raise ExportError(
170-
ExportErrorType.INVALID_INPUT_TYPE,
171-
"Buffers cannot be modified in executorch.",
172-
)
173198
# create empty state
174199
emitter_state = _EmitterState(
175200
values=[],
@@ -180,7 +205,11 @@ def emit_program(
180205
emit_stacktrace=emit_stacktrace,
181206
)
182207

183-
emitter = _TopLevelEmitter(name, exported_program, program_state, emitter_state)
208+
gm = _remove_non_user_outputs(exported_program)
209+
210+
emitter = _TopLevelEmitter(
211+
name, exported_program, gm, program_state, emitter_state
212+
)
184213

185214
emitter.run()
186215
plans.append(emitter.plan())

exir/emit/_emitter.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
# pyre-strict
3131
import ctypes
3232
import hashlib
33+
import warnings
3334
import operator
3435
import typing
3536
from dataclasses import dataclass, field
@@ -1266,15 +1267,17 @@ def __init__(
12661267
self,
12671268
name: str,
12681269
exported_program: ExportedProgram,
1270+
graph_module: torch.fx.GraphModule,
12691271
program_state: _ProgramState,
12701272
emitter_state: _EmitterState,
12711273
) -> None:
1272-
super().__init__(exported_program.graph_module, emitter_state, program_state)
1274+
super().__init__(graph_module, emitter_state, program_state)
12731275
self.name = name
12741276
self.exported_program = exported_program
12751277

12761278
self.inputs: List[int] = []
12771279
self.outputs: List[int] = []
1280+
self.given_mutable_buffer_warning = False
12781281

12791282
def create_container_str(spec: Optional[pytree.TreeSpec]) -> str:
12801283
if spec is None:
@@ -1302,40 +1305,57 @@ def placeholder(
13021305
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
13031306
"""
13041307
spec = self.node.meta["spec"]
1305-
const_tensor = False
1306-
if isinstance(target, str) and (
1307-
target in self.exported_program.graph_signature.inputs_to_parameters
1308-
or target in self.exported_program.graph_signature.inputs_to_buffers
1309-
or target
1310-
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
1311-
):
1312-
if (
1308+
is_user_input = True
1309+
1310+
if isinstance(target, str) and isinstance(spec, TensorSpec):
1311+
# Find the fully qualified name
1312+
fqn = None
1313+
is_mutable_buffer = False
1314+
if target in self.exported_program.graph_signature.inputs_to_parameters:
1315+
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
1316+
1317+
elif target in self.exported_program.graph_signature.inputs_to_buffers:
1318+
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
1319+
1320+
# if the buffer is mutated then record that
1321+
if (
1322+
fqn
1323+
in self.exported_program.graph_signature.buffers_to_mutate.values()
1324+
):
1325+
is_mutable_buffer = True
1326+
if not self.given_mutable_buffer_warning:
1327+
warnings.warn(
1328+
"Mutation on a buffer in the model is detected. ExecuTorch assumes "
1329+
"buffers that are mutated in the graph have a meaningless initial state, "
1330+
"only the shape and dtype will be serialized.",
1331+
UserWarning,
1332+
stacklevel=1,
1333+
)
1334+
self.mutable_buffer_warning_count = True
1335+
1336+
elif (
13131337
target
13141338
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
13151339
):
13161340
fqn = self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
13171341
target
13181342
]
1319-
elif target in self.exported_program.graph_signature.inputs_to_buffers:
1320-
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
1321-
else:
1322-
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
1343+
1344+
# From the fqn find the corresponding tensor
1345+
real_tensor = None
13231346
if fqn in self.exported_program.state_dict:
1324-
spec = TensorSpec.from_tensor(
1325-
self.exported_program.state_dict[fqn], const=True
1326-
)
1327-
const_tensor = True
1347+
real_tensor = self.exported_program.state_dict[fqn]
1348+
is_user_input = False
1349+
13281350
elif fqn in self.exported_program.constants:
1329-
spec = TensorSpec.from_tensor(
1330-
self.exported_program.constants[fqn], const=True
1331-
)
1332-
const_tensor = True
1333-
else:
1351+
real_tensor = self.exported_program.constants[fqn]
1352+
is_user_input = False
1353+
elif fqn is not None:
13341354
buffers = self.exported_program.named_buffers()
13351355
buf = next((x[1] for x in buffers if x[0] == fqn), None)
13361356
if buf is not None:
1337-
spec = TensorSpec.from_tensor(buf, const=True)
1338-
const_tensor = True
1357+
real_tensor = buf
1358+
is_user_input = False
13391359
else:
13401360
raise InternalError(
13411361
self._emit_node_specific_error(
@@ -1344,13 +1364,28 @@ def placeholder(
13441364
)
13451365
)
13461366

1367+
# assign the storage of the placeholder spec to the storage of the real tensor if there is one
1368+
if real_tensor is not None:
1369+
# for non-contigous tensors, convert to a contiguous one
1370+
real_tensor = real_tensor.contiguous()
1371+
# Weights cannot be views during emission or serialization
1372+
if real_tensor.nbytes != real_tensor.untyped_storage().nbytes():
1373+
real_tensor = real_tensor.clone()
1374+
1375+
spec.storage = real_tensor.untyped_storage()
1376+
1377+
# User inputs and mutable buffers are not constants, other buffers or parameters are.
1378+
spec.const = not (is_user_input or is_mutable_buffer)
1379+
13471380
evalue = (
13481381
self._tensor_spec_to_evalue(spec)
13491382
if isinstance(spec, TensorSpec)
13501383
else self._constant_to_evalue(spec, None)
13511384
)
13521385
value = self._emit_evalue(evalue)
1353-
if not const_tensor:
1386+
1387+
# Only user inputs should remain as inputs.
1388+
if is_user_input:
13541389
self.inputs.append(value.id)
13551390

13561391
return value

exir/emit/test/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ python_unittest(
2121
"//executorch/exir/passes:constant_prop_pass",
2222
"//executorch/exir/tests:lib",
2323
"//executorch/exir/tests:models",
24-
"//executorch/extension/pybindings:portable_lib", # @manual
24+
"//executorch/extension/pybindings:aten_lib",
2525
],
2626
)

exir/emit/test/test_emit.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager, to_edge
2020
from executorch.exir.backend.backend_api import to_backend
2121
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
22+
from executorch.exir.dialects._ops import ops as exir_ops
2223
from executorch.exir.emit import emit_program # noqa
2324
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
2425
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
@@ -42,6 +43,7 @@
4243
)
4344
from executorch.exir.tests.common import register_additional_test_aten_ops
4445
from executorch.exir.tests.models import Mul
46+
from executorch.extension.pybindings.aten_lib import _load_for_executorch_from_buffer
4547
from functorch.experimental import control_flow
4648
from torch import nn
4749

@@ -1028,7 +1030,7 @@ def forward(self, k: torch.Tensor) -> torch.Tensor:
10281030
edge = to_edge(captured)
10291031
from executorch.exir.passes import MemoryPlanningPass
10301032

1031-
config = exir.ExecutorchBackendConfig(
1033+
config = exir.ExecutorchBackendConfig( # pyre-ignore[28]
10321034
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
10331035
memory_planning_pass=MemoryPlanningPass(
10341036
memory_planning_algo="greedy",
@@ -1393,3 +1395,41 @@ def forward(self, x):
13931395
self.assertEqual(len(exec_plan.inputs), 1)
13941396
self.assertEqual(len(program.constant_buffer), 2)
13951397
self.assertEqual(len(program.constant_buffer[1].storage), 24)
1398+
1399+
def test_mutable_buffers(self) -> None:
1400+
def count_copies(gm: torch.fx.GraphModule) -> int:
1401+
return sum(
1402+
(
1403+
node.target == torch.ops.aten.copy_
1404+
or node.target == exir_ops.edge.aten.copy_.default
1405+
)
1406+
for node in gm.graph.nodes
1407+
)
1408+
1409+
class MutableStateModule(torch.nn.Module):
1410+
def __init__(self):
1411+
super().__init__()
1412+
self.register_buffer("state", torch.zeros(1))
1413+
1414+
def forward(self, x):
1415+
y = x + self.state
1416+
self.state.add_(1)
1417+
return y
1418+
1419+
model = to_edge(
1420+
export(
1421+
MutableStateModule(),
1422+
(torch.zeros(1),),
1423+
)
1424+
)
1425+
model = model.to_executorch()
1426+
model.dump_executorch_program(True)
1427+
self.assertTrue(
1428+
model.executorch_program.execution_plan[0] # pyre-ignore[16]
1429+
.values[0]
1430+
.val.allocation_info
1431+
is not None
1432+
)
1433+
executorch_module = _load_for_executorch_from_buffer(model.buffer)
1434+
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
1435+
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)

0 commit comments

Comments
 (0)