Skip to content

Commit c857cc3

Browse files
Fix write back to buffers for buffer mutation (#7810)
* Fixed missing writeback copy opperation in insert_write_back_for_buffers_pass for the case of copying data directly from one input to another. Also converted the list comprehension to a for loop for readablity. * Add unit test. * Fix linter errors.
1 parent edbdbfb commit c857cc3

File tree

2 files changed

+37
-28
lines changed

2 files changed

+37
-28
lines changed

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,23 @@ def insert_write_back_for_buffers_pass(
100100
input_name_to_node[lifted_node] = input_node
101101

102102
# Grab the mutable buffer nodes in the outputs,
103-
mutated_outputs: List[Optional[str]] = [
104-
(
105-
out_spec.target
106-
if out_spec.kind
103+
mutated_outputs: List[Optional[str]] = []
104+
for out_spec in ep.graph_signature.output_specs:
105+
# if the output arg is the input value then all operations on it are in-place
106+
# so there's no need to add a copy_ node
107+
if (
108+
out_spec.kind
107109
in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
108-
and out_spec.arg.name
109-
not in {
110-
val.name for val in input_name_to_node.values()
111-
} # if the output arg is the input value then all operations on it are in-place so theres no need to add a copy_ node
112-
else None
113-
)
114-
for out_spec in ep.graph_signature.output_specs
115-
]
110+
and
111+
# explicitly check if target exists (it should always be there)
112+
out_spec.target in input_name_to_node
113+
and
114+
# if the arg and target are not the same, we add a copy_ node.
115+
out_spec.arg.name != input_name_to_node[out_spec.target].name
116+
):
117+
mutated_outputs.append(out_spec.target)
118+
else:
119+
mutated_outputs.append(None)
116120

117121
# insert the copy ops and update the outputs
118122
buffer_output_nodes = _insert_copy(gm, mutated_outputs, input_name_to_node)

exir/tests/test_passes.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,36 +1291,41 @@ class MutableStateModule(torch.nn.Module):
12911291
def __init__(self):
12921292
super().__init__()
12931293
self.register_buffer("state", torch.zeros(1))
1294+
self.register_buffer("direct_copy_from_input", torch.zeros(1))
12941295

12951296
def forward(self, x):
12961297
y = x + self.state
12971298
self.state.add_(1)
1299+
self.direct_copy_from_input.copy_(x)
12981300
return y
12991301

13001302
model = to_edge(export(MutableStateModule(), (torch.zeros(1),), strict=True))
13011303
self.assertEqual(count_copies(model.exported_program().graph_module), 0)
13021304
# Before
13031305
# graph():
1304-
# %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
1305-
# %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1306-
# %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1307-
# %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1308-
# %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1309-
# %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1310-
# return (aten_add_tensor_1, aten_add_tensor)
1306+
# %b_state : [num_users=2] = placeholder[target=b_state]
1307+
# %b_direct_copy_from_input : [num_users=0] = placeholder[target=b_direct_copy_from_input]
1308+
# %_lifted_tensor_constant2 : [num_users=1] = placeholder[target=_lifted_tensor_constant2]
1309+
# %x : [num_users=2] = placeholder[target=x]
1310+
# %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%x, %b_state), kwargs = {})
1311+
# %dim_order_ops__to_dim_order_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.dim_order_ops._to_dim_order_copy.default](args = (%_lifted_tensor_constant2,), kwargs = {dtype: torch.float32, dim_order: []})
1312+
# %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%b_state, %dim_order_ops__to_dim_order_copy_default), kwargs = {})
1313+
# return (aten_add_tensor_1, x, aten_add_tensor)
13111314
gm, _ = insert_write_back_for_buffers_pass(model.exported_program())
13121315

13131316
# After
13141317
# graph():
1315-
# %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
1316-
# %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1317-
# %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1318-
# %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1319-
# %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1320-
# %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1321-
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
1322-
# return (copy__default, aten_add_tensor)
1323-
self.assertEqual(count_copies(gm), 1)
1318+
# %b_state : [num_users=3] = placeholder[target=b_state]
1319+
# %b_direct_copy_from_input : [num_users=1] = placeholder[target=b_direct_copy_from_input]
1320+
# %_lifted_tensor_constant2 : [num_users=1] = placeholder[target=_lifted_tensor_constant2]
1321+
# %x : [num_users=2] = placeholder[target=x]
1322+
# %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%x, %b_state), kwargs = {})
1323+
# %dim_order_ops__to_dim_order_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.dim_order_ops._to_dim_order_copy.default](args = (%_lifted_tensor_constant2,), kwargs = {dtype: torch.float32, dim_order: []})
1324+
# %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%b_state, %dim_order_ops__to_dim_order_copy_default), kwargs = {})
1325+
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%b_state, %aten_add_tensor_1), kwargs = {})
1326+
# %copy__default_1 : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%b_direct_copy_from_input, %x), kwargs = {})
1327+
# return (copy__default, copy__default_1, aten_add_tensor)
1328+
self.assertEqual(count_copies(gm), 2)
13241329

13251330
def test_remove_quantized_op_noop_pass(self) -> None:
13261331
class TestAddSliceNoop(torch.nn.Module):

0 commit comments

Comments
 (0)