Skip to content

Commit 858efd2

Browse files
committed
[ET-VK] Add pass to remove local_scalar_dense
## Context Scalar tensors (i.e. tensors with only 1 element) are often passed in to functions as scalars via ``` scalar_tensor[0].item() ``` This translates to the following chain in the graph ``` index_select = index_select(scalar_tensor, ...) scalar = local_scalar_dense(index_select) ``` This diff introduces a pass to remove the `local_scalar_dense` "chain" in favor of passing in the input tensor directly. Note that this replacement only occurs if the original tensor is a scalar tensor. In the Vulkan backend, these scalar tensors will be represented as symbolic integers instead of actual tensors, which is why this replacement is valid. However, it may not a valid replacement for other backends. Differential Revision: [D63913432](https://our.internmc.facebook.com/intern/diff/D63913432/) ghstack-source-id: 246377322 Pull Request resolved: #5886
1 parent bb2623d commit 858efd2

File tree

5 files changed

+90
-0
lines changed

5 files changed

+90
-0
lines changed

backends/vulkan/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ runtime.python_library(
2828
"//executorch/backends/transforms:fuse_view_copy",
2929
"//executorch/backends/transforms:mean_to_sum_div",
3030
"//executorch/backends/transforms:remove_clone_ops",
31+
"//executorch/backends/vulkan/passes:remove_local_scalar_dense",
3132
"//executorch/exir:graph_module",
3233
"//executorch/exir/_serialize:_bindings",
3334
"//executorch/exir/_serialize:lib",

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,30 @@ def is_linear_permute(self, node: torch.fx.Node) -> bool:
108108

109109
return False
110110

111+
def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool:
112+
"""
113+
Scalar tensors are usually converted to scalar values in the graph via`
114+
scalar_tensor[0].item()` in Python, which translates to a chain of
115+
`local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph.
116+
This function marks the entire chain as supported by the Vulkan delegate.
117+
118+
Later, within vulkan_preprocess there will be a graph transform which
119+
replaces the chain with passing in the scalar tensor directly.
120+
"""
121+
if node.target == exir_ops.edge.aten.select_copy.int:
122+
if len(node.users) != 1:
123+
return False
124+
if node.args[0].meta["val"].numel() != 1:
125+
return False
126+
127+
user = list(node.users.keys())[0]
128+
return user.target == torch.ops.aten._local_scalar_dense.default
129+
130+
if node.target == torch.ops.aten._local_scalar_dense.default:
131+
return True
132+
133+
return False
134+
111135
def is_node_supported(
112136
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
113137
) -> bool:
@@ -128,6 +152,9 @@ def _is_node_supported(
128152
if self.is_linear_permute(node):
129153
return True
130154

155+
if self.is_in_local_scalar_dense_chain(node):
156+
return True
157+
131158
if target not in VulkanSupportedOperators._ops:
132159
return False
133160

backends/vulkan/passes/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,16 @@ python_unittest(
2727
"//caffe2:torch",
2828
],
2929
)
30+
31+
runtime.python_library(
32+
name = "remove_local_scalar_dense",
33+
srcs = ["remove_local_scalar_dense_ops.py"],
34+
visibility = [
35+
"//executorch/backends/...",
36+
],
37+
deps = [
38+
"//caffe2:torch",
39+
"//executorch/exir:pass_base",
40+
"//executorch/exir/dialects:lib",
41+
],
42+
)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import torch
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
15+
"""
16+
Remove local_scalar_dense op nodes and replace uses with parent node, or the
17+
original scalar tensor.
18+
"""
19+
target_op = torch.ops.aten._local_scalar_dense.default
20+
for node in graph.nodes:
21+
if node.op == "call_function" and node.target == target_op:
22+
replace_node = node.args[0]
23+
# If the argument to the local_scalar_dense op is a select op with only
24+
# one user, and the argument to the select op is a tensor with only one
25+
# element (i.e. a scalar tensor), then replace the entire pattern with the
26+
# scalar tensor.
27+
if (
28+
replace_node.op == "call_function"
29+
and replace_node.target == exir_ops.edge.aten.select_copy.int
30+
):
31+
if replace_node.args[0].meta["val"].numel() == 1:
32+
replace_node = replace_node.args[0]
33+
34+
with graph.inserting_after(node):
35+
node.replace_all_uses_with(replace_node)
36+
37+
graph.eliminate_dead_code()
38+
return graph
39+
40+
41+
class RemoveLocalScalarDenseOpsTransform(ExportPass):
42+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
43+
graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph)
44+
return PassResult(graph_module, True)

backends/vulkan/vulkan_preprocess.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
1818
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
1919

20+
from executorch.backends.vulkan.passes.remove_local_scalar_dense_ops import (
21+
RemoveLocalScalarDenseOpsTransform,
22+
)
23+
2024
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
2125
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
2226
serialize_vulkan_graph,
@@ -63,6 +67,7 @@ def preprocess( # noqa: C901
6367
MeanToSumDiv(),
6468
SpecPropPass(),
6569
ConstraintBasedSymShapeEvalPass(),
70+
RemoveLocalScalarDenseOpsTransform(),
6671
MemoryPlanningPass(),
6772
]
6873

0 commit comments

Comments
 (0)