Skip to content

Commit cd2af09

Browse files
cccclaifacebook-github-bot
authored andcommitted
exclude mutated buffer (#2876)
Summary: Pull Request resolved: #2876 Differential Revision: D55812844
1 parent 86b326a commit cd2af09

File tree

2 files changed

+106
-15
lines changed

2 files changed

+106
-15
lines changed

exir/backend/test/test_partitioner.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
2727
ExecutorBackend,
2828
)
29-
from executorch.exir.backend.utils import get_delegates
29+
from executorch.exir.backend.utils import get_delegates, tag_constant_data
3030

3131
from executorch.exir.dialects._ops import ops as exir_ops
3232

@@ -523,3 +523,85 @@ def partition(
523523
"constant data node (b_const) is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)",
524524
str(error.exception),
525525
)
526+
527+
def test_not_delegate_mutable_buffers(self) -> None:
528+
"""
529+
A test case to check the mutated buffer is not delegated. We'll need to add a test case
530+
to consider when the delegate can consume the mutable buffer.
531+
"""
532+
533+
class MutableStateModule(torch.nn.Module):
534+
def __init__(self):
535+
super().__init__()
536+
self.register_buffer("my_state", torch.zeros(1))
537+
538+
def forward(self, x):
539+
y = x + self.my_state
540+
self.my_state.add_(1)
541+
return y
542+
543+
edge = exir.to_edge(
544+
torch.export.export(
545+
MutableStateModule(),
546+
(torch.zeros(1),),
547+
)
548+
)
549+
self.assertGreater(
550+
len(edge.exported_program().graph_signature.buffers_to_mutate),
551+
0,
552+
"The test case should at leaset one mutable buffer",
553+
)
554+
555+
class PartitionerTagData(Partitioner):
556+
def __init__(self):
557+
super().__init__()
558+
self.delegation_spec = DelegationSpec(
559+
ExecutorBackend.__name__,
560+
[CompileSpec(key, value) for key, value in self.spec.items()],
561+
)
562+
563+
def partition(
564+
self, edge_exported_program: ExportedProgram
565+
) -> PartitionResult:
566+
partition_tags = {}
567+
for node in edge_exported_program.graph.nodes:
568+
if node.op == "call_function" and node.target in [
569+
exir_ops.edge.aten.add.Tensor
570+
]:
571+
delegation_tag = "tag0"
572+
node.meta["delegation_tag"] = delegation_tag
573+
partition_tags[delegation_tag] = self.delegation_spec
574+
tag_constant_data(edge_exported_program)
575+
return PartitionResult(
576+
tagged_exported_program=edge_exported_program,
577+
partition_tags=partition_tags,
578+
)
579+
580+
# Check the edge program inital buffers_to_mutate
581+
mutate_op = "aten_add_tensor_1"
582+
self.assertEqual(
583+
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
584+
"my_state",
585+
)
586+
edge = edge.to_backend(PartitionerTagData())
587+
# After to_backend, add is delegated and is no longer in buffers_to_mutate.
588+
self.assertNotIn(
589+
mutate_op,
590+
edge.exported_program().graph_signature.buffers_to_mutate,
591+
)
592+
593+
mutate_op = "getitem_1"
594+
# Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate)
595+
self.assertEqual(
596+
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
597+
"my_state",
598+
)
599+
# Check the copy_ node is inserted
600+
edge = edge.to_executorch()
601+
copy_node = [
602+
node
603+
for node in edge.exported_program().graph.nodes
604+
if node.op == "call_function"
605+
and node.target == torch.ops.aten.copy_.default
606+
]
607+
self.assertEqual(len(copy_node), 1)

exir/backend/utils.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -515,20 +515,29 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
515515
or is_buffer(edge_program, node)
516516
or is_lifted_tensor_constant(edge_program, node)
517517
):
518-
user_tags = set()
519-
for user in node.users:
520-
user_tag = user.meta.get("delegation_tag", None)
521-
if user_tag is not None:
522-
user_tags.add(user_tag)
523-
if len(user_tags) > 1:
524-
logging.info(
525-
f"The data node is used across multiple partitions, including {user_tags}. "
526-
"If the data is too large and it's not preferred to copy, please tag the "
527-
"constant node like node.['no_copy'] = True and they won't be copied."
528-
)
529-
# tag the data node with the same tag as the last user
530-
if len(user_tags) > 0:
531-
node.meta["delegation_tag"] = user_tags.pop()
518+
is_mutated = False
519+
for node_user in node.users:
520+
if node_user.name in edge_program.graph_signature.buffers_to_mutate:
521+
logging.info(
522+
"The buffer node is a mutated buffer node, which is not constant."
523+
)
524+
is_mutated = True
525+
break
526+
if not is_mutated:
527+
user_tags = set()
528+
for user in node.users:
529+
user_tag = user.meta.get("delegation_tag", None)
530+
if user_tag is not None:
531+
user_tags.add(user_tag)
532+
if len(user_tags) > 1:
533+
logging.info(
534+
f"The data node is used across multiple partitions, including {user_tags}. "
535+
"If the data is too large and it's not preferred to copy, please tag the "
536+
"constant node like node.['no_copy'] = True and they won't be copied."
537+
)
538+
# tag the data node with the same tag as the last user
539+
if len(user_tags) > 0:
540+
node.meta["delegation_tag"] = user_tags.pop()
532541

533542

534543
# TODO - style: use templated types

0 commit comments

Comments
 (0)