Skip to content

Commit 6f9ea64

Browse files
cccclaifacebook-github-bot
authored andcommitted
fix constant tagging util function (#2816)
Summary: Fix the constant nodes tagging util function ``` python3 -m examples.models.llama2.export_llama --coreml --use_kv_cache ``` Reviewed By: mcr229 Differential Revision: D55663281
1 parent 399482c commit 6f9ea64

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

exir/backend/utils.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
509509
underlying data will be owned by multiple delegates.
510510
"""
511511
for node in edge_program.graph.nodes:
512-
# go through const/param/buffer nodes
512+
# go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition
513513
is_attr = (
514514
node.op == "placeholder"
515515
and (
@@ -518,17 +518,20 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
518518
or is_lifted_tensor_constant(edge_program, node)
519519
)
520520
) or (node.op == "get_attr")
521-
# if all users of const/param/buffer nodes are partitioned then partition
522521
if is_attr:
523522
user_tags = set()
524523
for user in node.users:
525-
user_tags.add(user.meta.get("delegation_tag", None))
526-
assert len(user_tags) <= 1, (
527-
"Const/Param/Buffer users have multiple tags because one constant data can't "
528-
"be owned by multiple backends. Consider duplicating the constant data so that "
529-
"each user is unique"
530-
)
531-
if len(user_tags) == 1:
524+
user_tag = user.meta.get("delegation_tag", None)
525+
if user_tag is not None:
526+
user_tags.add(user_tag)
527+
if len(user_tags) > 1:
528+
logging.info(
529+
f"The data node is used across multiple partitions, including {user_tags}. "
530+
"If the data is too large and it's not prefered to copied, please tag the "
531+
"constant node like node.['no_copy'] = True and they won't be copied."
532+
)
533+
# tag the data node with the same tag as the last user
534+
if len(user_tags) > 0:
532535
node.meta["delegation_tag"] = user_tags.pop()
533536

534537

0 commit comments

Comments
 (0)