Skip to content

Commit cefe515

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK] Refine paritioner to account for storage type and memory layout (#6668)
Pull Request resolved: #6635 ## Context There are a variety of ways that tensors can be represented in Vulkan. The two main descriptors for how a tensor is laid out in memory is: 1. Storage Type (buffer or texture) 2. Memory Layout (which dim is packed along a texel, which dim has a stride of 1, etc.) Due to the differences between buffers and textures, and the differences between different memory layouts, an implementation for an operator may only support a specific set of (storage type, memory layout) combinations. Furthermore, if an operator implementation supports multiple (storage type, memory layout) combinations, there may be a "preferred" setting which results in optimal performance. These changes lay the foundation for the implementation of a memory metadata tagging graph transform, which will make sure that all tensors participating in an operator call is has a valid/optimal (storage type, memory layout) setting, and insert transition operators to transfer input tensors to the correct memory settings when necessary. An additional change that is required arises from the fact that in Vulkan, there is a limit on texture and buffer sizes. Therefore, the partitioner needs to account for the storage types and memory layouts supported by the operator implementation, and check if all tensors participating in a computation can be represented with some storage type, memory layout combination supported by the implementation. ## Changes Improvements to the operator registry: * Introduce utility functions to check the optimal and enabled storage types and memory layouts for an operator Improvements to the Partitioner: * Account for the storage types and memory layouts supported by an operator when deciding if a node should be partitioned * Improved logic for fusable ops (i.e. the permute/transpose before a mm which can be fused into linear) to check if the final target op is supported in Vulkan, and only partition those nodes if so. Otherwise, don't partition it so that it can be fused by another backend. ghstack-source-id: 251883705 @exported-using-ghexport Differential Revision: [D65428843](https://our.internmc.facebook.com/intern/diff/D65428843/) Co-authored-by: Stephen Jia <[email protected]>
1 parent d99d26e commit cefe515

File tree

5 files changed

+455
-111
lines changed

5 files changed

+455
-111
lines changed

backends/vulkan/op_registry.py

Lines changed: 155 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,69 @@
88

99
import operator
1010

11-
from typing import Callable, Dict, List, Optional, Union
11+
from typing import Callable, Dict, Optional, Set, Union
1212

1313
import executorch.backends.vulkan.custom_ops_lib # noqa
1414

1515
import torch
1616

17-
from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout
17+
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
18+
VkMemoryLayout,
19+
VkStorageType,
20+
)
21+
22+
from executorch.backends.vulkan.utils import (
23+
all_memory_layouts,
24+
all_packed_dims,
25+
PackedDim,
26+
)
1827
from executorch.exir.dialects._ops import ops as exir_ops
1928

2029
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2130
from torch._subclasses.fake_tensor import FakeTensor
2231

32+
######################
33+
## OpFeatures class ##
34+
######################
35+
2336

2437
def allow_node(node: torch.fx.Node) -> bool:
2538
return True
2639

2740

2841
class TextureImplFeatures:
2942
__slots__ = [
30-
# Indicates if the compute shader is agnostic to the packed dimension
31-
"uses_packed_dim",
32-
# Indicates if the compute shader is agnostic to the texture axis mapping
43+
"valid_packed_dims",
3344
"uses_axis_map",
34-
# Specifies a specific set of memory layouts that the shader supports. If it is
35-
# and empty list, then the supported memory layouts can be inferred from the
36-
# `uses_packed_dim` and `uses_axis_map` flags.
37-
"supported_layouts",
3845
]
3946

4047
def __init__(
4148
self,
42-
uses_packed_dim: bool = False,
4349
uses_axis_map: bool = False,
44-
supported_layouts: Optional[List[VkMemoryLayout]] = None,
50+
valid_packed_dims: Optional[Set[PackedDim]] = None,
4551
):
46-
self.uses_packed_dim: bool = uses_packed_dim
4752
self.uses_axis_map: bool = uses_axis_map
48-
self.supported_layouts: Optional[List[VkMemoryLayout]] = supported_layouts
53+
self.valid_packed_dims = set()
54+
if valid_packed_dims is not None:
55+
self.valid_packed_dims = valid_packed_dims
56+
57+
def valid_memory_layouts(self) -> Set[VkMemoryLayout]:
58+
"""
59+
Derive the set of memory layouts supported by the texture implementation based
60+
on the valid packed dimensions.
61+
"""
62+
layouts = set()
63+
64+
if PackedDim.WIDTH in self.valid_packed_dims:
65+
layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED)
66+
67+
if PackedDim.HEIGHT in self.valid_packed_dims:
68+
layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED)
69+
70+
if PackedDim.CHANNELS in self.valid_packed_dims:
71+
layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED)
72+
73+
return layouts
4974

5075

5176
class OpFeatures:
@@ -58,6 +83,9 @@ class OpFeatures:
5883
# bool indicating if the operator has a resize function, which allows it to
5984
# support dynamic shape tensors.
6085
"resize_fn",
86+
# Optimal
87+
"optimal_storage",
88+
"optimal_layout",
6189
# bool indicating if the operator handles its own prepacking. If this is True,
6290
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
6391
# of the op.
@@ -72,17 +100,90 @@ def __init__(
72100
texture_impl: Optional[TextureImplFeatures] = None,
73101
buffer_impl: bool = False,
74102
resize_fn: bool = False,
103+
optimal_storage: Optional[VkStorageType] = None,
104+
optimal_layout: Optional[VkMemoryLayout] = None,
75105
handles_own_prepacking: bool = False,
76106
check_node_fn: Optional[Callable] = None,
77107
):
78108
self.texture_impl: Optional[TextureImplFeatures] = texture_impl
79109
self.buffer_impl: bool = buffer_impl
80110
self.resize_fn: bool = resize_fn
111+
self.optimal_storage: Optional[VkStorageType] = optimal_storage
112+
self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout
81113
self.handles_own_prepacking: bool = handles_own_prepacking
82114
self.check_node_fn: Callable = allow_node
83115
if check_node_fn is not None:
84116
self.check_node_fn = check_node_fn
85117

118+
def propose_storage_type(self) -> Optional[VkStorageType]:
119+
"""
120+
Propose a storage type that should be used for this operator. A proposal can be
121+
made if one of the following is true:
122+
1. The operator specifies an optimal storage type
123+
2. Only one storage type is supported.
124+
125+
If both storage types are supported and no optimal storage type is specified,
126+
then None is returned to indicate that there is no preference in storage type.
127+
"""
128+
if self.optimal_storage is not None:
129+
return self.optimal_storage
130+
131+
if self.texture_impl is not None and not self.buffer_impl:
132+
return VkStorageType.TEXTURE_3D
133+
elif self.buffer_impl and self.texture_impl is None:
134+
return VkStorageType.BUFFER
135+
136+
return None
137+
138+
def supported_storage_types(self) -> Set[VkStorageType]:
139+
"""
140+
Return the set of storage types supported by this operator.
141+
"""
142+
storage_types = set()
143+
if self.texture_impl is not None:
144+
storage_types.add(VkStorageType.TEXTURE_3D)
145+
if self.buffer_impl:
146+
storage_types.add(VkStorageType.BUFFER)
147+
148+
return storage_types
149+
150+
def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]:
151+
"""
152+
Given a storage type as a precondition, propose a memory layout that should be
153+
used for this operator. A proposal can be made if one of the following is true:
154+
1. The operator specifies an optimal memory layout
155+
2. Only one memory layout is supported.
156+
157+
If multiple memory layouts are supported and no optimal memory layout is
158+
specified then return None to indicate that the "best" memory layout for the
159+
operator is ambiguous.
160+
"""
161+
if self.optimal_layout is not None:
162+
return self.optimal_layout
163+
164+
if storage == VkStorageType.TEXTURE_3D:
165+
assert self.texture_impl is not None
166+
possible_layouts = self.texture_impl.valid_memory_layouts()
167+
if len(possible_layouts) == 1:
168+
return next(iter(possible_layouts))
169+
170+
return None
171+
172+
def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]:
173+
"""
174+
Return the set of memory layouts supported by this operator for a given storage
175+
type.
176+
"""
177+
if storage == VkStorageType.TEXTURE_3D:
178+
assert self.texture_impl is not None
179+
return self.texture_impl.valid_memory_layouts()
180+
else:
181+
return all_memory_layouts
182+
183+
184+
#######################
185+
## Operator Registry ##
186+
#######################
86187

87188
OpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload]
88189

@@ -122,8 +223,8 @@ def update_features_impl(op: OpKey):
122223
)
123224
def register_ephemeral_op(features: OpFeatures):
124225
features.texture_impl = TextureImplFeatures(
125-
uses_packed_dim=True,
126226
uses_axis_map=True,
227+
valid_packed_dims=all_packed_dims,
127228
)
128229
features.buffer_impl = True
129230
features.resize_fn = True
@@ -143,8 +244,8 @@ def register_ephemeral_op(features: OpFeatures):
143244
)
144245
def register_binary_op(features: OpFeatures):
145246
features.texture_impl = TextureImplFeatures(
146-
uses_packed_dim=True,
147247
uses_axis_map=True,
248+
valid_packed_dims=all_packed_dims,
148249
)
149250
features.resize_fn = True
150251
return features
@@ -170,8 +271,8 @@ def register_binary_op(features: OpFeatures):
170271
)
171272
def register_unary_op(features: OpFeatures):
172273
features.texture_impl = TextureImplFeatures(
173-
uses_packed_dim=True,
174274
uses_axis_map=True,
275+
valid_packed_dims=all_packed_dims,
175276
)
176277
features.buffer_impl = True
177278
features.resize_fn = True
@@ -181,8 +282,8 @@ def register_unary_op(features: OpFeatures):
181282
@update_features(exir_ops.edge.aten._to_copy.default)
182283
def register_to_copy_op(features: OpFeatures):
183284
features.texture_impl = TextureImplFeatures(
184-
uses_packed_dim=True,
185285
uses_axis_map=True,
286+
valid_packed_dims=all_packed_dims,
186287
)
187288
features.resize_fn = True
188289

@@ -220,40 +321,43 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
220321
)
221322
def register_mm_op(features: OpFeatures):
222323
features.texture_impl = TextureImplFeatures(
223-
uses_packed_dim=False,
224324
uses_axis_map=True,
225-
supported_layouts=[
226-
VkMemoryLayout.TENSOR_WIDTH_PACKED,
227-
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
228-
],
325+
valid_packed_dims={
326+
PackedDim.WIDTH,
327+
PackedDim.CHANNELS,
328+
},
229329
)
230330
features.buffer_impl = True
231331
features.resize_fn = True
332+
features.optimal_storage = VkStorageType.TEXTURE_3D
333+
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
232334
features.handles_own_prepacking = True
233335
return features
234336

235337

236338
@update_features(exir_ops.edge.aten._weight_int8pack_mm.default)
237339
def register_int8_mm_op(features: OpFeatures):
238340
features.texture_impl = TextureImplFeatures(
239-
uses_packed_dim=False,
240341
uses_axis_map=False,
241-
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
342+
valid_packed_dims={PackedDim.WIDTH},
242343
)
243344
features.buffer_impl = True
244345
features.resize_fn = True
346+
features.optimal_storage = VkStorageType.TEXTURE_3D
347+
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
245348
features.handles_own_prepacking = True
246349
return features
247350

248351

249352
@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
250353
def register_int4_mm_op(features: OpFeatures):
251354
features.texture_impl = TextureImplFeatures(
252-
uses_packed_dim=False,
253355
uses_axis_map=False,
254-
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
356+
valid_packed_dims={PackedDim.WIDTH},
255357
)
256358
features.resize_fn = True
359+
features.optimal_storage = VkStorageType.TEXTURE_3D
360+
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
257361
features.handles_own_prepacking = True
258362
return features
259363

@@ -266,7 +370,7 @@ def register_int4_mm_op(features: OpFeatures):
266370
)
267371
def register_softmax_op(features: OpFeatures):
268372
features.texture_impl = TextureImplFeatures(
269-
uses_packed_dim=True,
373+
valid_packed_dims=all_packed_dims,
270374
)
271375
features.resize_fn = True
272376
return features
@@ -282,7 +386,7 @@ def register_softmax_op(features: OpFeatures):
282386
)
283387
def register_reduce_op(features: OpFeatures):
284388
features.texture_impl = TextureImplFeatures(
285-
uses_packed_dim=True,
389+
valid_packed_dims=all_packed_dims,
286390
)
287391
features.resize_fn = True
288392

@@ -309,7 +413,7 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
309413
)
310414
def register_2d_pool_op(features: OpFeatures):
311415
features.texture_impl = TextureImplFeatures(
312-
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
416+
valid_packed_dims={PackedDim.CHANNELS},
313417
)
314418
features.resize_fn = True
315419
return features
@@ -323,27 +427,31 @@ def register_2d_pool_op(features: OpFeatures):
323427
)
324428
def register_convolution_op(features: OpFeatures):
325429
features.texture_impl = TextureImplFeatures(
326-
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
430+
valid_packed_dims={PackedDim.CHANNELS},
327431
)
328432
features.resize_fn = True
433+
features.optimal_storage = VkStorageType.TEXTURE_3D
434+
features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
329435
features.handles_own_prepacking = True
330436
return features
331437

332438

333439
@update_features("llama::sdpa_with_kv_cache")
334440
def register_sdpa_op(features: OpFeatures):
335441
features.texture_impl = TextureImplFeatures(
336-
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
442+
valid_packed_dims={PackedDim.WIDTH},
337443
)
338444
features.resize_fn = True
445+
features.optimal_storage = VkStorageType.TEXTURE_3D
446+
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
339447
features.handles_own_prepacking = True
340448
return features
341449

342450

343451
@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
344452
def register_rotary_emb_op(features: OpFeatures):
345453
features.texture_impl = TextureImplFeatures(
346-
supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED],
454+
valid_packed_dims={PackedDim.WIDTH},
347455
)
348456
features.resize_fn = True
349457
return features
@@ -352,7 +460,7 @@ def register_rotary_emb_op(features: OpFeatures):
352460
@update_features(exir_ops.edge.aten.view_copy.default)
353461
def register_view_op(features: OpFeatures):
354462
features.texture_impl = TextureImplFeatures(
355-
uses_packed_dim=True,
463+
valid_packed_dims=all_packed_dims,
356464
)
357465
features.resize_fn = True
358466
return features
@@ -393,7 +501,7 @@ def register_view_op(features: OpFeatures):
393501
)
394502
def register_ported_op(features: OpFeatures):
395503
features.texture_impl = TextureImplFeatures(
396-
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
504+
valid_packed_dims={PackedDim.CHANNELS},
397505
)
398506
return features
399507

@@ -408,15 +516,24 @@ def register_ported_op(features: OpFeatures):
408516
)
409517
def register_ported_ops_with_prepacking(features: OpFeatures):
410518
features.texture_impl = TextureImplFeatures(
411-
supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED],
519+
valid_packed_dims={PackedDim.CHANNELS},
412520
)
413521
features.handles_own_prepacking = True
414522
return features
415523

416524

417-
##
418-
## Utility Functions
419-
##
525+
#######################
526+
## Utility functions ##
527+
#######################
528+
529+
530+
def has_impl(target: OpKey) -> bool:
531+
if not isinstance(target, str):
532+
if target not in vulkan_supported_ops:
533+
return target.name() in vulkan_supported_ops
534+
return target in vulkan_supported_ops
535+
else:
536+
return target in vulkan_supported_ops
420537

421538

422539
def get_op_features(target: OpKey) -> OpFeatures:

backends/vulkan/partitioner/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ runtime.python_library(
1313
],
1414
deps = [
1515
"//executorch/backends/vulkan:op_registry",
16+
"//executorch/backends/vulkan:utils_lib",
1617
"//executorch/backends/vulkan:vulkan_preprocess",
1718
"//executorch/exir:delegate",
1819
"//executorch/exir:lib",

0 commit comments

Comments
 (0)