Skip to content

[ET-VK] Update RemoveLocalScalarDenseOpsTransform to tag scalar tensors as well #6069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ runtime.python_library(
"//executorch/backends/transforms:fuse_view_copy",
"//executorch/backends/transforms:mean_to_sum_div",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/backends/vulkan/_passes:remove_local_scalar_dense",
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/exir:graph_module",
"//executorch/exir/_serialize:_bindings",
"//executorch/exir/_serialize:lib",
Expand Down
13 changes: 13 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,16 @@ runtime.python_library(
"//executorch/exir/dialects:lib",
],
)

runtime.python_library(
name = "vulkan_passes",
srcs = [
"__init__.py",
],
visibility = [
"//executorch/backends/...",
],
deps = [
":remove_local_scalar_dense",
]
)
7 changes: 7 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
RemoveLocalScalarDenseOpsTransform,
)

__all__ = [
"RemoveLocalScalarDenseOpsTransform",
]
98 changes: 82 additions & 16 deletions backends/vulkan/_passes/remove_local_scalar_dense_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,95 @@
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from torch._subclasses.fake_tensor import FakeTensor


def node_is_local_scalar_dense_chain(node: torch.fx.Node) -> bool:
"""
Converting a tensor to a scalar via tensor[0].item() creates a index_select +
local_scalar_dense pattern in the graph. Check if a node is the start of this pattern.
"""
if (
node.op == "call_function"
and node.target == exir_ops.edge.aten.select_copy.int
and len(node.users) == 1
):
user = list(node.users.keys())[0]
return user.target == torch.ops.aten._local_scalar_dense.default

return False


def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None:
"""
A scalar tensor in the Vulkan backend is a tensor that can be represented as a scalar
value instead of a Tensor object. The criteria for identifying a tensor as a scalar
tensor are as follows:

1. The tensor has only 1 element
2. One of the node's uses is converting it to a scalar via `tensor[0].item()`, which
creates a index_select + local_scalar_dense pattern in the graph

If any of these criteria are fulfilled, then tag the node for the tensor to mark it
so that it is added as a scalar value during serialization.
"""
tensor_val = node.meta["val"]
if not isinstance(tensor_val, FakeTensor):
return

# Scalar tensors must have only one element
if tensor_val.numel() != 1:
return

for user in node.users:
if node_is_local_scalar_dense_chain(user):
node.meta["vkdg_is_scalar_tensor"] = True


def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None:
"""
Remove the index_select + local_scalar_dense pattern in the graph in favor of passing
the original scalar tensor directly.
"""
replace_node = node.args[0]
assert isinstance(replace_node, torch.fx.Node)
# If the argument to the local_scalar_dense op is a select op with only
# one user, and the argument to the select op is a tensor with only one
# element (i.e. a scalar tensor), then replace the entire pattern with the
# scalar tensor.
if (
replace_node.op == "call_function"
and replace_node.target == exir_ops.edge.aten.select_copy.int
):
# pyre-ignore
if replace_node.args[0].meta["val"].numel() == 1:
replace_node = replace_node.args[0]
assert isinstance(replace_node, torch.fx.Node)
assert replace_node.meta.get("vkdg_is_scalar_tensor", True)

with graph.inserting_after(node):
node.replace_all_uses_with(replace_node)


def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
"""
Remove local_scalar_dense op nodes and replace uses with parent node, or the
original scalar tensor.
The purpose of this pass is twofold:
1. Tag scalar tensors (see `tag_node_if_scalar_tensor()` for the criteria)
2. Remove the index_select + local_scalar_dense pattern in the graph in favor of
passing the original scalar tensor directly (see `remove_local_scalar_dense_chain()`)

This makes it easier to deal with scalar tensors in the Vulkan backend. In particular,
it allows serializing scalar tensors as SymInt objects instead of Tensor objects.
Because scalar tensors are often used to inform tensor shapes, their values need to
be easily accessed by the CPU during resizing logic, while also being able to reflect
updates to their value in any GPU shaders that reference them.
"""
target_op = torch.ops.aten._local_scalar_dense.default
for node in graph.nodes:
tag_node_if_scalar_tensor(node)

if node.op == "call_function" and node.target == target_op:
replace_node = node.args[0]
# If the argument to the local_scalar_dense op is a select op with only
# one user, and the argument to the select op is a tensor with only one
# element (i.e. a scalar tensor), then replace the entire pattern with the
# scalar tensor.
if (
replace_node.op == "call_function"
and replace_node.target == exir_ops.edge.aten.select_copy.int
):
if replace_node.args[0].meta["val"].numel() == 1:
replace_node = replace_node.args[0]

with graph.inserting_after(node):
node.replace_all_uses_with(replace_node)
remove_local_scalar_dense_chain(graph, node)

graph.eliminate_dead_code()
return graph
Expand Down
4 changes: 1 addition & 3 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform

from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
RemoveLocalScalarDenseOpsTransform,
)
from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform

from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
Expand Down
Loading