Skip to content

Commit 400fefa

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Add pass to remove local_scalar_dense (#5886)
Summary: Pull Request resolved: #5886 ## Context Scalar tensors (i.e. tensors with only 1 element) are often passed in to functions as scalars via ``` scalar_tensor[0].item() ``` This translates to the following chain in the graph ``` index_select = index_select(scalar_tensor, ...) scalar = local_scalar_dense(index_select) ``` This diff introduces a pass to remove the `local_scalar_dense` "chain" in favor of passing in the input tensor directly. Note that this replacement only occurs if the original tensor is a scalar tensor. In the Vulkan backend, these scalar tensors will be represented as symbolic integers instead of actual tensors, which is why this replacement is valid. However, it may not a valid replacement for other backends. ghstack-source-id: 246752220 Reviewed By: jorgep31415 Differential Revision: D63913432 fbshipit-source-id: 86ab48ef6a171cd631643db521e8dec041fa63e0
1 parent cb12061 commit 400fefa

File tree

5 files changed

+91
-0
lines changed

5 files changed

+91
-0
lines changed

backends/vulkan/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ runtime.python_library(
2828
"//executorch/backends/transforms:fuse_view_copy",
2929
"//executorch/backends/transforms:mean_to_sum_div",
3030
"//executorch/backends/transforms:remove_clone_ops",
31+
"//executorch/backends/vulkan/passes:remove_local_scalar_dense",
3132
"//executorch/exir:graph_module",
3233
"//executorch/exir/_serialize:_bindings",
3334
"//executorch/exir/_serialize:lib",

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,31 @@ def is_linear_permute(self, node: torch.fx.Node) -> bool:
108108

109109
return False
110110

111+
def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool:
112+
"""
113+
Scalar tensors are usually converted to scalar values in the graph via`
114+
scalar_tensor[0].item()` in Python, which translates to a chain of
115+
`local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph.
116+
This function marks the entire chain as supported by the Vulkan delegate.
117+
118+
Later, within vulkan_preprocess there will be a graph transform which
119+
replaces the chain with passing in the scalar tensor directly.
120+
"""
121+
if node.target == exir_ops.edge.aten.select_copy.int:
122+
if len(node.users) != 1:
123+
return False
124+
# pyre-ignore
125+
if node.args[0].meta["val"].numel() != 1:
126+
return False
127+
128+
user = list(node.users.keys())[0]
129+
return user.target == torch.ops.aten._local_scalar_dense.default
130+
131+
if node.target == torch.ops.aten._local_scalar_dense.default:
132+
return True
133+
134+
return False
135+
111136
def is_node_supported(
112137
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
113138
) -> bool:
@@ -122,6 +147,9 @@ def _is_node_supported(
122147
if self.is_linear_permute(node):
123148
return True
124149

150+
if self.is_in_local_scalar_dense_chain(node):
151+
return True
152+
125153
if node.target not in VulkanSupportedOperators._ops:
126154
return False
127155

backends/vulkan/passes/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,16 @@ python_unittest(
2727
"//caffe2:torch",
2828
],
2929
)
30+
31+
runtime.python_library(
32+
name = "remove_local_scalar_dense",
33+
srcs = ["remove_local_scalar_dense_ops.py"],
34+
visibility = [
35+
"//executorch/backends/...",
36+
],
37+
deps = [
38+
"//caffe2:torch",
39+
"//executorch/exir:pass_base",
40+
"//executorch/exir/dialects:lib",
41+
],
42+
)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import torch
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
15+
"""
16+
Remove local_scalar_dense op nodes and replace uses with parent node, or the
17+
original scalar tensor.
18+
"""
19+
target_op = torch.ops.aten._local_scalar_dense.default
20+
for node in graph.nodes:
21+
if node.op == "call_function" and node.target == target_op:
22+
replace_node = node.args[0]
23+
# If the argument to the local_scalar_dense op is a select op with only
24+
# one user, and the argument to the select op is a tensor with only one
25+
# element (i.e. a scalar tensor), then replace the entire pattern with the
26+
# scalar tensor.
27+
if (
28+
replace_node.op == "call_function"
29+
and replace_node.target == exir_ops.edge.aten.select_copy.int
30+
):
31+
if replace_node.args[0].meta["val"].numel() == 1:
32+
replace_node = replace_node.args[0]
33+
34+
with graph.inserting_after(node):
35+
node.replace_all_uses_with(replace_node)
36+
37+
graph.eliminate_dead_code()
38+
return graph
39+
40+
41+
class RemoveLocalScalarDenseOpsTransform(ExportPass):
42+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
43+
graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph)
44+
return PassResult(graph_module, True)

backends/vulkan/vulkan_preprocess.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
1818
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
1919

20+
from executorch.backends.vulkan.passes.remove_local_scalar_dense_ops import (
21+
RemoveLocalScalarDenseOpsTransform,
22+
)
23+
2024
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
2125
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
2226
serialize_vulkan_graph,
@@ -57,6 +61,7 @@ def preprocess( # noqa: C901
5761
MeanToSumDiv(),
5862
SpecPropPass(),
5963
ConstraintBasedSymShapeEvalPass(),
64+
RemoveLocalScalarDenseOpsTransform(),
6065
MemoryPlanningPass(),
6166
]
6267

0 commit comments

Comments
 (0)