Skip to content

[ET-VK] Add pass to remove local_scalar_dense #5886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ runtime.python_library(
"//executorch/backends/transforms:fuse_view_copy",
"//executorch/backends/transforms:mean_to_sum_div",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/backends/vulkan/passes:remove_local_scalar_dense",
"//executorch/exir:graph_module",
"//executorch/exir/_serialize:_bindings",
"//executorch/exir/_serialize:lib",
Expand Down
28 changes: 28 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,31 @@ def is_linear_permute(self, node: torch.fx.Node) -> bool:

return False

def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool:
"""
Scalar tensors are usually converted to scalar values in the graph via`
scalar_tensor[0].item()` in Python, which translates to a chain of
`local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph.
This function marks the entire chain as supported by the Vulkan delegate.

Later, within vulkan_preprocess there will be a graph transform which
replaces the chain with passing in the scalar tensor directly.
"""
if node.target == exir_ops.edge.aten.select_copy.int:
if len(node.users) != 1:
return False
# pyre-ignore
if node.args[0].meta["val"].numel() != 1:
return False

user = list(node.users.keys())[0]
return user.target == torch.ops.aten._local_scalar_dense.default

if node.target == torch.ops.aten._local_scalar_dense.default:
return True

return False

def is_node_supported(
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
Expand All @@ -122,6 +147,9 @@ def _is_node_supported(
if self.is_linear_permute(node):
return True

if self.is_in_local_scalar_dense_chain(node):
return True

if node.target not in VulkanSupportedOperators._ops:
return False

Expand Down
13 changes: 13 additions & 0 deletions backends/vulkan/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,16 @@ python_unittest(
"//caffe2:torch",
],
)

runtime.python_library(
name = "remove_local_scalar_dense",
srcs = ["remove_local_scalar_dense_ops.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)
44 changes: 44 additions & 0 deletions backends/vulkan/passes/remove_local_scalar_dense_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
"""
Remove local_scalar_dense op nodes and replace uses with parent node, or the
original scalar tensor.
"""
target_op = torch.ops.aten._local_scalar_dense.default
for node in graph.nodes:
if node.op == "call_function" and node.target == target_op:
replace_node = node.args[0]
# If the argument to the local_scalar_dense op is a select op with only
# one user, and the argument to the select op is a tensor with only one
# element (i.e. a scalar tensor), then replace the entire pattern with the
# scalar tensor.
if (
replace_node.op == "call_function"
and replace_node.target == exir_ops.edge.aten.select_copy.int
):
if replace_node.args[0].meta["val"].numel() == 1:
replace_node = replace_node.args[0]

with graph.inserting_after(node):
node.replace_all_uses_with(replace_node)

graph.eliminate_dead_code()
return graph


class RemoveLocalScalarDenseOpsTransform(ExportPass):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph)
return PassResult(graph_module, True)
5 changes: 5 additions & 0 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform

from executorch.backends.vulkan.passes.remove_local_scalar_dense_ops import (
RemoveLocalScalarDenseOpsTransform,
)

from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
serialize_vulkan_graph,
Expand Down Expand Up @@ -57,6 +61,7 @@ def preprocess( # noqa: C901
MeanToSumDiv(),
SpecPropPass(),
ConstraintBasedSymShapeEvalPass(),
RemoveLocalScalarDenseOpsTransform(),
MemoryPlanningPass(),
]

Expand Down
Loading