Skip to content

Commit b7bca17

Browse files
committed
[ET-VK] Update RemoveLocalScalarDenseOpsTransform to tag scalar tensors as well
## Context See the new docstrings added to `remove_local_scalar_dense_ops` for more details on what the pass is trying to achieve. The goal is to mark tensors that are consumed as scalars via `tensor[0].item()` as "scalar tensors" that will be represented as a `SymInt` object in the vulkan delegate instead of a regular `Tensor` object. This diff also adds an `__init__.py` file to the `_passes` folder to make it easier to include Vulkan passes from one place. Differential Revision: [D64139867](https://our.internmc.facebook.com/intern/diff/D64139867/) [ghstack-poisoned]
1 parent 866b40c commit b7bca17

File tree

5 files changed

+104
-20
lines changed

5 files changed

+104
-20
lines changed

backends/vulkan/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ runtime.python_library(
2828
"//executorch/backends/transforms:fuse_view_copy",
2929
"//executorch/backends/transforms:mean_to_sum_div",
3030
"//executorch/backends/transforms:remove_clone_ops",
31-
"//executorch/backends/vulkan/_passes:remove_local_scalar_dense",
31+
"//executorch/backends/vulkan/_passes:vulkan_passes",
3232
"//executorch/exir:graph_module",
3333
"//executorch/exir/_serialize:_bindings",
3434
"//executorch/exir/_serialize:lib",

backends/vulkan/_passes/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,16 @@ runtime.python_library(
4040
"//executorch/exir/dialects:lib",
4141
],
4242
)
43+
44+
runtime.python_library(
45+
name = "vulkan_passes",
46+
srcs = [
47+
"__init__.py",
48+
],
49+
visibility = [
50+
"//executorch/backends/...",
51+
],
52+
deps = [
53+
":remove_local_scalar_dense",
54+
]
55+
)

backends/vulkan/_passes/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
2+
RemoveLocalScalarDenseOpsTransform,
3+
)
4+
5+
__all__ = [
6+
"RemoveLocalScalarDenseOpsTransform",
7+
]

backends/vulkan/_passes/remove_local_scalar_dense_ops.py

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,95 @@
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111
from executorch.exir.pass_base import ExportPass, PassResult
1212

13+
from torch._subclasses.fake_tensor import FakeTensor
14+
15+
16+
def node_is_local_scalar_dense_chain(node: torch.fx.Node) -> bool:
17+
"""
18+
Converting a tensor to a scalar via tensor[0].item() creates a index_select +
19+
local_scalar_dense pattern in the graph. Check if a node is the start of this pattern.
20+
"""
21+
if (
22+
node.op == "call_function"
23+
and node.target == exir_ops.edge.aten.select_copy.int
24+
and len(node.users) == 1
25+
):
26+
user = list(node.users.keys())[0]
27+
return user.target == torch.ops.aten._local_scalar_dense.default
28+
29+
return False
30+
31+
32+
def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None:
33+
"""
34+
A scalar tensor in the Vulkan backend is a tensor that can be represented as a scalar
35+
value instead of a Tensor object. The criteria for identifying a tensor as a scalar
36+
tensor are as follows:
37+
38+
1. The tensor has only 1 element
39+
2. One of the node's uses is converting it to a scalar via `tensor[0].item()`, which
40+
creates a index_select + local_scalar_dense pattern in the graph
41+
42+
If any of these criteria are fulfilled, then tag the node for the tensor to mark it
43+
so that it is added as a scalar value during serialization.
44+
"""
45+
tensor_val = node.meta["val"]
46+
if not isinstance(tensor_val, FakeTensor):
47+
return
48+
49+
# Scalar tensors must have only one element
50+
if tensor_val.numel() != 1:
51+
return
52+
53+
for user in node.users:
54+
if node_is_local_scalar_dense_chain(user):
55+
node.meta["vkdg_is_scalar_tensor"] = True
56+
57+
58+
def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None:
59+
"""
60+
Remove the index_select + local_scalar_dense pattern in the graph in favor of passing
61+
the original scalar tensor directly.
62+
"""
63+
replace_node = node.args[0]
64+
assert isinstance(replace_node, torch.fx.Node)
65+
# If the argument to the local_scalar_dense op is a select op with only
66+
# one user, and the argument to the select op is a tensor with only one
67+
# element (i.e. a scalar tensor), then replace the entire pattern with the
68+
# scalar tensor.
69+
if (
70+
replace_node.op == "call_function"
71+
and replace_node.target == exir_ops.edge.aten.select_copy.int
72+
):
73+
# pyre-ignore
74+
if replace_node.args[0].meta["val"].numel() == 1:
75+
replace_node = replace_node.args[0]
76+
assert isinstance(replace_node, torch.fx.Node)
77+
assert replace_node.meta.get("vkdg_is_scalar_tensor", True)
78+
79+
with graph.inserting_after(node):
80+
node.replace_all_uses_with(replace_node)
81+
1382

1483
def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
1584
"""
16-
Remove local_scalar_dense op nodes and replace uses with parent node, or the
17-
original scalar tensor.
85+
The purpose of this pass is twofold:
86+
1. Tag scalar tensors (see `tag_node_if_scalar_tensor()` for the criteria)
87+
2. Remove the index_select + local_scalar_dense pattern in the graph in favor of
88+
passing the original scalar tensor directly (see `remove_local_scalar_dense_chain()`)
89+
90+
This makes it easier to deal with scalar tensors in the Vulkan backend. In particular,
91+
it allows serializing scalar tensors as SymInt objects instead of Tensor objects.
92+
Because scalar tensors are often used to inform tensor shapes, their values need to
93+
be easily accessed by the CPU during resizing logic, while also being able to reflect
94+
updates to their value in any GPU shaders that reference them.
1895
"""
1996
target_op = torch.ops.aten._local_scalar_dense.default
2097
for node in graph.nodes:
98+
tag_node_if_scalar_tensor(node)
99+
21100
if node.op == "call_function" and node.target == target_op:
22-
replace_node = node.args[0]
23-
# If the argument to the local_scalar_dense op is a select op with only
24-
# one user, and the argument to the select op is a tensor with only one
25-
# element (i.e. a scalar tensor), then replace the entire pattern with the
26-
# scalar tensor.
27-
if (
28-
replace_node.op == "call_function"
29-
and replace_node.target == exir_ops.edge.aten.select_copy.int
30-
):
31-
if replace_node.args[0].meta["val"].numel() == 1:
32-
replace_node = replace_node.args[0]
33-
34-
with graph.inserting_after(node):
35-
node.replace_all_uses_with(replace_node)
101+
remove_local_scalar_dense_chain(graph, node)
36102

37103
graph.eliminate_dead_code()
38104
return graph

backends/vulkan/vulkan_preprocess.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
1818
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
1919

20-
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
21-
RemoveLocalScalarDenseOpsTransform,
22-
)
20+
from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform
2321

2422
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
2523
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (

0 commit comments

Comments
 (0)