Skip to content

[ET-VK][ez] Improve insert_prepack_node pass to handle multiple uses of constant tensors #10426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

from copy import deepcopy

import executorch.backends.vulkan.custom_ops_lib # noqa

import torch

from executorch.backends.vulkan.op_registry import handles_own_prepacking
from executorch.backends.vulkan.utils import is_param_node

Expand All @@ -31,27 +27,27 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
argument into the operator implementation.
"""

def prepack_not_required(node: torch.fx.Node) -> bool:
for node in program.graph_module.graph.nodes:
# Prepacking is only needed for constant tensors. Only nodes corresponding to
# constant tensors will proceed beyond this point.
if not is_param_node(program, node):
return True
continue

# Annotate that this node is going to represented as a tensorref in the Vulkan
# compute graph. This will be useful for later graph passes.
# Mark that this node is going to be represented as a TensorRef type in the
# Vulkan compute graph. This annotation is used in later graph passes.
node.meta["vkdg_tensorref"] = True

# Get the list of node users that do not handle their own prepacking
nodes_to_replace_input = []
for user in node.users:
if user.op == "call_function" and handles_own_prepacking(
# pyre-ignore
user.target
):
return True
if user.op == "call_function" and not handles_own_prepacking(user.target):
nodes_to_replace_input.append(user)

return False

for node in program.graph_module.graph.nodes:
if prepack_not_required(node):
if len(nodes_to_replace_input) == 0:
continue

replace_all_uses = len(nodes_to_replace_input) == len(node.users)

with program.graph_module.graph.inserting_after(node):
prepack_node = program.graph_module.graph.create_node(
"call_function",
Expand All @@ -74,9 +70,14 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
# memory object.
prepack_node.meta["spec"].mem_obj_id = -1
node.replace_all_uses_with(
prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output")
)
if replace_all_uses:
node.replace_all_uses_with(
prepack_node,
lambda x, y=prepack_node: (x != y and x.op != "output"),
)
else:
for user_node in nodes_to_replace_input:
user_node.replace_input_with(node, prepack_node)

program.graph.eliminate_dead_code()
return program
Loading