Skip to content

Commit d3adb47

Browse files
cccclaifacebook-github-bot
authored andcommitted
exclude mutated buffer (#2876)
Summary: Pull Request resolved: #2876 Differential Revision: D55812844
1 parent 86b326a commit d3adb47

File tree

2 files changed

+101
-15
lines changed

2 files changed

+101
-15
lines changed

exir/backend/test/test_partitioner.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
2727
ExecutorBackend,
2828
)
29-
from executorch.exir.backend.utils import get_delegates
29+
from executorch.exir.backend.utils import get_delegates, tag_constant_data
3030

3131
from executorch.exir.dialects._ops import ops as exir_ops
3232

@@ -523,3 +523,80 @@ def partition(
523523
"constant data node (b_const) is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)",
524524
str(error.exception),
525525
)
526+
527+
def test_mutable_buffers(self) -> None:
528+
class MutableStateModule(torch.nn.Module):
529+
def __init__(self):
530+
super().__init__()
531+
self.register_buffer("my_state", torch.zeros(1))
532+
533+
def forward(self, x):
534+
y = x + self.my_state
535+
self.my_state.add_(1)
536+
return y
537+
538+
edge = exir.to_edge(
539+
torch.export.export(
540+
MutableStateModule(),
541+
(torch.zeros(1),),
542+
)
543+
)
544+
self.assertGreater(
545+
len(edge.exported_program().graph_signature.buffers_to_mutate),
546+
0,
547+
"The test case should at leaset one mutable buffer",
548+
)
549+
550+
class PartitionerTagData(Partitioner):
551+
def __init__(self):
552+
super().__init__()
553+
self.delegation_spec = DelegationSpec(
554+
ExecutorBackend.__name__,
555+
[CompileSpec(key, value) for key, value in self.spec.items()],
556+
)
557+
558+
def partition(
559+
self, edge_exported_program: ExportedProgram
560+
) -> PartitionResult:
561+
partition_tags = {}
562+
for node in edge_exported_program.graph.nodes:
563+
if node.op == "call_function" and node.target in [
564+
exir_ops.edge.aten.add.Tensor
565+
]:
566+
delegation_tag = "tag0"
567+
node.meta["delegation_tag"] = delegation_tag
568+
partition_tags[delegation_tag] = self.delegation_spec
569+
tag_constant_data(edge_exported_program)
570+
return PartitionResult(
571+
tagged_exported_program=edge_exported_program,
572+
partition_tags=partition_tags,
573+
)
574+
575+
# Check the edge program inital buffers_to_mutate
576+
mutate_op = "aten_add_tensor_1"
577+
self.assertEqual(
578+
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
579+
"my_state",
580+
)
581+
edge = edge.to_backend(PartitionerTagData())
582+
# After to_backend, add is delegated and is no longer in buffers_to_mutate.
583+
self.assertNotIn(
584+
mutate_op,
585+
edge.exported_program().graph_signature.buffers_to_mutate,
586+
)
587+
588+
mutate_op = "getitem_1"
589+
# Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate)
590+
self.assertEqual(
591+
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
592+
"my_state",
593+
)
594+
# Check the copy_ node is inserted
595+
edge = edge.to_executorch()
596+
copy_node = [
597+
node
598+
for node in edge.exported_program().graph.nodes
599+
if node.op == "call_function"
600+
and node.target == torch.ops.aten.copy_.default
601+
]
602+
self.assertEqual(len(copy_node), 1)

exir/backend/utils.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -515,20 +515,29 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
515515
or is_buffer(edge_program, node)
516516
or is_lifted_tensor_constant(edge_program, node)
517517
):
518-
user_tags = set()
519-
for user in node.users:
520-
user_tag = user.meta.get("delegation_tag", None)
521-
if user_tag is not None:
522-
user_tags.add(user_tag)
523-
if len(user_tags) > 1:
524-
logging.info(
525-
f"The data node is used across multiple partitions, including {user_tags}. "
526-
"If the data is too large and it's not preferred to copy, please tag the "
527-
"constant node like node.['no_copy'] = True and they won't be copied."
528-
)
529-
# tag the data node with the same tag as the last user
530-
if len(user_tags) > 0:
531-
node.meta["delegation_tag"] = user_tags.pop()
518+
is_mutated = False
519+
for node_user in node.users:
520+
if node_user.name in edge_program.graph_signature.buffers_to_mutate:
521+
logging.info(
522+
"The buffer node is a mutated buffer node, which is not constant."
523+
)
524+
is_mutated = True
525+
break
526+
if not is_mutated:
527+
user_tags = set()
528+
for user in node.users:
529+
user_tag = user.meta.get("delegation_tag", None)
530+
if user_tag is not None:
531+
user_tags.add(user_tag)
532+
if len(user_tags) > 1:
533+
logging.info(
534+
f"The data node is used across multiple partitions, including {user_tags}. "
535+
"If the data is too large and it's not preferred to copy, please tag the "
536+
"constant node like node.['no_copy'] = True and they won't be copied."
537+
)
538+
# tag the data node with the same tag as the last user
539+
if len(user_tags) > 0:
540+
node.meta["delegation_tag"] = user_tags.pop()
532541

533542

534543
# TODO - style: use templated types

0 commit comments

Comments
 (0)