Skip to content

Commit 836d556

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK] Introduce memory metadata tagging pass (#6669)
* [ET-VK] Refine paritioner to account for storage type and memory layout Pull Request resolved: #6635 ## Context There are a variety of ways that tensors can be represented in Vulkan. The two main descriptors for how a tensor is laid out in memory is: 1. Storage Type (buffer or texture) 2. Memory Layout (which dim is packed along a texel, which dim has a stride of 1, etc.) Due to the differences between buffers and textures, and the differences between different memory layouts, an implementation for an operator may only support a specific set of (storage type, memory layout) combinations. Furthermore, if an operator implementation supports multiple (storage type, memory layout) combinations, there may be a "preferred" setting which results in optimal performance. These changes lay the foundation for the implementation of a memory metadata tagging graph transform, which will make sure that all tensors participating in an operator call is has a valid/optimal (storage type, memory layout) setting, and insert transition operators to transfer input tensors to the correct memory settings when necessary. An additional change that is required arises from the fact that in Vulkan, there is a limit on texture and buffer sizes. Therefore, the partitioner needs to account for the storage types and memory layouts supported by the operator implementation, and check if all tensors participating in a computation can be represented with some storage type, memory layout combination supported by the implementation. ## Changes Improvements to the operator registry: * Introduce utility functions to check the optimal and enabled storage types and memory layouts for an operator Improvements to the Partitioner: * Account for the storage types and memory layouts supported by an operator when deciding if a node should be partitioned * Improved logic for fusable ops (i.e. the permute/transpose before a mm which can be fused into linear) to check if the final target op is supported in Vulkan, and only partition those nodes if so. Otherwise, don't partition it so that it can be fused by another backend. ghstack-source-id: 251883705 @exported-using-ghexport Differential Revision: [D65428843](https://our.internmc.facebook.com/intern/diff/D65428843/) * [ET-VK] Introduce memory metadata tagging pass Pull Request resolved: #6636 ## Context As title; implements the memory metadata tagging graph transform described in the dependent diff. See the comments for more details. ghstack-source-id: 251884020 @exported-using-ghexport Differential Revision: [D65428842](https://our.internmc.facebook.com/intern/diff/D65428842/) --------- Co-authored-by: Stephen Jia <[email protected]>
1 parent cefe515 commit 836d556

File tree

11 files changed

+404
-16
lines changed

11 files changed

+404
-16
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@ runtime.python_library(
1616
],
1717
)
1818

19+
runtime.python_library(
20+
name = "int4_weight_only_quantizer",
21+
srcs = [
22+
"int4_weight_only_quantizer.py",
23+
],
24+
visibility = [
25+
"//executorch/backends/...",
26+
],
27+
deps = [
28+
"//executorch/backends/vulkan:custom_ops_lib",
29+
"//pytorch/ao:torchao",
30+
]
31+
)
32+
1933
runtime.python_library(
2034
name = "remove_local_scalar_dense",
2135
srcs = ["remove_local_scalar_dense_ops.py"],
@@ -30,17 +44,18 @@ runtime.python_library(
3044
)
3145

3246
runtime.python_library(
33-
name = "int4_weight_only_quantizer",
34-
srcs = [
35-
"int4_weight_only_quantizer.py",
36-
],
47+
name = "tag_memory_meta_pass",
48+
srcs = ["tag_memory_meta_pass.py"],
3749
visibility = [
3850
"//executorch/backends/...",
3951
],
4052
deps = [
41-
"//executorch/backends/vulkan:custom_ops_lib",
42-
"//pytorch/ao:torchao",
43-
]
53+
"//caffe2:torch",
54+
"//executorch/exir:pass_base",
55+
"//executorch/exir/dialects:lib",
56+
"//executorch/backends/vulkan:utils_lib",
57+
"//executorch/backends/vulkan/serialization:lib",
58+
],
4459
)
4560

4661
runtime.python_library(
@@ -56,5 +71,6 @@ runtime.python_library(
5671
":insert_prepack_nodes",
5772
":int4_weight_only_quantizer",
5873
":remove_local_scalar_dense",
74+
":tag_memory_meta_pass"
5975
]
6076
)

backends/vulkan/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
66
RemoveLocalScalarDenseOpsTransform,
77
)
8+
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
89

910
__all__ = [
1011
"insert_prepack_nodes",
1112
"VkInt4WeightOnlyQuantizer",
1213
"RemoveLocalScalarDenseOpsTransform",
14+
"TagMemoryMetaPass",
1315
]
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from copy import deepcopy
9+
from typing import Set
10+
11+
import executorch.backends.vulkan.utils as utils
12+
13+
import torch
14+
15+
from executorch.backends.vulkan.op_registry import get_op_features, has_impl
16+
17+
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
18+
VkMemoryLayout,
19+
VkStorageType,
20+
)
21+
22+
from executorch.exir.dialects._ops import ops as exir_ops
23+
24+
from executorch.exir.pass_base import ExportPass, PassResult
25+
26+
from torch._subclasses.fake_tensor import FakeTensor
27+
28+
from torch.fx.passes.tools_common import NodeList
29+
from torch.fx.passes.utils.fuser_utils import topo_sort
30+
31+
logger: logging.Logger = logging.getLogger("")
32+
logger.setLevel(logging.INFO)
33+
34+
35+
def set_memory_metadata(
36+
node: torch.fx.Node, storage: VkStorageType, layout: VkMemoryLayout
37+
) -> None:
38+
utils.set_node_spec_attr(node, "vk_storage_type", storage)
39+
utils.set_node_spec_attr(node, "vk_memory_layout", layout)
40+
41+
42+
class TagMemoryMetaPass(ExportPass):
43+
"""
44+
There are a variety of ways that tensors can be represented in Vulkan. The two main
45+
descriptors for how a tensor is laid out in memory is:
46+
47+
1. Storage Type (buffer or texture)
48+
2. Memory Layout (which dim is packed along a texel / has a stride of 1, etc.)
49+
50+
Due to the differences between buffers and textures, and the differences between
51+
different memory layouts, an implementation for an operator may only support a
52+
specific set of (storage type, memory layout) combinations.
53+
54+
Furthermore, if an operator implementation supports multiple (storage type, memory
55+
layout) combinations, there may be a "preferred" setting which results in optimal
56+
performance.
57+
58+
This pass is responsible for ensuring that all tensors participating in an operator
59+
call have a valid/optimal (storage type, memory layout) setting, and insert
60+
transition operators to transfer input tensors to the correct memory settings when
61+
necessary.
62+
"""
63+
64+
def __init__(
65+
self,
66+
texture_limits: utils.ImageExtents,
67+
default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D,
68+
default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED,
69+
):
70+
super().__init__()
71+
self.default_storage: VkStorageType = default_storage_type
72+
self.default_layout: VkMemoryLayout = default_memory_layout
73+
self.texture_limits = texture_limits
74+
75+
def propose_node_storage(
76+
self,
77+
node: torch.fx.Node,
78+
) -> VkStorageType:
79+
"""
80+
Uses the operator registry to determine the storage type that should be used for
81+
a given node. The storage type is determined with the following priorities:
82+
1. In some cases, a tensor involved in the computation may be too large to be
83+
represented as a texture. If this is the case, the node is "opinionated" and
84+
buffer representation must be used.
85+
1. If the operator called by the node indicates an optimal storage type, or only
86+
supports a single storage type, use that storage type. If either is true,
87+
then the node is considered to be opinionated as well. If multiple storage
88+
and no preferred storage type is indicated, then the node is not opinionated;
89+
go to the next step.
90+
2. If the node's arguments already have memory metadata annotations, then
91+
preserve the settings of the first argument. Otherwise, proceed to the next
92+
step.
93+
3. Recursively search the node's uses to see if any subsequent uses are
94+
opinionated; inherit the settings of the first opinionated node. If no
95+
opinionated user can be found, then proceed to the last step.
96+
4. Use the default storage type setting.
97+
"""
98+
# The node may have an input/output tensor that is too big to be stored in a
99+
# texture. In this case, buffer storage must be used. Note that the partitioner
100+
# has already checked for the fact that buffer storage is supported by the
101+
# operator.
102+
if len(utils.possible_node_memory_layouts(node, self.texture_limits)) == 0:
103+
return VkStorageType.BUFFER
104+
105+
valid_storage_types: Set[VkStorageType] = utils.all_storage_types
106+
107+
# pyre-ignore
108+
if has_impl(node.target):
109+
# pyre-ignore
110+
features = get_op_features(node.target)
111+
valid_storage_types = features.supported_storage_types()
112+
storage = features.propose_storage_type()
113+
if storage is not None:
114+
return storage
115+
116+
for arg in node.args:
117+
if isinstance(arg, torch.fx.Node) and isinstance(
118+
arg.meta["val"], FakeTensor
119+
):
120+
storage = utils.get_node_storage_type(arg)
121+
if storage is not None and storage in valid_storage_types:
122+
return storage
123+
124+
# If no storage type has been resolved yet, assume the optimal storage type of
125+
# the first opinionated user. This search is recursive.
126+
for user in node.users:
127+
optimal_storage = self.propose_node_storage(user)
128+
if optimal_storage is not None:
129+
return optimal_storage
130+
131+
if self.default_storage in valid_storage_types:
132+
return self.default_storage
133+
else:
134+
return next(iter(valid_storage_types))
135+
136+
def propose_node_layout(
137+
self,
138+
node: torch.fx.Node,
139+
storage: VkStorageType,
140+
) -> VkMemoryLayout:
141+
"""
142+
Performs the same steps as propose_node_storage, but detects the memory layout
143+
that should be used for the specific storage type. The same prioritization logic
144+
is applied.
145+
"""
146+
valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts
147+
# pyre-ignore
148+
if has_impl(node.target):
149+
# pyre-ignore
150+
features = get_op_features(node.target)
151+
valid_layouts = features.supported_memory_layouts(storage)
152+
layout = features.propose_memory_layout(storage)
153+
if layout is not None:
154+
return layout
155+
156+
for arg in node.args:
157+
if isinstance(arg, torch.fx.Node) and isinstance(
158+
arg.meta["val"], FakeTensor
159+
):
160+
layout = utils.get_node_memory_layout(arg)
161+
if layout is not None and layout in valid_layouts:
162+
return layout
163+
164+
# If no storage type has been resolved yet, assume the optimal storage type of
165+
# the first opinionated user. This search is recursive.
166+
for user in node.users:
167+
optimal_storage = self.propose_node_layout(user, storage)
168+
if optimal_storage is not None:
169+
return optimal_storage
170+
171+
# As a last resort, return the default storage type that should be used.
172+
if self.default_layout in valid_layouts:
173+
return self.default_layout
174+
else:
175+
return next(iter(valid_layouts))
176+
177+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
178+
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))
179+
180+
for node in sorted_nodes:
181+
if not isinstance(node.meta["val"], FakeTensor):
182+
continue
183+
184+
if node.target == exir_ops.edge.et_vk.prepack.default:
185+
continue
186+
187+
storage = self.propose_node_storage(node)
188+
layout = self.propose_node_layout(node, storage)
189+
190+
set_memory_metadata(node, storage, layout)
191+
192+
inserting_transitions_for_node = False
193+
for i, arg in enumerate(node.args):
194+
if not isinstance(arg, torch.fx.Node):
195+
continue
196+
if not isinstance(arg.meta["val"], FakeTensor):
197+
continue
198+
199+
arg_storage = utils.get_node_storage_type(arg)
200+
arg_layout = utils.get_node_memory_layout(arg)
201+
202+
if arg_storage is None:
203+
utils.set_node_spec_attr(arg, "vk_storage_type", storage)
204+
arg_storage = storage
205+
if arg_layout is None:
206+
utils.set_node_spec_attr(arg, "vk_memory_layout", layout)
207+
arg_layout = layout
208+
209+
if arg_storage == storage and arg_layout == layout:
210+
continue
211+
212+
if not inserting_transitions_for_node:
213+
inserting_transitions_for_node = True
214+
logger.info(
215+
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
216+
)
217+
218+
logger.info(
219+
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
220+
)
221+
222+
# Insert a clone node to copy the original tensor to a tensor with the
223+
# desired storage type and memory layout.
224+
with graph_module.graph.inserting_before(node):
225+
clone_node = graph_module.graph.create_node(
226+
"call_function",
227+
exir_ops.edge.aten.clone.default,
228+
(arg,),
229+
)
230+
clone_node.meta["val"] = arg.meta["val"]
231+
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
232+
clone_node.meta["spec"].const = False
233+
set_memory_metadata(clone_node, storage, layout)
234+
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)
235+
236+
return PassResult(graph_module, True)

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,11 @@ def op_node_is_compatible(
9494
# If there are no valid texture memory layouts, then buffer storage must be
9595
# supported by the operator implementation.
9696
if len(valid_texture_layouts) == 0:
97-
# TODO: once memory metadata tagging pass is implemented, check that the
98-
# op impl supports buffers instead
99-
return False, "requires buffer representation"
97+
compatible = VkStorageType.BUFFER in features.supported_storage_types()
98+
reason = "op is compatible"
99+
if not compatible:
100+
reason = "op requires buffers which is not supported by op impl"
101+
return compatible, reason
100102

101103
op_available_layouts = features.supported_memory_layouts(
102104
VkStorageType.TEXTURE_3D

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
1313

1414
import torch
15+
16+
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
17+
VkMemoryLayout,
18+
VkStorageType,
19+
)
1520
from executorch.backends.vulkan.utils import (
1621
is_constant,
1722
is_get_attr_node,
@@ -169,6 +174,15 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
169174
if spec.mem_obj_id is not None:
170175
mem_obj_id = spec.mem_obj_id
171176

177+
storage_type = VkStorageType.DEFAULT_STORAGE
178+
memory_layout = VkMemoryLayout.DEFAULT_LAYOUT
179+
if hasattr(spec, "vk_storage_type"):
180+
# pyre-ignore[16]
181+
storage_type = spec.vk_storage_type
182+
if hasattr(spec, "vk_memory_layout"):
183+
# pyre-ignore[16]
184+
memory_layout = spec.vk_memory_layout
185+
172186
new_id = len(self.values)
173187
self.values.append(
174188
vk_graph_schema.VkValue(
@@ -177,6 +191,8 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
177191
dims=spec.shape,
178192
constant_id=constant_id,
179193
mem_obj_id=mem_obj_id,
194+
storage_type=storage_type,
195+
memory_layout=memory_layout,
180196
)
181197
)
182198
)

backends/vulkan/serialization/vulkan_graph_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,19 @@ class VkStorageType(IntEnum):
3737
TEXTURE_2D = 2
3838
DEFAULT_STORAGE = 255
3939

40+
def __str__(self) -> str:
41+
return self.name
42+
4043

4144
class VkMemoryLayout(IntEnum):
4245
TENSOR_WIDTH_PACKED = 0
4346
TENSOR_HEIGHT_PACKED = 1
4447
TENSOR_CHANNELS_PACKED = 2
4548
DEFAULT_LAYOUT = 255
4649

50+
def __str__(self) -> str:
51+
return self.name
52+
4753

4854
@dataclass
4955
class VkTensor:

backends/vulkan/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ def define_common_targets(is_fbcode = False):
223223
],
224224
deps = [
225225
"//caffe2:torch",
226+
"//executorch/exir:tensor",
227+
"//executorch/backends/vulkan/serialization:lib",
226228
]
227229
)
228230

0 commit comments

Comments
 (0)