Skip to content

Commit 91c382d

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK][AOT][ez] Introduce vulkan export utils lib (#6605)
Pull Request resolved: #6600 ## Changes As title. Introduce a common Python utility library for scripts in the Vulkan backend. ghstack-source-id: 251223077 Differential Revision: [D65291064](https://our.internmc.facebook.com/intern/diff/D65291064/) --------- Co-authored-by: Stephen Jia <[email protected]>
1 parent 3aaf584 commit 91c382d

File tree

5 files changed

+56
-46
lines changed

5 files changed

+56
-46
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ runtime.python_library(
1212
deps = [
1313
"//caffe2:torch",
1414
"//executorch/exir:pass_base",
15+
"//executorch/backends/vulkan:utils_lib",
1516
],
1617
)
1718

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
import torch
1414

1515
from executorch.backends.vulkan.op_registry import handles_own_prepacking
16+
from executorch.backends.vulkan.utils import is_param_node
1617

1718
from executorch.exir.dialects._ops import ops as exir_ops
1819

19-
from torch._export.utils import is_buffer, is_param
2020
from torch.export import ExportedProgram
2121

2222

@@ -31,25 +31,8 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
3131
argument into the operator implementation.
3232
"""
3333

34-
def is_get_attr_node(node: torch.fx.Node) -> bool:
35-
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
36-
37-
def is_constant(node: torch.fx.Node) -> bool:
38-
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
39-
40-
def is_param_node(node: torch.fx.Node) -> bool:
41-
"""
42-
Check if the given node is a parameter within the exported program
43-
"""
44-
return (
45-
is_get_attr_node(node)
46-
or is_param(program, node)
47-
or is_buffer(program, node)
48-
or is_constant(node)
49-
)
50-
5134
def prepack_not_required(node: torch.fx.Node) -> bool:
52-
if not is_param_node(node):
35+
if not is_param_node(program, node):
5336
return True
5437

5538
for user in node.users:

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
1313

1414
import torch
15+
from executorch.backends.vulkan.utils import (
16+
is_constant,
17+
is_get_attr_node,
18+
is_param_node,
19+
)
1520
from executorch.exir.backend.utils import DelegateMappingBuilder
1621

1722
from executorch.exir.tensor import TensorSpec
@@ -68,34 +73,12 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
6873
else:
6974
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")
7075

71-
def is_constant(self, node: Node):
72-
return (
73-
node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants
74-
)
75-
76-
def is_get_attr_node(self, node: Node) -> bool:
77-
"""
78-
Returns true if the given node is a get attr node for a tensor of the model
79-
"""
80-
return isinstance(node, Node) and node.op == "get_attr"
81-
82-
def is_param_node(self, node: Node) -> bool:
83-
"""
84-
Check if the given node is a parameter within the exported program
85-
"""
86-
return (
87-
self.is_get_attr_node(node)
88-
or is_param(self.program, node)
89-
or is_buffer(self.program, node)
90-
or self.is_constant(node)
91-
)
92-
9376
def get_constant(self, node: Node) -> Optional[torch.Tensor]:
9477
"""
9578
Returns the constant associated with the given node in the exported program.
9679
Returns None if the node is not a constant within the exported program
9780
"""
98-
if self.is_constant(node):
81+
if is_constant(self.program, node):
9982
constant_name = (
10083
self.program.graph_signature.inputs_to_lifted_tensor_constants[
10184
node.name
@@ -116,9 +99,9 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
11699
tensor = get_param(self.program, node)
117100
elif is_buffer(self.program, node):
118101
tensor = get_buffer(self.program, node)
119-
elif self.is_constant(node):
102+
elif is_constant(self.program, node):
120103
tensor = self.get_constant(node)
121-
elif self.is_get_attr_node(node):
104+
elif is_get_attr_node(node):
122105
# This is a hack to support both lifted and unlifted graph
123106
try:
124107
tensor = getattr(node.graph.owning_module, node.target)
@@ -132,7 +115,7 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
132115

133116
def maybe_add_constant_tensor(self, node: Node) -> int:
134117
constant_id = -1
135-
if self.is_param_node(node):
118+
if is_param_node(self.program, node):
136119
constant_id = len(self.const_tensors)
137120
self.const_tensors.append(self.get_param_tensor(node))
138121

@@ -280,7 +263,7 @@ def process_placeholder_node(self, node: Node) -> None:
280263
if len(node.users) == 0:
281264
return None
282265
ids = self.create_node_value(node)
283-
if not self.is_param_node(node):
266+
if not is_param_node(self.program, node):
284267
if isinstance(ids, int):
285268
self.input_ids.append(ids)
286269
else:

backends/vulkan/targets.bzl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,19 @@ def define_common_targets(is_fbcode = False):
213213
## AOT targets
214214
##
215215
if is_fbcode:
216+
runtime.python_library(
217+
name = "utils_lib",
218+
srcs = [
219+
"utils.py",
220+
],
221+
visibility = [
222+
"//executorch/backends/vulkan/...",
223+
],
224+
deps = [
225+
"//caffe2:torch",
226+
]
227+
)
228+
216229
runtime.python_library(
217230
name = "custom_ops_lib",
218231
srcs = [

backends/vulkan/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch._export.utils import is_buffer, is_param
9+
10+
from torch.export import ExportedProgram
11+
12+
13+
def is_get_attr_node(node: torch.fx.Node) -> bool:
14+
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
15+
16+
17+
def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool:
18+
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
19+
20+
21+
def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
22+
"""
23+
Check if the given node is a parameter within the exported program
24+
"""
25+
return (
26+
is_get_attr_node(node)
27+
or is_param(program, node)
28+
or is_buffer(program, node)
29+
or is_constant(program, node)
30+
)

0 commit comments

Comments
 (0)