Skip to content

Commit 1fed8e9

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
emit programs with mutable buffers
Summary: Meaningful changes to the emitter logic here. Before we would ignore the tensor spec passed in and try to decide if the placeholder was a constant and if it was we would create a new spec from the actual value for that constant. That drops meta data on the input spec which is not great. Now instead of that we just look up the storage of the concrete tensor and hook it up to the spec. Also added some logic to seperate out behavior for mutable buffers specifically. While working on this I also discovered a bug that memory planning is planning space for parameters and constant buffers if its told to allocate space for inputs which is really bad lol. Oh one big assumption this diff makes is that the buffer does not have a meaningful initial state. I should probably throw out a warning during emission about this in the short term. Long term we will handle them properly. Reviewed By: tarun292 Differential Revision: D53713544
1 parent ec54b4c commit 1fed8e9

File tree

4 files changed

+133
-44
lines changed

4 files changed

+133
-44
lines changed

exir/emit/_emit_program.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
)
3333
from executorch.exir.tensor import layout_enum, scalar_type_enum
3434
from executorch.exir.version import EXECUTORCH_SCHEMA_VERSION
35-
from torch.export.exported_program import ExportedProgram
35+
from torch.export.exported_program import ExportedProgram, OutputKind
36+
from torch.utils import _pytree as pytree
3637

3738

3839
def _emit_prim_getters(prim_getters: Dict[str, Any]) -> List[ExecutionPlan]:
@@ -122,6 +123,35 @@ class EmitterOutput:
122123
]
123124

124125

126+
def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
127+
gm = exported_program.graph_module
128+
output_node = None
129+
for node in gm.graph.nodes:
130+
if node.op == "output":
131+
output_node = node
132+
assert output_node is not None
133+
134+
mutated_outputs: List[Optional[str]] = [
135+
out_spec.target if out_spec.kind in (OutputKind.BUFFER_MUTATION,) else None
136+
for out_spec in exported_program.graph_signature.output_specs
137+
]
138+
outputs = pytree.tree_flatten(output_node.args)[0]
139+
140+
user_output_nodes = []
141+
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
142+
if mutated_node_name is None:
143+
user_output_nodes.append(return_node)
144+
continue
145+
146+
with gm.graph.inserting_before(output_node):
147+
# Only return user outputs
148+
new_output = gm.graph.output(tuple(user_output_nodes))
149+
output_node.replace_all_uses_with(new_output)
150+
gm.graph.erase_node(output_node)
151+
152+
return gm
153+
154+
125155
def emit_program(
126156
methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
127157
emit_stacktrace: bool = False,
@@ -163,13 +193,6 @@ def emit_program(
163193

164194
# emit each entry point in order according to name.
165195
for name, exported_program in sorted(methods.items()):
166-
if (
167-
exported_program.graph_signature.buffers_to_mutate
168-
): # see if we are mutating any state
169-
raise ExportError(
170-
ExportErrorType.INVALID_INPUT_TYPE,
171-
"Buffers cannot be modified in executorch.",
172-
)
173196
# create empty state
174197
emitter_state = _EmitterState(
175198
values=[],
@@ -180,7 +203,11 @@ def emit_program(
180203
emit_stacktrace=emit_stacktrace,
181204
)
182205

183-
emitter = _TopLevelEmitter(name, exported_program, program_state, emitter_state)
206+
gm = _remove_non_user_outputs(exported_program)
207+
208+
emitter = _TopLevelEmitter(
209+
name, exported_program, gm, program_state, emitter_state
210+
)
184211

185212
emitter.run()
186213
plans.append(emitter.plan())

exir/emit/_emitter.py

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,15 +1243,15 @@ def run_node(self, n: torch.fx.Node) -> None:
12431243
https://pytorch.org/docs/stable/fx.html#torch.fx.Node
12441244
"""
12451245
self.node = n
1246-
try:
1247-
ret = super().run_node(n)
1248-
except Exception as e:
1249-
if isinstance(e, (InternalError, ExportError)):
1250-
raise e
1251-
else:
1252-
raise InternalError(
1253-
self._emit_node_specific_error(self.node, str(e))
1254-
) from e
1246+
# try:
1247+
ret = super().run_node(n)
1248+
# except Exception as e:
1249+
# if isinstance(e, (InternalError, ExportError)):
1250+
# raise e
1251+
# else:
1252+
# raise InternalError(
1253+
# self._emit_node_specific_error(self.node, str(e))
1254+
# ) from e
12551255
return ret
12561256

12571257

@@ -1266,10 +1266,11 @@ def __init__(
12661266
self,
12671267
name: str,
12681268
exported_program: ExportedProgram,
1269+
graph_module: torch.fx.GraphModule,
12691270
program_state: _ProgramState,
12701271
emitter_state: _EmitterState,
12711272
) -> None:
1272-
super().__init__(exported_program.graph_module, emitter_state, program_state)
1273+
super().__init__(graph_module, emitter_state, program_state)
12731274
self.name = name
12741275
self.exported_program = exported_program
12751276

@@ -1302,40 +1303,48 @@ def placeholder(
13021303
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
13031304
"""
13041305
spec = self.node.meta["spec"]
1305-
const_tensor = False
1306-
if isinstance(target, str) and (
1307-
target in self.exported_program.graph_signature.inputs_to_parameters
1308-
or target in self.exported_program.graph_signature.inputs_to_buffers
1309-
or target
1310-
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
1311-
):
1312-
if (
1306+
is_user_input = True
1307+
1308+
if isinstance(target, str) and isinstance(spec, TensorSpec):
1309+
# Find the fully qualified name
1310+
fqn = None
1311+
is_mutable_buffer = False
1312+
if target in self.exported_program.graph_signature.inputs_to_parameters:
1313+
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
1314+
1315+
elif target in self.exported_program.graph_signature.inputs_to_buffers:
1316+
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
1317+
1318+
# if the buffer is mutated then record that
1319+
if (
1320+
fqn
1321+
in self.exported_program.graph_signature.buffers_to_mutate.values()
1322+
):
1323+
is_mutable_buffer = True
1324+
1325+
elif (
13131326
target
13141327
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
13151328
):
13161329
fqn = self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
13171330
target
13181331
]
1319-
elif target in self.exported_program.graph_signature.inputs_to_buffers:
1320-
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
1321-
else:
1322-
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
1332+
1333+
# From the fqn find the corresponding tensor
1334+
real_tensor = None
13231335
if fqn in self.exported_program.state_dict:
1324-
spec = TensorSpec.from_tensor(
1325-
self.exported_program.state_dict[fqn], const=True
1326-
)
1327-
const_tensor = True
1336+
real_tensor = self.exported_program.state_dict[fqn]
1337+
is_user_input = False
1338+
13281339
elif fqn in self.exported_program.constants:
1329-
spec = TensorSpec.from_tensor(
1330-
self.exported_program.constants[fqn], const=True
1331-
)
1332-
const_tensor = True
1333-
else:
1340+
real_tensor = self.exported_program.constants[fqn]
1341+
is_user_input = False
1342+
elif fqn is not None:
13341343
buffers = self.exported_program.named_buffers()
13351344
buf = next((x[1] for x in buffers if x[0] == fqn), None)
13361345
if buf is not None:
1337-
spec = TensorSpec.from_tensor(buf, const=True)
1338-
const_tensor = True
1346+
real_tensor = buf
1347+
is_user_input = False
13391348
else:
13401349
raise InternalError(
13411350
self._emit_node_specific_error(
@@ -1344,13 +1353,28 @@ def placeholder(
13441353
)
13451354
)
13461355

1356+
# assign the storage of the placeholder spec to the storage of the real tensor if there is one
1357+
if real_tensor is not None:
1358+
# for non-contigous tensors, convert to a contiguous one
1359+
real_tensor = real_tensor.contiguous()
1360+
# Weights cannot be views during emission or serialization
1361+
if real_tensor.nbytes != real_tensor.untyped_storage().nbytes():
1362+
real_tensor = real_tensor.clone()
1363+
1364+
spec.storage = real_tensor.untyped_storage()
1365+
1366+
# User inputs and mutable buffers are not constants, other buffers or parameters are.
1367+
spec.const = not (is_user_input or is_mutable_buffer)
1368+
13471369
evalue = (
13481370
self._tensor_spec_to_evalue(spec)
13491371
if isinstance(spec, TensorSpec)
13501372
else self._constant_to_evalue(spec, None)
13511373
)
13521374
value = self._emit_evalue(evalue)
1353-
if not const_tensor:
1375+
1376+
# Only user inputs should remain as inputs.
1377+
if is_user_input:
13541378
self.inputs.append(value.id)
13551379

13561380
return value

exir/emit/test/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ python_unittest(
2121
"//executorch/exir/passes:constant_prop_pass",
2222
"//executorch/exir/tests:lib",
2323
"//executorch/exir/tests:models",
24-
"//executorch/extension/pybindings:portable_lib", # @manual
24+
"//executorch/extension/pybindings:aten_lib",
2525
],
2626
)

exir/emit/test/test_emit.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager, to_edge
2020
from executorch.exir.backend.backend_api import to_backend
2121
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
22+
from executorch.exir.dialects._ops import ops as exir_ops
2223
from executorch.exir.emit import emit_program # noqa
2324
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
2425
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
@@ -42,6 +43,7 @@
4243
)
4344
from executorch.exir.tests.common import register_additional_test_aten_ops
4445
from executorch.exir.tests.models import Mul
46+
from executorch.extension.pybindings.aten_lib import _load_for_executorch_from_buffer
4547
from functorch.experimental import control_flow
4648
from torch import nn
4749

@@ -1393,3 +1395,39 @@ def forward(self, x):
13931395
self.assertEqual(len(exec_plan.inputs), 1)
13941396
self.assertEqual(len(program.constant_buffer), 2)
13951397
self.assertEqual(len(program.constant_buffer[1].storage), 24)
1398+
1399+
def test_mutable_buffers(self) -> None:
1400+
def count_copies(gm: torch.fx.GraphModule) -> int:
1401+
return sum(
1402+
(
1403+
node.target == torch.ops.aten.copy_
1404+
or node.target == exir_ops.edge.aten.copy_.default
1405+
)
1406+
for node in gm.graph.nodes
1407+
)
1408+
1409+
class MutableStateModule(torch.nn.Module):
1410+
def __init__(self):
1411+
super().__init__()
1412+
self.register_buffer("state", torch.zeros(1))
1413+
1414+
def forward(self, x):
1415+
y = x + self.state
1416+
self.state.add_(1)
1417+
return y
1418+
1419+
model = to_edge(
1420+
export(
1421+
MutableStateModule(),
1422+
(torch.zeros(1),),
1423+
)
1424+
)
1425+
model = model.to_executorch()
1426+
model.dump_executorch_program(True)
1427+
self.assertTrue(
1428+
model.executorch_program.execution_plan[0].values[0].val.allocation_info # pyre-ignore[16]
1429+
is not None
1430+
)
1431+
executorch_module = _load_for_executorch_from_buffer(model.buffer)
1432+
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
1433+
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)

0 commit comments

Comments
 (0)