Skip to content

Commit 8f7d9d5

Browse files
helunwencserfacebook-github-bot
authored andcommitted
Allow mutating input tensor (#4850)
Summary: Pull Request resolved: #4850 To support dynamic kv cache, we need to pass in kv cache as an input tensor and update it inside the model. This PR allows mutating input tensor. imported-using-ghimport Test Plan: Imported from OSS Reviewed By: JacobSzwejbka Differential Revision: D61683366 Pulled By: helunwencser fbshipit-source-id: b480073d16ddcc624d12c23918a78dfca966e0dd
1 parent aebc2e3 commit 8f7d9d5

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed

exir/emit/test/test_emit.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,3 +1649,23 @@ def forward(self, x):
16491649
self.assertEqual(
16501650
pte_data.execution_plan, model.executorch_program.execution_plan
16511651
)
1652+
1653+
def test_mutate_input_tensor(self) -> None:
1654+
class MutateInputTensorModule(torch.nn.Module):
1655+
def __init__(self):
1656+
super().__init__()
1657+
1658+
def forward(self, x):
1659+
x.add_(1)
1660+
1661+
model = to_edge(
1662+
export(MutateInputTensorModule(), (torch.zeros(1),))
1663+
).to_executorch(
1664+
config=ExecutorchBackendConfig(
1665+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False)
1666+
)
1667+
)
1668+
executorch_model = _load_for_executorch_from_buffer(model.buffer)
1669+
input = torch.zeros(1)
1670+
executorch_model(input)
1671+
self.assertEqual(input, torch.ones(1))

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
OutputKind,
1616
OutputSpec,
1717
)
18+
from torch.export.graph_signature import TensorArgument
1819
from torch.utils import _pytree as pytree
1920

2021

@@ -73,20 +74,21 @@ def insert_write_back_for_buffers_pass(
7374
ep: ExportedProgram,
7475
) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]:
7576
gm: torch.fx.GraphModule = ep.graph_module
76-
lifted_inputs: List[Optional[str]] = [
77-
(
78-
in_spec.target
79-
if in_spec.kind
80-
in (
81-
InputKind.BUFFER,
82-
InputKind.CONSTANT_TENSOR,
83-
InputKind.PARAMETER,
84-
InputKind.CUSTOM_OBJ,
85-
)
86-
else None
87-
)
88-
for in_spec in ep.graph_signature.input_specs
89-
]
77+
lifted_inputs: List[Optional[str]] = []
78+
for in_spec in ep.graph_signature.input_specs:
79+
if in_spec.kind in (
80+
InputKind.BUFFER,
81+
InputKind.CONSTANT_TENSOR,
82+
InputKind.PARAMETER,
83+
InputKind.CUSTOM_OBJ,
84+
):
85+
lifted_inputs.append(in_spec.target)
86+
elif in_spec.kind is InputKind.USER_INPUT and isinstance(
87+
in_spec.arg, TensorArgument
88+
):
89+
lifted_inputs.append(in_spec.arg.name)
90+
else:
91+
lifted_inputs.append(None)
9092

9193
input_name_to_node: Dict[str, torch.fx.Node] = {}
9294

@@ -101,7 +103,8 @@ def insert_write_back_for_buffers_pass(
101103
mutated_outputs: List[Optional[str]] = [
102104
(
103105
out_spec.target
104-
if out_spec.kind in (OutputKind.BUFFER_MUTATION,)
106+
if out_spec.kind
107+
in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
105108
and out_spec.arg.name
106109
not in {
107110
val.name for val in input_name_to_node.values()
@@ -121,7 +124,10 @@ def insert_write_back_for_buffers_pass(
121124
new_output_specs: List[OutputSpec] = []
122125
i = 0
123126
for output_spec in ep.graph_signature.output_specs:
124-
if output_spec.kind == OutputKind.BUFFER_MUTATION:
127+
if output_spec.kind in (
128+
OutputKind.BUFFER_MUTATION,
129+
OutputKind.USER_INPUT_MUTATION,
130+
):
125131
output_spec.arg.name = buffer_output_nodes[i].name
126132
i += 1
127133
new_output_specs.append(output_spec)

0 commit comments

Comments
 (0)