Skip to content

emit programs with mutable buffers #2233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
)
from executorch.exir.tensor import layout_enum, scalar_type_enum
from executorch.exir.version import EXECUTORCH_SCHEMA_VERSION
from torch.export.exported_program import ExportedProgram
from torch.export.exported_program import ExportedProgram, OutputKind
from torch.utils import _pytree as pytree


def _emit_prim_getters(prim_getters: Dict[str, Any]) -> List[ExecutionPlan]:
Expand Down Expand Up @@ -122,6 +123,36 @@ class EmitterOutput:
]


def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
gm = exported_program.graph_module
output_node = None
for node in gm.graph.nodes:
if node.op == "output":
output_node = node
assert output_node is not None

mutated_outputs: List[Optional[str]] = [
out_spec.target if out_spec.kind in (OutputKind.BUFFER_MUTATION,) else None
for out_spec in exported_program.graph_signature.output_specs
]
outputs = pytree.tree_flatten(output_node.args)[0]

user_output_nodes = []
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
if mutated_node_name is None:
user_output_nodes.append(return_node)
continue

with gm.graph.inserting_before(output_node):
# Only return user outputs
new_output = gm.graph.output(tuple(user_output_nodes))
new_output.meta = output_node.meta.copy()
output_node.replace_all_uses_with(new_output)
gm.graph.erase_node(output_node)

return gm


def emit_program(
methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
emit_stacktrace: bool = False,
Expand Down Expand Up @@ -163,13 +194,6 @@ def emit_program(

# emit each entry point in order according to name.
for name, exported_program in sorted(methods.items()):
if (
exported_program.graph_signature.buffers_to_mutate
): # see if we are mutating any state
raise ExportError(
ExportErrorType.INVALID_INPUT_TYPE,
"Buffers cannot be modified in executorch.",
)
# create empty state
emitter_state = _EmitterState(
values=[],
Expand All @@ -180,7 +204,11 @@ def emit_program(
emit_stacktrace=emit_stacktrace,
)

emitter = _TopLevelEmitter(name, exported_program, program_state, emitter_state)
gm = _remove_non_user_outputs(exported_program)

emitter = _TopLevelEmitter(
name, exported_program, gm, program_state, emitter_state
)

emitter.run()
plans.append(emitter.plan())
Expand Down
105 changes: 73 additions & 32 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
import hashlib
import operator
import typing
import warnings
from dataclasses import dataclass, field
from typing import Callable, cast, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Tuple, Union

import executorch.exir.memory as memory
import executorch.extension.pytree as ex_pytree
Expand Down Expand Up @@ -1266,15 +1267,17 @@ def __init__(
self,
name: str,
exported_program: ExportedProgram,
graph_module: torch.fx.GraphModule,
program_state: _ProgramState,
emitter_state: _EmitterState,
) -> None:
super().__init__(exported_program.graph_module, emitter_state, program_state)
super().__init__(graph_module, emitter_state, program_state)
self.name = name
self.exported_program = exported_program

self.inputs: List[int] = []
self.outputs: List[int] = []
self.given_mutable_buffer_warning = False

def create_container_str(spec: Optional[pytree.TreeSpec]) -> str:
if spec is None:
Expand All @@ -1293,6 +1296,42 @@ def create_container_str(spec: Optional[pytree.TreeSpec]) -> str:
inp_container_str, out_container_str
)

def _find_fqn_for_placeholder(
self, target: _Target, spec: Any # pyre-ignore[2]
) -> Tuple[Optional[str], bool]:
# Find the fully qualified name
fqn = None
is_mutable_buffer = False
if target in self.exported_program.graph_signature.inputs_to_parameters:
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]

elif target in self.exported_program.graph_signature.inputs_to_buffers:
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]

# if the buffer is mutated then record that
if fqn in self.exported_program.graph_signature.buffers_to_mutate.values():
is_mutable_buffer = True
if not self.given_mutable_buffer_warning:
warnings.warn(
"Mutation on a buffer in the model is detected. ExecuTorch assumes "
"buffers that are mutated in the graph have a meaningless initial state, "
"only the shape and dtype will be serialized.",
UserWarning,
stacklevel=1,
)
self.given_mutable_buffer_warning = True

elif (
target
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
):
fqn = (
self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
target
]
)
return fqn, is_mutable_buffer

def placeholder(
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
) -> _AbstractValue:
Expand All @@ -1302,40 +1341,27 @@ def placeholder(
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
"""
spec = self.node.meta["spec"]
const_tensor = False
if isinstance(target, str) and (
target in self.exported_program.graph_signature.inputs_to_parameters
or target in self.exported_program.graph_signature.inputs_to_buffers
or target
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
):
if (
target
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
):
fqn = self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
target
]
elif target in self.exported_program.graph_signature.inputs_to_buffers:
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
else:
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
is_user_input = True

if isinstance(target, str) and isinstance(spec, TensorSpec):

fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)

# From the fqn find the corresponding tensor
real_tensor = None
if fqn in self.exported_program.state_dict:
spec = TensorSpec.from_tensor(
self.exported_program.state_dict[fqn], const=True
)
const_tensor = True
real_tensor = self.exported_program.state_dict[fqn]
is_user_input = False

elif fqn in self.exported_program.constants:
spec = TensorSpec.from_tensor(
self.exported_program.constants[fqn], const=True
)
const_tensor = True
else:
real_tensor = self.exported_program.constants[fqn]
is_user_input = False
elif fqn is not None:
buffers = self.exported_program.named_buffers()
buf = next((x[1] for x in buffers if x[0] == fqn), None)
if buf is not None:
spec = TensorSpec.from_tensor(buf, const=True)
const_tensor = True
real_tensor = buf
is_user_input = False
else:
raise InternalError(
self._emit_node_specific_error(
Expand All @@ -1344,13 +1370,28 @@ def placeholder(
)
)

# assign the storage of the placeholder spec to the storage of the real tensor if there is one
if real_tensor is not None:
# for non-contigous tensors, convert to a contiguous one
real_tensor = real_tensor.contiguous()
# Weights cannot be views during emission or serialization
if real_tensor.nbytes != real_tensor.untyped_storage().nbytes():
real_tensor = real_tensor.clone()

spec.storage = real_tensor.untyped_storage()

# User inputs and mutable buffers are not constants, other buffers or parameters are.
spec.const = not (is_user_input or is_mutable_buffer)

evalue = (
self._tensor_spec_to_evalue(spec)
if isinstance(spec, TensorSpec)
else self._constant_to_evalue(spec, None)
)
value = self._emit_evalue(evalue)
if not const_tensor:

# Only user inputs should remain as inputs.
if is_user_input:
self.inputs.append(value.id)

return value
Expand Down
2 changes: 1 addition & 1 deletion exir/emit/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ python_unittest(
"//executorch/exir/passes:constant_prop_pass",
"//executorch/exir/tests:lib",
"//executorch/exir/tests:models",
"//executorch/extension/pybindings:portable_lib", # @manual
"//executorch/extension/pybindings:aten_lib",
],
)
40 changes: 40 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager, to_edge
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.emit import emit_program # noqa
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
Expand All @@ -42,6 +43,7 @@
)
from executorch.exir.tests.common import register_additional_test_aten_ops
from executorch.exir.tests.models import Mul
from executorch.extension.pybindings.aten_lib import _load_for_executorch_from_buffer
from functorch.experimental import control_flow
from torch import nn

Expand Down Expand Up @@ -1393,3 +1395,41 @@ def forward(self, x):
self.assertEqual(len(exec_plan.inputs), 1)
self.assertEqual(len(program.constant_buffer), 2)
self.assertEqual(len(program.constant_buffer[1].storage), 24)

def test_mutable_buffers(self) -> None:
def count_copies(gm: torch.fx.GraphModule) -> int:
return sum(
(
node.target == torch.ops.aten.copy_
or node.target == exir_ops.edge.aten.copy_.default
)
for node in gm.graph.nodes
)

class MutableStateModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("state", torch.zeros(1))

def forward(self, x):
y = x + self.state
self.state.add_(1)
return y

model = to_edge(
export(
MutableStateModule(),
(torch.zeros(1),),
)
)
model = model.to_executorch()
model.dump_executorch_program(True)
self.assertTrue(
model.executorch_program.execution_plan[0] # pyre-ignore[16]
.values[0]
.val.allocation_info
is not None
)
executorch_module = _load_for_executorch_from_buffer(model.buffer)
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)