Skip to content

Commit 843f6d0

Browse files
angelayifacebook-github-bot
authored andcommitted
Allow delegate to consume buffer mutations (#4830)
Summary: Pull Request resolved: #4830 Fixing #4209 Edge Program: ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Partitioned / lowered Exported Program (buffer mutation gets removed): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 3]"): # No stacktrace found for following nodes lowered_module_0 = self.lowered_module_0 executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x); lowered_module_0 = x = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b getitem_1: "f32[3, 3]" = executorch_call_delegate[0]; executorch_call_delegate = None return (getitem_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None)]) ``` Delegate (consumes the buffer mutation): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Differential Revision: D60838243
1 parent e636ef6 commit 843f6d0

File tree

6 files changed

+289
-8
lines changed

6 files changed

+289
-8
lines changed

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def lower_module_and_test_output(
229229
compile_specs = [CompileSpec("use_fp16", bytes([use_fp16]))]
230230

231231
if use_partitioner:
232-
logging.info(f"Edge IR graph:\n{edge_program.exported_program().graph}")
232+
logging.info(f"Edge IR graph:\n{edge_program.exported_program()}")
233233
delegated_program = edge_program
234234
delegated_program = edge_program.to_backend(
235235
MPSPartitioner(compile_specs=compile_specs)

exir/backend/test/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ python_library(
8888
"//executorch/exir/backend:compile_spec_schema",
8989
"//executorch/exir/backend:partitioner",
9090
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
91+
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
92+
"//executorch/exir/backend/test/demos/rpc:executor_backend_preprocess",
9193
"//executorch/exir/dialects:lib",
9294
],
9395
)
@@ -290,6 +292,7 @@ python_unittest(
290292
"//executorch/exir/backend/test/demos/rpc:executor_backend_register",
291293
],
292294
deps = [
295+
":op_partitioner_demo",
293296
"//caffe2:torch",
294297
"//executorch/exir:lib",
295298
"//executorch/exir/backend:backend_details",

exir/backend/test/op_partitioner_demo.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from executorch.exir.backend.test.backend_with_compiler_demo import (
2222
BackendWithCompilerDemo,
2323
)
24+
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
25+
ExecutorBackend,
26+
)
2427
from executorch.exir.dialects._ops import ops as exir_ops
2528
from executorch.exir.graph_module import get_control_flow_submodules
2629
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
@@ -29,6 +32,11 @@
2932
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
3033

3134

35+
class AllOperatorSupport(OperatorSupportBase):
36+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
37+
return node.op == "call_function"
38+
39+
3240
class AddOperatorSupport(OperatorSupportBase):
3341
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
3442
return node.op == "call_function" and node.target in [
@@ -126,6 +134,48 @@ def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
126134
)
127135

128136

137+
@final
138+
class AllNodesPartitionerDemo(Partitioner):
139+
"""
140+
Partitions all nodes
141+
"""
142+
143+
def __init__(self) -> None:
144+
self.op_support = AllOperatorSupport()
145+
self.delegation_spec = DelegationSpec(ExecutorBackend.__name__, [])
146+
147+
def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
148+
partition_tags = {}
149+
partition_list = generate_pattern_op_partitions(
150+
edge_exported_program.graph_module, op_support=self.op_support
151+
)
152+
for partition in partition_list:
153+
for node in partition.nodes:
154+
delegation_tag = f"tag{partition.id}"
155+
partition_tags[delegation_tag] = self.delegation_spec
156+
157+
# Tag the add nodes
158+
node.meta["delegation_tag"] = delegation_tag
159+
160+
for arg_node in node.args:
161+
if not isinstance(arg_node, torch.fx.Node):
162+
continue
163+
164+
is_get_attr = arg_node.op == "get_attr"
165+
is_param_buffer = arg_node.op == "placeholder" and (
166+
is_param(edge_exported_program, arg_node)
167+
or is_buffer(edge_exported_program, arg_node)
168+
or is_lifted_tensor_constant(edge_exported_program, arg_node)
169+
)
170+
if is_get_attr or is_param_buffer:
171+
arg_node.meta["delegation_tag"] = delegation_tag
172+
# Add to the list of partitioned nodes.
173+
174+
return PartitionResult(
175+
tagged_exported_program=edge_exported_program, partition_tags=partition_tags
176+
)
177+
178+
129179
ops_not_to_decompose = [
130180
torch.ops.aten.linear.default,
131181
torch.ops.aten.scaled_dot_product_attention.default,

exir/backend/test/test_partitioner.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
2727
ExecutorBackend,
2828
)
29+
from executorch.exir.backend.test.op_partitioner_demo import (
30+
AddAttributePartitionerDemo,
31+
AllNodesPartitionerDemo,
32+
)
2933
from executorch.exir.backend.utils import get_delegates, tag_constant_data
3034

3135
from executorch.exir.dialects._ops import ops as exir_ops
@@ -619,3 +623,111 @@ def partition(
619623
and node.target == torch.ops.aten.copy_.default
620624
]
621625
self.assertEqual(len(copy_node), 1)
626+
627+
def test_buffer_mutation1(self):
628+
class TestModule(torch.nn.Module):
629+
def __init__(self):
630+
super().__init__()
631+
self.register_buffer("b", torch.ones(3, 3))
632+
633+
def forward(self, x):
634+
self.b.add_(x)
635+
return x + self.b
636+
637+
model_inputs = (torch.ones(3, 3),)
638+
orig_res = TestModule()(*model_inputs)
639+
edge_program = exir.to_edge(torch.export.export(TestModule(), model_inputs))
640+
lowered = edge_program.to_backend(AddAttributePartitionerDemo())
641+
642+
self.assertTrue(
643+
torch.allclose(lowered.exported_program().module()(*model_inputs), orig_res)
644+
)
645+
646+
self.assertEqual(
647+
len(lowered.exported_program().graph_signature.buffers_to_mutate),
648+
0,
649+
)
650+
lowered_module_nodes = get_delegates(lowered.exported_program().graph)
651+
self.assertEqual(len(lowered_module_nodes), 1)
652+
lowered_module_node = lowered_module_nodes[0]
653+
654+
# get call delegate node
655+
call_delegate_node = list(lowered_module_node.users.keys())[0]
656+
self.assertEqual(len(call_delegate_node.args), 2)
657+
658+
lower_module = getattr(
659+
lowered.exported_program().graph_module, lowered_module_node.name
660+
)
661+
delegated_ep = lower_module.original_module
662+
663+
self.assertEqual(len(delegated_ep.state_dict), 1)
664+
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
665+
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)
666+
667+
def test_buffer_mutation_llama_repro(self):
668+
SHAPE = (2, 3)
669+
670+
class Model(torch.nn.Module):
671+
def __init__(self):
672+
super().__init__()
673+
self.register_buffer("cache", torch.zeros(SHAPE, dtype=torch.float32))
674+
675+
def forward(self, q, k_val, input_pos):
676+
q_T = q.transpose(0, 1)
677+
k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val)
678+
attn = k.mm(q_T)
679+
return attn
680+
681+
q = torch.rand(1, 3)
682+
k = torch.rand(1, 3)
683+
example_inputs = (q, k, torch.tensor([1, 1]))
684+
685+
model = Model()
686+
model.eval()
687+
688+
exir_program_aten = torch.export.export(model, example_inputs)
689+
exir_program_aten.module()(*example_inputs)
690+
edge_program_manager = exir.to_edge(exir_program_aten)
691+
lowered = edge_program_manager.to_backend(AllNodesPartitionerDemo())
692+
693+
self.assertEqual(
694+
len(lowered.exported_program().graph_signature.buffers_to_mutate),
695+
0,
696+
)
697+
lowered_module_nodes = get_delegates(lowered.exported_program().graph)
698+
self.assertEqual(len(lowered_module_nodes), 1)
699+
lowered_module_node = lowered_module_nodes[0]
700+
701+
# get call delegate node
702+
call_delegate_node = list(lowered_module_node.users.keys())[0]
703+
self.assertEqual(len(call_delegate_node.args), 4)
704+
705+
lower_module = getattr(
706+
lowered.exported_program().graph_module, lowered_module_node.name
707+
)
708+
delegated_ep = lower_module.original_module
709+
710+
self.assertEqual(len(delegated_ep.state_dict), 1)
711+
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
712+
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)
713+
714+
def test_buffer_mutation_unsupported(self):
715+
SHAPE = (2, 3)
716+
717+
class Model(torch.nn.Module):
718+
def __init__(self):
719+
super().__init__()
720+
self.register_buffer("state_1", torch.zeros(SHAPE, dtype=torch.float32))
721+
722+
def forward(self, x):
723+
add = self.state_1.add_(x)
724+
return add
725+
726+
model = Model()
727+
model.eval()
728+
729+
example_inputs = (torch.randn(SHAPE),)
730+
exir_program_aten = torch.export.export(model, example_inputs)
731+
edge_program_manager = exir.to_edge(exir_program_aten)
732+
with self.assertRaises(AssertionError):
733+
edge_program_manager.to_backend(AddAttributePartitionerDemo())

exir/lowered_backend_module.py

Lines changed: 122 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import copy
1010
import operator
11+
from collections import defaultdict
1112
from typing import Any, Dict, List, Optional, Set, Tuple, Union
1213

1314
import torch
@@ -488,8 +489,12 @@ def _get_new_signature( # noqa: C901
488489
else {}
489490
)
490491

492+
toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list)
493+
if not is_submodule:
494+
for output_spec in old_signature.output_specs:
495+
toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec)
496+
491497
for node in gm.graph.nodes:
492-
is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag
493498
if node.op == "placeholder":
494499

495500
if node.name not in input_node_to_sig:
@@ -507,7 +512,7 @@ def _get_new_signature( # noqa: C901
507512
if not isinstance(orig_input_spec.arg, TensorArgument):
508513
input_specs.append(orig_input_spec)
509514

510-
elif is_tagged:
515+
elif node.meta.get("delegation_tag", None) == tag:
511516
input_specs.append(orig_input_spec)
512517

513518
if orig_input_spec.kind == InputKind.USER_INPUT:
@@ -551,11 +556,72 @@ def _get_new_signature( # noqa: C901
551556
)
552557

553558
if node.op == "output":
554-
output_nodes = pytree.tree_leaves((node.args, node.kwargs))
559+
buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list)
560+
for user in call_module_node.users.keys():
561+
if user.name in toplevel_output_node_to_sig:
562+
assert (
563+
user.op == "call_function" and user.target == operator.getitem
564+
), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}"
565+
getitem_idx = user.args[1]
566+
assert isinstance(
567+
getitem_idx, int
568+
), f"Invalid getitem type: {type(getitem_idx)}"
569+
buffer_mutation_idxs[getitem_idx].extend(
570+
toplevel_output_node_to_sig[user.name]
571+
)
555572

556-
for output_node in output_nodes:
573+
for i, output_node in enumerate(node.args[0]):
574+
if i in buffer_mutation_idxs:
575+
assert isinstance(output_node, torch.fx.Node)
576+
orig_output_specs = buffer_mutation_idxs[i]
577+
578+
if any(
579+
orig_output_spec.kind == OutputKind.BUFFER_MUTATION
580+
and orig_output_spec.target in new_state_dict
581+
for orig_output_spec in orig_output_specs
582+
):
583+
# If the delegate wants to consume the buffer, then the
584+
# delegate should also consume the buffer mutation
585+
# (output spec would be a BUFFER_MUTATION). Otherwise
586+
# the delegate will just return the result of the
587+
# mutation as a USER_OUTPUT.
588+
589+
orig_output_spec = [
590+
orig_output_spec
591+
for orig_output_spec in orig_output_specs
592+
if orig_output_spec.kind == OutputKind.BUFFER_MUTATION
593+
and orig_output_spec.target in new_state_dict
594+
][0]
595+
596+
assert len(orig_output_specs) == 1, (
597+
f"Constant {orig_output_spec.target} was tagged to be "
598+
"consumed by the buffer, and was found to also contain "
599+
"a buffer mutation. However this buffer mutation node "
600+
"was found to also be used as other types of outputs "
601+
"which is currently not supported. Please file an "
602+
"issue on Github. \n\n"
603+
f"The toplevel program: {original_program}\n"
604+
)
605+
output_specs.append(
606+
OutputSpec(
607+
kind=OutputKind.BUFFER_MUTATION,
608+
arg=TensorArgument(name=output_node.name),
609+
target=orig_output_spec.target,
610+
)
611+
)
612+
output_specs_to_delete[orig_output_spec.arg.name] = (
613+
orig_output_spec
614+
)
615+
else:
616+
output_specs.append(
617+
OutputSpec(
618+
kind=OutputKind.USER_OUTPUT,
619+
arg=TensorArgument(name=output_node.name),
620+
target=None,
621+
)
622+
)
557623

558-
if not isinstance(output_node, torch.fx.Node):
624+
elif not isinstance(output_node, torch.fx.Node):
559625
output_specs.append(
560626
OutputSpec(
561627
kind=OutputKind.USER_OUTPUT,
@@ -630,6 +696,9 @@ def create_exported_program_from_submodule(
630696
in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1]
631697
out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1]
632698

699+
print(submodule.graph)
700+
print(subgraph_signature)
701+
633702
return (
634703
ExportedProgram(
635704
root=submodule,
@@ -774,7 +843,7 @@ def get_lowered_backend_modules(
774843
return lowered_programs
775844

776845

777-
def _unsafe_adjust_original_program(
846+
def _unsafe_adjust_original_program( # noqa: C901
778847
original_program: ExportedProgram,
779848
call_delegate_node: torch.fx.Node,
780849
input_specs_to_delete: Dict[str, InputSpec],
@@ -830,3 +899,50 @@ def _unsafe_adjust_original_program(
830899
del original_program._constants[input_spec.target]
831900
else:
832901
raise RuntimeError(f"Invalid input spec {input_spec} received")
902+
903+
# Delete buffer mutations from the output which were consumed by the delegate
904+
toplevel_output_node = None
905+
for node in reversed(original_program.graph.nodes):
906+
if node.op == "output":
907+
toplevel_output_node = node
908+
break
909+
910+
assert toplevel_output_node is not None
911+
assert (
912+
len(toplevel_output_node.args) == 1
913+
), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}"
914+
915+
new_output_args = [
916+
arg
917+
for arg in toplevel_output_node.args[0]
918+
if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete
919+
]
920+
toplevel_output_node.args = (tuple(new_output_args),)
921+
922+
# Delete the buffer mutation getitem nodes
923+
getitem_idxs: List[int] = []
924+
user_nodes = list(call_delegate_node.users.keys())
925+
for user in user_nodes:
926+
if user.name in output_specs_to_delete:
927+
assert (
928+
user.op == "call_function" and user.target == operator.getitem
929+
), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}"
930+
user_idx = user.args[1]
931+
assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}"
932+
getitem_idxs.append(user_idx)
933+
original_program.graph.erase_node(user)
934+
935+
getitem_idxs.sort(reverse=True)
936+
937+
# Adjust all the getitem indices after the deleted getitems
938+
user_nodes = list(call_delegate_node.users.keys())
939+
for user in user_nodes:
940+
assert user.op == "call_function" and user.target == operator.getitem
941+
user_idx = user.args[1]
942+
assert isinstance(user_idx, int)
943+
for i, idx in enumerate(getitem_idxs):
944+
if user_idx > idx:
945+
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
946+
break
947+
948+
original_program._validate()

0 commit comments

Comments
 (0)