8
8
9
9
import operator
10
10
11
- from typing import Callable , Dict , List , Optional , Union
11
+ from typing import Callable , Dict , Optional , Set , Union
12
12
13
13
import executorch .backends .vulkan .custom_ops_lib # noqa
14
14
15
15
import torch
16
16
17
- from executorch .backends .vulkan .serialization .vulkan_graph_schema import VkMemoryLayout
17
+ from executorch .backends .vulkan .serialization .vulkan_graph_schema import (
18
+ VkMemoryLayout ,
19
+ VkStorageType ,
20
+ )
21
+
22
+ from executorch .backends .vulkan .utils import (
23
+ all_memory_layouts ,
24
+ all_packed_dims ,
25
+ PackedDim ,
26
+ )
18
27
from executorch .exir .dialects ._ops import ops as exir_ops
19
28
20
29
from executorch .exir .dialects .edge ._ops import EdgeOpOverload
21
30
from torch ._subclasses .fake_tensor import FakeTensor
22
31
32
+ ######################
33
+ ## OpFeatures class ##
34
+ ######################
35
+
23
36
24
37
def allow_node (node : torch .fx .Node ) -> bool :
25
38
return True
26
39
27
40
28
41
class TextureImplFeatures :
29
42
__slots__ = [
30
- # Indicates if the compute shader is agnostic to the packed dimension
31
- "uses_packed_dim" ,
32
- # Indicates if the compute shader is agnostic to the texture axis mapping
43
+ "valid_packed_dims" ,
33
44
"uses_axis_map" ,
34
- # Specifies a specific set of memory layouts that the shader supports. If it is
35
- # and empty list, then the supported memory layouts can be inferred from the
36
- # `uses_packed_dim` and `uses_axis_map` flags.
37
- "supported_layouts" ,
38
45
]
39
46
40
47
def __init__ (
41
48
self ,
42
- uses_packed_dim : bool = False ,
43
49
uses_axis_map : bool = False ,
44
- supported_layouts : Optional [List [ VkMemoryLayout ]] = None ,
50
+ valid_packed_dims : Optional [Set [ PackedDim ]] = None ,
45
51
):
46
- self .uses_packed_dim : bool = uses_packed_dim
47
52
self .uses_axis_map : bool = uses_axis_map
48
- self .supported_layouts : Optional [List [VkMemoryLayout ]] = supported_layouts
53
+ self .valid_packed_dims = set ()
54
+ if valid_packed_dims is not None :
55
+ self .valid_packed_dims = valid_packed_dims
56
+
57
+ def valid_memory_layouts (self ) -> Set [VkMemoryLayout ]:
58
+ """
59
+ Derive the set of memory layouts supported by the texture implementation based
60
+ on the valid packed dimensions.
61
+ """
62
+ layouts = set ()
63
+
64
+ if PackedDim .WIDTH in self .valid_packed_dims :
65
+ layouts .add (VkMemoryLayout .TENSOR_WIDTH_PACKED )
66
+
67
+ if PackedDim .HEIGHT in self .valid_packed_dims :
68
+ layouts .add (VkMemoryLayout .TENSOR_HEIGHT_PACKED )
69
+
70
+ if PackedDim .CHANNELS in self .valid_packed_dims :
71
+ layouts .add (VkMemoryLayout .TENSOR_CHANNELS_PACKED )
72
+
73
+ return layouts
49
74
50
75
51
76
class OpFeatures :
@@ -58,6 +83,9 @@ class OpFeatures:
58
83
# bool indicating if the operator has a resize function, which allows it to
59
84
# support dynamic shape tensors.
60
85
"resize_fn" ,
86
+ # Optimal
87
+ "optimal_storage" ,
88
+ "optimal_layout" ,
61
89
# bool indicating if the operator handles its own prepacking. If this is True,
62
90
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
63
91
# of the op.
@@ -72,17 +100,90 @@ def __init__(
72
100
texture_impl : Optional [TextureImplFeatures ] = None ,
73
101
buffer_impl : bool = False ,
74
102
resize_fn : bool = False ,
103
+ optimal_storage : Optional [VkStorageType ] = None ,
104
+ optimal_layout : Optional [VkMemoryLayout ] = None ,
75
105
handles_own_prepacking : bool = False ,
76
106
check_node_fn : Optional [Callable ] = None ,
77
107
):
78
108
self .texture_impl : Optional [TextureImplFeatures ] = texture_impl
79
109
self .buffer_impl : bool = buffer_impl
80
110
self .resize_fn : bool = resize_fn
111
+ self .optimal_storage : Optional [VkStorageType ] = optimal_storage
112
+ self .optimal_layout : Optional [VkMemoryLayout ] = optimal_layout
81
113
self .handles_own_prepacking : bool = handles_own_prepacking
82
114
self .check_node_fn : Callable = allow_node
83
115
if check_node_fn is not None :
84
116
self .check_node_fn = check_node_fn
85
117
118
+ def propose_storage_type (self ) -> Optional [VkStorageType ]:
119
+ """
120
+ Propose a storage type that should be used for this operator. A proposal can be
121
+ made if one of the following is true:
122
+ 1. The operator specifies an optimal storage type
123
+ 2. Only one storage type is supported.
124
+
125
+ If both storage types are supported and no optimal storage type is specified,
126
+ then None is returned to indicate that there is no preference in storage type.
127
+ """
128
+ if self .optimal_storage is not None :
129
+ return self .optimal_storage
130
+
131
+ if self .texture_impl is not None and not self .buffer_impl :
132
+ return VkStorageType .TEXTURE_3D
133
+ elif self .buffer_impl and self .texture_impl is None :
134
+ return VkStorageType .BUFFER
135
+
136
+ return None
137
+
138
+ def supported_storage_types (self ) -> Set [VkStorageType ]:
139
+ """
140
+ Return the set of storage types supported by this operator.
141
+ """
142
+ storage_types = set ()
143
+ if self .texture_impl is not None :
144
+ storage_types .add (VkStorageType .TEXTURE_3D )
145
+ if self .buffer_impl :
146
+ storage_types .add (VkStorageType .BUFFER )
147
+
148
+ return storage_types
149
+
150
+ def propose_memory_layout (self , storage : VkStorageType ) -> Optional [VkMemoryLayout ]:
151
+ """
152
+ Given a storage type as a precondition, propose a memory layout that should be
153
+ used for this operator. A proposal can be made if one of the following is true:
154
+ 1. The operator specifies an optimal memory layout
155
+ 2. Only one memory layout is supported.
156
+
157
+ If multiple memory layouts are supported and no optimal memory layout is
158
+ specified then return None to indicate that the "best" memory layout for the
159
+ operator is ambiguous.
160
+ """
161
+ if self .optimal_layout is not None :
162
+ return self .optimal_layout
163
+
164
+ if storage == VkStorageType .TEXTURE_3D :
165
+ assert self .texture_impl is not None
166
+ possible_layouts = self .texture_impl .valid_memory_layouts ()
167
+ if len (possible_layouts ) == 1 :
168
+ return next (iter (possible_layouts ))
169
+
170
+ return None
171
+
172
+ def supported_memory_layouts (self , storage : VkStorageType ) -> Set [VkMemoryLayout ]:
173
+ """
174
+ Return the set of memory layouts supported by this operator for a given storage
175
+ type.
176
+ """
177
+ if storage == VkStorageType .TEXTURE_3D :
178
+ assert self .texture_impl is not None
179
+ return self .texture_impl .valid_memory_layouts ()
180
+ else :
181
+ return all_memory_layouts
182
+
183
+
184
+ #######################
185
+ ## Operator Registry ##
186
+ #######################
86
187
87
188
OpKey = Union [str , torch ._ops .OpOverload , EdgeOpOverload ]
88
189
@@ -122,8 +223,8 @@ def update_features_impl(op: OpKey):
122
223
)
123
224
def register_ephemeral_op (features : OpFeatures ):
124
225
features .texture_impl = TextureImplFeatures (
125
- uses_packed_dim = True ,
126
226
uses_axis_map = True ,
227
+ valid_packed_dims = all_packed_dims ,
127
228
)
128
229
features .buffer_impl = True
129
230
features .resize_fn = True
@@ -143,8 +244,8 @@ def register_ephemeral_op(features: OpFeatures):
143
244
)
144
245
def register_binary_op (features : OpFeatures ):
145
246
features .texture_impl = TextureImplFeatures (
146
- uses_packed_dim = True ,
147
247
uses_axis_map = True ,
248
+ valid_packed_dims = all_packed_dims ,
148
249
)
149
250
features .resize_fn = True
150
251
return features
@@ -170,8 +271,8 @@ def register_binary_op(features: OpFeatures):
170
271
)
171
272
def register_unary_op (features : OpFeatures ):
172
273
features .texture_impl = TextureImplFeatures (
173
- uses_packed_dim = True ,
174
274
uses_axis_map = True ,
275
+ valid_packed_dims = all_packed_dims ,
175
276
)
176
277
features .buffer_impl = True
177
278
features .resize_fn = True
@@ -181,8 +282,8 @@ def register_unary_op(features: OpFeatures):
181
282
@update_features (exir_ops .edge .aten ._to_copy .default )
182
283
def register_to_copy_op (features : OpFeatures ):
183
284
features .texture_impl = TextureImplFeatures (
184
- uses_packed_dim = True ,
185
285
uses_axis_map = True ,
286
+ valid_packed_dims = all_packed_dims ,
186
287
)
187
288
features .resize_fn = True
188
289
@@ -220,40 +321,43 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
220
321
)
221
322
def register_mm_op (features : OpFeatures ):
222
323
features .texture_impl = TextureImplFeatures (
223
- uses_packed_dim = False ,
224
324
uses_axis_map = True ,
225
- supported_layouts = [
226
- VkMemoryLayout . TENSOR_WIDTH_PACKED ,
227
- VkMemoryLayout . TENSOR_CHANNELS_PACKED ,
228
- ] ,
325
+ valid_packed_dims = {
326
+ PackedDim . WIDTH ,
327
+ PackedDim . CHANNELS ,
328
+ } ,
229
329
)
230
330
features .buffer_impl = True
231
331
features .resize_fn = True
332
+ features .optimal_storage = VkStorageType .TEXTURE_3D
333
+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
232
334
features .handles_own_prepacking = True
233
335
return features
234
336
235
337
236
338
@update_features (exir_ops .edge .aten ._weight_int8pack_mm .default )
237
339
def register_int8_mm_op (features : OpFeatures ):
238
340
features .texture_impl = TextureImplFeatures (
239
- uses_packed_dim = False ,
240
341
uses_axis_map = False ,
241
- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
342
+ valid_packed_dims = { PackedDim . WIDTH } ,
242
343
)
243
344
features .buffer_impl = True
244
345
features .resize_fn = True
346
+ features .optimal_storage = VkStorageType .TEXTURE_3D
347
+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
245
348
features .handles_own_prepacking = True
246
349
return features
247
350
248
351
249
352
@update_features (exir_ops .edge .et_vk .linear_weight_int4 .default )
250
353
def register_int4_mm_op (features : OpFeatures ):
251
354
features .texture_impl = TextureImplFeatures (
252
- uses_packed_dim = False ,
253
355
uses_axis_map = False ,
254
- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
356
+ valid_packed_dims = { PackedDim . WIDTH } ,
255
357
)
256
358
features .resize_fn = True
359
+ features .optimal_storage = VkStorageType .TEXTURE_3D
360
+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
257
361
features .handles_own_prepacking = True
258
362
return features
259
363
@@ -266,7 +370,7 @@ def register_int4_mm_op(features: OpFeatures):
266
370
)
267
371
def register_softmax_op (features : OpFeatures ):
268
372
features .texture_impl = TextureImplFeatures (
269
- uses_packed_dim = True ,
373
+ valid_packed_dims = all_packed_dims ,
270
374
)
271
375
features .resize_fn = True
272
376
return features
@@ -282,7 +386,7 @@ def register_softmax_op(features: OpFeatures):
282
386
)
283
387
def register_reduce_op (features : OpFeatures ):
284
388
features .texture_impl = TextureImplFeatures (
285
- uses_packed_dim = True ,
389
+ valid_packed_dims = all_packed_dims ,
286
390
)
287
391
features .resize_fn = True
288
392
@@ -309,7 +413,7 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
309
413
)
310
414
def register_2d_pool_op (features : OpFeatures ):
311
415
features .texture_impl = TextureImplFeatures (
312
- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
416
+ valid_packed_dims = { PackedDim . CHANNELS } ,
313
417
)
314
418
features .resize_fn = True
315
419
return features
@@ -323,27 +427,31 @@ def register_2d_pool_op(features: OpFeatures):
323
427
)
324
428
def register_convolution_op (features : OpFeatures ):
325
429
features .texture_impl = TextureImplFeatures (
326
- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
430
+ valid_packed_dims = { PackedDim . CHANNELS } ,
327
431
)
328
432
features .resize_fn = True
433
+ features .optimal_storage = VkStorageType .TEXTURE_3D
434
+ features .optimal_layout = VkMemoryLayout .TENSOR_CHANNELS_PACKED
329
435
features .handles_own_prepacking = True
330
436
return features
331
437
332
438
333
439
@update_features ("llama::sdpa_with_kv_cache" )
334
440
def register_sdpa_op (features : OpFeatures ):
335
441
features .texture_impl = TextureImplFeatures (
336
- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
442
+ valid_packed_dims = { PackedDim . WIDTH } ,
337
443
)
338
444
features .resize_fn = True
445
+ features .optimal_storage = VkStorageType .TEXTURE_3D
446
+ features .optimal_layout = VkMemoryLayout .TENSOR_WIDTH_PACKED
339
447
features .handles_own_prepacking = True
340
448
return features
341
449
342
450
343
451
@update_features (exir_ops .edge .et_vk .apply_rotary_emb .default )
344
452
def register_rotary_emb_op (features : OpFeatures ):
345
453
features .texture_impl = TextureImplFeatures (
346
- supported_layouts = [ VkMemoryLayout . TENSOR_WIDTH_PACKED ] ,
454
+ valid_packed_dims = { PackedDim . WIDTH } ,
347
455
)
348
456
features .resize_fn = True
349
457
return features
@@ -352,7 +460,7 @@ def register_rotary_emb_op(features: OpFeatures):
352
460
@update_features (exir_ops .edge .aten .view_copy .default )
353
461
def register_view_op (features : OpFeatures ):
354
462
features .texture_impl = TextureImplFeatures (
355
- uses_packed_dim = True ,
463
+ valid_packed_dims = all_packed_dims ,
356
464
)
357
465
features .resize_fn = True
358
466
return features
@@ -393,7 +501,7 @@ def register_view_op(features: OpFeatures):
393
501
)
394
502
def register_ported_op (features : OpFeatures ):
395
503
features .texture_impl = TextureImplFeatures (
396
- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
504
+ valid_packed_dims = { PackedDim . CHANNELS } ,
397
505
)
398
506
return features
399
507
@@ -408,15 +516,24 @@ def register_ported_op(features: OpFeatures):
408
516
)
409
517
def register_ported_ops_with_prepacking (features : OpFeatures ):
410
518
features .texture_impl = TextureImplFeatures (
411
- supported_layouts = [ VkMemoryLayout . TENSOR_CHANNELS_PACKED ] ,
519
+ valid_packed_dims = { PackedDim . CHANNELS } ,
412
520
)
413
521
features .handles_own_prepacking = True
414
522
return features
415
523
416
524
417
- ##
418
- ## Utility Functions
419
- ##
525
+ #######################
526
+ ## Utility functions ##
527
+ #######################
528
+
529
+
530
+ def has_impl (target : OpKey ) -> bool :
531
+ if not isinstance (target , str ):
532
+ if target not in vulkan_supported_ops :
533
+ return target .name () in vulkan_supported_ops
534
+ return target in vulkan_supported_ops
535
+ else :
536
+ return target in vulkan_supported_ops
420
537
421
538
422
539
def get_op_features (target : OpKey ) -> OpFeatures :
0 commit comments