Skip to content

Commit 3477af8

Browse files
cccclaifacebook-github-bot
authored andcommitted
fix constant tagging util function (#2816)
Summary: Fix the constant nodes tagging util function ``` python3 -m examples.models.llama2.export_llama --coreml --use_kv_cache ``` Reviewed By: mcr229 Differential Revision: D55663281
1 parent 399482c commit 3477af8

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

exir/backend/utils.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -509,26 +509,25 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
509509
underlying data will be owned by multiple delegates.
510510
"""
511511
for node in edge_program.graph.nodes:
512-
# go through const/param/buffer nodes
513-
is_attr = (
514-
node.op == "placeholder"
515-
and (
516-
is_param(edge_program, node)
517-
or is_buffer(edge_program, node)
518-
or is_lifted_tensor_constant(edge_program, node)
519-
)
520-
) or (node.op == "get_attr")
521-
# if all users of const/param/buffer nodes are partitioned then partition
522-
if is_attr:
512+
# go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition
513+
if node.op == "placeholder" and (
514+
is_param(edge_program, node)
515+
or is_buffer(edge_program, node)
516+
or is_lifted_tensor_constant(edge_program, node)
517+
):
523518
user_tags = set()
524519
for user in node.users:
525-
user_tags.add(user.meta.get("delegation_tag", None))
526-
assert len(user_tags) <= 1, (
527-
"Const/Param/Buffer users have multiple tags because one constant data can't "
528-
"be owned by multiple backends. Consider duplicating the constant data so that "
529-
"each user is unique"
530-
)
531-
if len(user_tags) == 1:
520+
user_tag = user.meta.get("delegation_tag", None)
521+
if user_tag is not None:
522+
user_tags.add(user_tag)
523+
if len(user_tags) > 1:
524+
logging.info(
525+
f"The data node is used across multiple partitions, including {user_tags}. "
526+
"If the data is too large and it's not preferred to copy, please tag the "
527+
"constant node like node.['no_copy'] = True and they won't be copied."
528+
)
529+
# tag the data node with the same tag as the last user
530+
if len(user_tags) > 0:
532531
node.meta["delegation_tag"] = user_tags.pop()
533532

534533

0 commit comments

Comments
 (0)