Skip to content

Commit d66ce3f

Browse files
angelayifacebook-github-bot
authored andcommitted
Allow delegate to consume buffer mutations
Summary: Fixing #4209 Edge Program: ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Partitioned / lowered Exported Program (buffer mutation gets removed): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 3]"): # No stacktrace found for following nodes lowered_module_0 = self.lowered_module_0 executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x); lowered_module_0 = x = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b getitem_1: "f32[3, 3]" = executorch_call_delegate[0]; executorch_call_delegate = None return (getitem_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None)]) ``` Delegate (consumes the buffer mutation): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Differential Revision: D60838243
1 parent c7aff77 commit d66ce3f

File tree

3 files changed

+194
-6
lines changed

3 files changed

+194
-6
lines changed

exir/backend/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ python_unittest(
290290
"//executorch/exir/backend/test/demos/rpc:executor_backend_register",
291291
],
292292
deps = [
293+
":op_partitioner_demo",
293294
"//caffe2:torch",
294295
"//executorch/exir:lib",
295296
"//executorch/exir/backend:backend_details",

exir/backend/test/test_partitioner.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
2727
ExecutorBackend,
2828
)
29+
from executorch.exir.backend.test.op_partitioner_demo import AddAttributePartitionerDemo
2930
from executorch.exir.backend.utils import get_delegates, tag_constant_data
3031

3132
from executorch.exir.dialects._ops import ops as exir_ops
@@ -619,3 +620,90 @@ def partition(
619620
and node.target == torch.ops.aten.copy_.default
620621
]
621622
self.assertEqual(len(copy_node), 1)
623+
624+
def test_buffer_mutation1(self):
625+
class TestModule(torch.nn.Module):
626+
def __init__(self):
627+
super().__init__()
628+
self.register_buffer("b", torch.ones(3, 3))
629+
630+
def forward(self, x):
631+
self.b.add_(x)
632+
return x + self.b
633+
634+
model_inputs = (torch.ones(3, 3),)
635+
orig_res = TestModule()(*model_inputs)
636+
edge_program = exir.to_edge(torch.export.export(TestModule(), model_inputs))
637+
lowered = edge_program.to_backend(AddAttributePartitionerDemo())
638+
639+
self.assertTrue(
640+
torch.allclose(lowered.exported_program().module()(*model_inputs), orig_res)
641+
)
642+
643+
self.assertEqual(
644+
len(lowered.exported_program().graph_signature.buffers_to_mutate),
645+
0,
646+
)
647+
lowered_module_nodes = get_delegates(lowered.exported_program().graph)
648+
self.assertEqual(len(lowered_module_nodes), 1)
649+
lowered_module_node = lowered_module_nodes[0]
650+
651+
# get call delegate node
652+
call_delegate_node = list(lowered_module_node.users.keys())[0]
653+
self.assertEqual(len(call_delegate_node.args), 2)
654+
655+
lower_module = getattr(
656+
lowered.exported_program().graph_module, lowered_module_node.name
657+
)
658+
delegated_ep = lower_module.original_module
659+
660+
self.assertEqual(len(delegated_ep.state_dict), 1)
661+
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
662+
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)
663+
664+
def test_buffer_mutation2(self):
665+
SHAPE = (2, 3)
666+
667+
class Model(torch.nn.Module):
668+
def __init__(self):
669+
super().__init__()
670+
self.register_buffer("state_1", torch.zeros(SHAPE, dtype=torch.float32))
671+
672+
def forward(self, x):
673+
add = self.state_1.add_(x)
674+
return add
675+
676+
model = Model()
677+
model.eval()
678+
679+
example_inputs = (torch.randn(SHAPE),)
680+
exir_program_aten = torch.export.export(model, example_inputs)
681+
edge_program_manager = exir.to_edge(exir_program_aten)
682+
lowered = edge_program_manager.to_backend(AddAttributePartitionerDemo())
683+
684+
self.assertTrue(
685+
torch.allclose(
686+
lowered.exported_program().module()(*example_inputs), example_inputs[0]
687+
)
688+
)
689+
690+
self.assertEqual(
691+
len(lowered.exported_program().graph_signature.buffers_to_mutate),
692+
0,
693+
)
694+
lowered_module_nodes = get_delegates(lowered.exported_program().graph)
695+
self.assertEqual(len(lowered_module_nodes), 1)
696+
lowered_module_node = lowered_module_nodes[0]
697+
698+
# get call delegate node
699+
call_delegate_node = list(lowered_module_node.users.keys())[0]
700+
self.assertEqual(len(call_delegate_node.args), 2)
701+
702+
lower_module = getattr(
703+
lowered.exported_program().graph_module, lowered_module_node.name
704+
)
705+
delegated_ep = lower_module.original_module
706+
707+
self.assertEqual(len(delegated_ep.state_dict), 1)
708+
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
709+
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)

exir/lowered_backend_module.py

Lines changed: 105 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,16 @@ def _get_new_signature( # noqa: C901
488488
else {}
489489
)
490490

491+
toplevel_output_node_to_sig: Dict[str, OutputSpec] = (
492+
{
493+
output_spec.arg.name: output_spec
494+
for output_spec in old_signature.output_specs
495+
}
496+
if not is_submodule
497+
else {}
498+
)
499+
491500
for node in gm.graph.nodes:
492-
is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag
493501
if node.op == "placeholder":
494502

495503
if node.name not in input_node_to_sig:
@@ -507,7 +515,7 @@ def _get_new_signature( # noqa: C901
507515
if not isinstance(orig_input_spec.arg, TensorArgument):
508516
input_specs.append(orig_input_spec)
509517

510-
elif is_tagged:
518+
elif node.meta.get("delegation_tag", None) == tag:
511519
input_specs.append(orig_input_spec)
512520

513521
if orig_input_spec.kind == InputKind.USER_INPUT:
@@ -551,11 +559,55 @@ def _get_new_signature( # noqa: C901
551559
)
552560

553561
if node.op == "output":
554-
output_nodes = pytree.tree_leaves((node.args, node.kwargs))
562+
buffer_mutation_idxs: Dict[int, OutputSpec] = {}
563+
for user in call_module_node.users.keys():
564+
if user.name in toplevel_output_node_to_sig:
565+
assert (
566+
user.op == "call_function" and user.target == operator.getitem
567+
), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}"
568+
getitem_idx = user.args[1]
569+
assert isinstance(
570+
getitem_idx, int
571+
), f"Invalid getitem type: {type(getitem_idx)}"
572+
buffer_mutation_idxs[getitem_idx] = toplevel_output_node_to_sig[
573+
user.name
574+
]
575+
576+
for i, output_node in enumerate(node.args[0]):
577+
if i in buffer_mutation_idxs:
578+
assert isinstance(output_node, torch.fx.Node)
579+
orig_output_spec = buffer_mutation_idxs[i]
580+
581+
if (
582+
orig_output_spec.kind == OutputKind.BUFFER_MUTATION
583+
and orig_output_spec.target in new_state_dict
584+
):
585+
# If the delegate wants to consume the buffer, then
586+
# the delegate should also consume the buffer
587+
# mutation (output spec would be a BUFFER_MUTATION).
588+
# Otherwise the delegate will just return the result
589+
# of the mutation as a USER_OUTPUT.
590+
output_specs.append(
591+
OutputSpec(
592+
kind=OutputKind.BUFFER_MUTATION,
593+
arg=TensorArgument(name=output_node.name),
594+
target=orig_output_spec.target,
595+
)
596+
)
597+
output_specs_to_delete[orig_output_spec.arg.name] = (
598+
orig_output_spec
599+
)
555600

556-
for output_node in output_nodes:
601+
else:
602+
output_specs.append(
603+
OutputSpec(
604+
kind=OutputKind.USER_OUTPUT,
605+
arg=TensorArgument(name=output_node.name),
606+
target=None,
607+
)
608+
)
557609

558-
if not isinstance(output_node, torch.fx.Node):
610+
elif not isinstance(output_node, torch.fx.Node):
559611
output_specs.append(
560612
OutputSpec(
561613
kind=OutputKind.USER_OUTPUT,
@@ -774,7 +826,7 @@ def get_lowered_backend_modules(
774826
return lowered_programs
775827

776828

777-
def _unsafe_adjust_original_program(
829+
def _unsafe_adjust_original_program( # noqa: C901
778830
original_program: ExportedProgram,
779831
call_delegate_node: torch.fx.Node,
780832
input_specs_to_delete: Dict[str, InputSpec],
@@ -830,3 +882,50 @@ def _unsafe_adjust_original_program(
830882
del original_program._constants[input_spec.target]
831883
else:
832884
raise RuntimeError(f"Invalid input spec {input_spec} received")
885+
886+
# Delete buffer mutations from the output which were consumed by the delegate
887+
toplevel_output_node = None
888+
for node in reversed(original_program.graph.nodes):
889+
if node.op == "output":
890+
toplevel_output_node = node
891+
break
892+
893+
assert toplevel_output_node is not None
894+
assert (
895+
len(toplevel_output_node.args) == 1
896+
), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}"
897+
898+
new_output_args = [
899+
arg
900+
for arg in toplevel_output_node.args[0]
901+
if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete
902+
]
903+
toplevel_output_node.args = (tuple(new_output_args),)
904+
905+
# Delete the buffer mutation getitem nodes
906+
getitem_idxs: List[int] = []
907+
user_nodes = list(call_delegate_node.users.keys())
908+
for user in user_nodes:
909+
if user.name in output_specs_to_delete:
910+
assert (
911+
user.op == "call_function" and user.target == operator.getitem
912+
), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}"
913+
user_idx = user.args[1]
914+
assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}"
915+
getitem_idxs.append(user_idx)
916+
original_program.graph.erase_node(user)
917+
918+
getitem_idxs.sort(reverse=True)
919+
920+
# Adjust all the getitem indices after the deleted getitems
921+
user_nodes = list(call_delegate_node.users.keys())
922+
for user in user_nodes:
923+
assert user.op == "call_function" and user.target == operator.getitem
924+
user_idx = user.args[1]
925+
assert isinstance(user_idx, int)
926+
for i, idx in enumerate(getitem_idxs):
927+
if user_idx > idx:
928+
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
929+
break
930+
931+
original_program._validate()

0 commit comments

Comments
 (0)