Skip to content

Commit aef3a7c

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
emit programs with mutable buffers (#2233)
Summary: Pull Request resolved: #2233 Meaningful changes to the emitter logic here. Before we would ignore the tensor spec passed in and try to decide if the placeholder was a constant and if it was we would create a new spec from the actual value for that constant. That drops meta data on the input spec which is not great. Now instead of that we just look up the storage of the concrete tensor and hook it up to the spec. Also added some logic to seperate out behavior for mutable buffers specifically. While working on this I also discovered a bug that memory planning is planning space for parameters and constant buffers if its told to allocate space for inputs which is really bad lol. Oh one big assumption this diff makes is that the buffer does not have a meaningful initial state. I should probably throw out a warning during emission about this in the short term. Long term we will handle them properly. bypass-github-export-checks Reviewed By: tarun292, Jack-Khuu Differential Revision: D53713544 fbshipit-source-id: b0bd8abd89e2d0e2006f0f1d885b1eaa02653afa
1 parent c119247 commit aef3a7c

File tree

4 files changed

+151
-42
lines changed

4 files changed

+151
-42
lines changed

exir/emit/_emit_program.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
)
3333
from executorch.exir.tensor import layout_enum, scalar_type_enum
3434
from executorch.exir.version import EXECUTORCH_SCHEMA_VERSION
35-
from torch.export.exported_program import ExportedProgram
35+
from torch.export.exported_program import ExportedProgram, OutputKind
36+
from torch.utils import _pytree as pytree
3637

3738

3839
def _emit_prim_getters(prim_getters: Dict[str, Any]) -> List[ExecutionPlan]:
@@ -122,6 +123,36 @@ class EmitterOutput:
122123
]
123124

124125

126+
def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
127+
gm = exported_program.graph_module
128+
output_node = None
129+
for node in gm.graph.nodes:
130+
if node.op == "output":
131+
output_node = node
132+
assert output_node is not None
133+
134+
mutated_outputs: List[Optional[str]] = [
135+
out_spec.target if out_spec.kind in (OutputKind.BUFFER_MUTATION,) else None
136+
for out_spec in exported_program.graph_signature.output_specs
137+
]
138+
outputs = pytree.tree_flatten(output_node.args)[0]
139+
140+
user_output_nodes = []
141+
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
142+
if mutated_node_name is None:
143+
user_output_nodes.append(return_node)
144+
continue
145+
146+
with gm.graph.inserting_before(output_node):
147+
# Only return user outputs
148+
new_output = gm.graph.output(tuple(user_output_nodes))
149+
new_output.meta = output_node.meta.copy()
150+
output_node.replace_all_uses_with(new_output)
151+
gm.graph.erase_node(output_node)
152+
153+
return gm
154+
155+
125156
def emit_program(
126157
methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
127158
emit_stacktrace: bool = False,
@@ -163,13 +194,6 @@ def emit_program(
163194

164195
# emit each entry point in order according to name.
165196
for name, exported_program in sorted(methods.items()):
166-
if (
167-
exported_program.graph_signature.buffers_to_mutate
168-
): # see if we are mutating any state
169-
raise ExportError(
170-
ExportErrorType.INVALID_INPUT_TYPE,
171-
"Buffers cannot be modified in executorch.",
172-
)
173197
# create empty state
174198
emitter_state = _EmitterState(
175199
values=[],
@@ -180,7 +204,11 @@ def emit_program(
180204
emit_stacktrace=emit_stacktrace,
181205
)
182206

183-
emitter = _TopLevelEmitter(name, exported_program, program_state, emitter_state)
207+
gm = _remove_non_user_outputs(exported_program)
208+
209+
emitter = _TopLevelEmitter(
210+
name, exported_program, gm, program_state, emitter_state
211+
)
184212

185213
emitter.run()
186214
plans.append(emitter.plan())

exir/emit/_emitter.py

Lines changed: 73 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
import hashlib
3333
import operator
3434
import typing
35+
import warnings
3536
from dataclasses import dataclass, field
36-
from typing import Callable, cast, Dict, List, Mapping, Optional, Tuple, Union
37+
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Tuple, Union
3738

3839
import executorch.exir.memory as memory
3940
import executorch.extension.pytree as ex_pytree
@@ -1266,15 +1267,17 @@ def __init__(
12661267
self,
12671268
name: str,
12681269
exported_program: ExportedProgram,
1270+
graph_module: torch.fx.GraphModule,
12691271
program_state: _ProgramState,
12701272
emitter_state: _EmitterState,
12711273
) -> None:
1272-
super().__init__(exported_program.graph_module, emitter_state, program_state)
1274+
super().__init__(graph_module, emitter_state, program_state)
12731275
self.name = name
12741276
self.exported_program = exported_program
12751277

12761278
self.inputs: List[int] = []
12771279
self.outputs: List[int] = []
1280+
self.given_mutable_buffer_warning = False
12781281

12791282
def create_container_str(spec: Optional[pytree.TreeSpec]) -> str:
12801283
if spec is None:
@@ -1293,6 +1296,42 @@ def create_container_str(spec: Optional[pytree.TreeSpec]) -> str:
12931296
inp_container_str, out_container_str
12941297
)
12951298

1299+
def _find_fqn_for_placeholder(
1300+
self, target: _Target, spec: Any # pyre-ignore[2]
1301+
) -> Tuple[Optional[str], bool]:
1302+
# Find the fully qualified name
1303+
fqn = None
1304+
is_mutable_buffer = False
1305+
if target in self.exported_program.graph_signature.inputs_to_parameters:
1306+
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
1307+
1308+
elif target in self.exported_program.graph_signature.inputs_to_buffers:
1309+
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
1310+
1311+
# if the buffer is mutated then record that
1312+
if fqn in self.exported_program.graph_signature.buffers_to_mutate.values():
1313+
is_mutable_buffer = True
1314+
if not self.given_mutable_buffer_warning:
1315+
warnings.warn(
1316+
"Mutation on a buffer in the model is detected. ExecuTorch assumes "
1317+
"buffers that are mutated in the graph have a meaningless initial state, "
1318+
"only the shape and dtype will be serialized.",
1319+
UserWarning,
1320+
stacklevel=1,
1321+
)
1322+
self.given_mutable_buffer_warning = True
1323+
1324+
elif (
1325+
target
1326+
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
1327+
):
1328+
fqn = (
1329+
self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
1330+
target
1331+
]
1332+
)
1333+
return fqn, is_mutable_buffer
1334+
12961335
def placeholder(
12971336
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
12981337
) -> _AbstractValue:
@@ -1302,40 +1341,27 @@ def placeholder(
13021341
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
13031342
"""
13041343
spec = self.node.meta["spec"]
1305-
const_tensor = False
1306-
if isinstance(target, str) and (
1307-
target in self.exported_program.graph_signature.inputs_to_parameters
1308-
or target in self.exported_program.graph_signature.inputs_to_buffers
1309-
or target
1310-
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
1311-
):
1312-
if (
1313-
target
1314-
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
1315-
):
1316-
fqn = self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
1317-
target
1318-
]
1319-
elif target in self.exported_program.graph_signature.inputs_to_buffers:
1320-
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
1321-
else:
1322-
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
1344+
is_user_input = True
1345+
1346+
if isinstance(target, str) and isinstance(spec, TensorSpec):
1347+
1348+
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
1349+
1350+
# From the fqn find the corresponding tensor
1351+
real_tensor = None
13231352
if fqn in self.exported_program.state_dict:
1324-
spec = TensorSpec.from_tensor(
1325-
self.exported_program.state_dict[fqn], const=True
1326-
)
1327-
const_tensor = True
1353+
real_tensor = self.exported_program.state_dict[fqn]
1354+
is_user_input = False
1355+
13281356
elif fqn in self.exported_program.constants:
1329-
spec = TensorSpec.from_tensor(
1330-
self.exported_program.constants[fqn], const=True
1331-
)
1332-
const_tensor = True
1333-
else:
1357+
real_tensor = self.exported_program.constants[fqn]
1358+
is_user_input = False
1359+
elif fqn is not None:
13341360
buffers = self.exported_program.named_buffers()
13351361
buf = next((x[1] for x in buffers if x[0] == fqn), None)
13361362
if buf is not None:
1337-
spec = TensorSpec.from_tensor(buf, const=True)
1338-
const_tensor = True
1363+
real_tensor = buf
1364+
is_user_input = False
13391365
else:
13401366
raise InternalError(
13411367
self._emit_node_specific_error(
@@ -1344,13 +1370,28 @@ def placeholder(
13441370
)
13451371
)
13461372

1373+
# assign the storage of the placeholder spec to the storage of the real tensor if there is one
1374+
if real_tensor is not None:
1375+
# for non-contigous tensors, convert to a contiguous one
1376+
real_tensor = real_tensor.contiguous()
1377+
# Weights cannot be views during emission or serialization
1378+
if real_tensor.nbytes != real_tensor.untyped_storage().nbytes():
1379+
real_tensor = real_tensor.clone()
1380+
1381+
spec.storage = real_tensor.untyped_storage()
1382+
1383+
# User inputs and mutable buffers are not constants, other buffers or parameters are.
1384+
spec.const = not (is_user_input or is_mutable_buffer)
1385+
13471386
evalue = (
13481387
self._tensor_spec_to_evalue(spec)
13491388
if isinstance(spec, TensorSpec)
13501389
else self._constant_to_evalue(spec, None)
13511390
)
13521391
value = self._emit_evalue(evalue)
1353-
if not const_tensor:
1392+
1393+
# Only user inputs should remain as inputs.
1394+
if is_user_input:
13541395
self.inputs.append(value.id)
13551396

13561397
return value

exir/emit/test/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ python_unittest(
2121
"//executorch/exir/passes:constant_prop_pass",
2222
"//executorch/exir/tests:lib",
2323
"//executorch/exir/tests:models",
24-
"//executorch/extension/pybindings:portable_lib", # @manual
24+
"//executorch/extension/pybindings:aten_lib",
2525
],
2626
)

exir/emit/test/test_emit.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager, to_edge
2020
from executorch.exir.backend.backend_api import to_backend
2121
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
22+
from executorch.exir.dialects._ops import ops as exir_ops
2223
from executorch.exir.emit import emit_program # noqa
2324
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
2425
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
@@ -42,6 +43,7 @@
4243
)
4344
from executorch.exir.tests.common import register_additional_test_aten_ops
4445
from executorch.exir.tests.models import Mul
46+
from executorch.extension.pybindings.aten_lib import _load_for_executorch_from_buffer
4547
from functorch.experimental import control_flow
4648
from torch import nn
4749

@@ -1393,3 +1395,41 @@ def forward(self, x):
13931395
self.assertEqual(len(exec_plan.inputs), 1)
13941396
self.assertEqual(len(program.constant_buffer), 2)
13951397
self.assertEqual(len(program.constant_buffer[1].storage), 24)
1398+
1399+
def test_mutable_buffers(self) -> None:
1400+
def count_copies(gm: torch.fx.GraphModule) -> int:
1401+
return sum(
1402+
(
1403+
node.target == torch.ops.aten.copy_
1404+
or node.target == exir_ops.edge.aten.copy_.default
1405+
)
1406+
for node in gm.graph.nodes
1407+
)
1408+
1409+
class MutableStateModule(torch.nn.Module):
1410+
def __init__(self):
1411+
super().__init__()
1412+
self.register_buffer("state", torch.zeros(1))
1413+
1414+
def forward(self, x):
1415+
y = x + self.state
1416+
self.state.add_(1)
1417+
return y
1418+
1419+
model = to_edge(
1420+
export(
1421+
MutableStateModule(),
1422+
(torch.zeros(1),),
1423+
)
1424+
)
1425+
model = model.to_executorch()
1426+
model.dump_executorch_program(True)
1427+
self.assertTrue(
1428+
model.executorch_program.execution_plan[0] # pyre-ignore[16]
1429+
.values[0]
1430+
.val.allocation_info
1431+
is not None
1432+
)
1433+
executorch_module = _load_for_executorch_from_buffer(model.buffer)
1434+
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
1435+
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)

0 commit comments

Comments
 (0)