Skip to content

fix coreml partitioner for llama #2816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,26 +509,25 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
underlying data will be owned by multiple delegates.
"""
for node in edge_program.graph.nodes:
# go through const/param/buffer nodes
is_attr = (
node.op == "placeholder"
and (
is_param(edge_program, node)
or is_buffer(edge_program, node)
or is_lifted_tensor_constant(edge_program, node)
)
) or (node.op == "get_attr")
# if all users of const/param/buffer nodes are partitioned then partition
if is_attr:
# go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition
if node.op == "placeholder" and (
is_param(edge_program, node)
or is_buffer(edge_program, node)
or is_lifted_tensor_constant(edge_program, node)
):
user_tags = set()
for user in node.users:
user_tags.add(user.meta.get("delegation_tag", None))
assert len(user_tags) <= 1, (
"Const/Param/Buffer users have multiple tags because one constant data can't "
"be owned by multiple backends. Consider duplicating the constant data so that "
"each user is unique"
)
if len(user_tags) == 1:
user_tag = user.meta.get("delegation_tag", None)
if user_tag is not None:
user_tags.add(user_tag)
if len(user_tags) > 1:
logging.info(
f"The data node is used across multiple partitions, including {user_tags}. "
"If the data is too large and it's not preferred to copy, please tag the "
"constant node like node.['no_copy'] = True and they won't be copied."
)
# tag the data node with the same tag as the last user
if len(user_tags) > 0:
node.meta["delegation_tag"] = user_tags.pop()


Expand Down