Skip to content

Commit 6563109

Browse files
committed
Move vulkan.passes to vulkan._passes (#5919)
Summary: Changing vulkan.passes to vulkan._passes to indicate that these passes are not covered under the API stability guarantee. Pull Request resolved: #5919 Reviewed By: helunwencser Differential Revision: D63926849 fbshipit-source-id: bf135c46c6718bc37afa640cf51d004891516575 (cherry picked from commit e1832ef)
1 parent 40358fa commit 6563109

File tree

7 files changed

+49
-2
lines changed

7 files changed

+49
-2
lines changed

backends/transforms/fuse_conv_with_clamp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import sys
88

99
import torch
10-
from executorch.backends.vulkan.passes.custom_ops_defs import conv_with_clamp_op # noqa
10+
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
11+
conv_with_clamp_op,
12+
)
1113

1214
from executorch.exir.dialects._ops import ops as exir_ops
1315
from executorch.exir.pass_base import ExportPass, PassResult

backends/vulkan/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ runtime.python_library(
2828
"//executorch/backends/transforms:fuse_view_copy",
2929
"//executorch/backends/transforms:mean_to_sum_div",
3030
"//executorch/backends/transforms:remove_clone_ops",
31+
"//executorch/backends/vulkan/_passes:remove_local_scalar_dense",
3132
"//executorch/exir:graph_module",
3233
"//executorch/exir/_serialize:_bindings",
3334
"//executorch/exir/_serialize:lib",
File renamed without changes.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import torch
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
15+
"""
16+
Remove local_scalar_dense op nodes and replace uses with parent node, or the
17+
original scalar tensor.
18+
"""
19+
target_op = torch.ops.aten._local_scalar_dense.default
20+
for node in graph.nodes:
21+
if node.op == "call_function" and node.target == target_op:
22+
replace_node = node.args[0]
23+
# If the argument to the local_scalar_dense op is a select op with only
24+
# one user, and the argument to the select op is a tensor with only one
25+
# element (i.e. a scalar tensor), then replace the entire pattern with the
26+
# scalar tensor.
27+
if (
28+
replace_node.op == "call_function"
29+
and replace_node.target == exir_ops.edge.aten.select_copy.int
30+
):
31+
if replace_node.args[0].meta["val"].numel() == 1:
32+
replace_node = replace_node.args[0]
33+
34+
with graph.inserting_after(node):
35+
node.replace_all_uses_with(replace_node)
36+
37+
graph.eliminate_dead_code()
38+
return graph
39+
40+
41+
class RemoveLocalScalarDenseOpsTransform(ExportPass):
42+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
43+
graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph)
44+
return PassResult(graph_module, True)

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import operator
1010

11-
from executorch.backends.vulkan.passes.custom_ops_defs import ( # noqa
11+
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
1212
conv_with_clamp_op,
1313
grid_priors_op,
1414
)

0 commit comments

Comments
 (0)