|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import logging |
| 8 | +from copy import deepcopy |
| 9 | +from typing import Set |
| 10 | + |
| 11 | +import executorch.backends.vulkan.utils as utils |
| 12 | + |
| 13 | +import torch |
| 14 | + |
| 15 | +from executorch.backends.vulkan.op_registry import get_op_features, has_impl |
| 16 | + |
| 17 | +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( |
| 18 | + VkMemoryLayout, |
| 19 | + VkStorageType, |
| 20 | +) |
| 21 | + |
| 22 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 23 | + |
| 24 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 25 | + |
| 26 | +from torch._subclasses.fake_tensor import FakeTensor |
| 27 | + |
| 28 | +from torch.fx.passes.tools_common import NodeList |
| 29 | +from torch.fx.passes.utils.fuser_utils import topo_sort |
| 30 | + |
| 31 | +logger: logging.Logger = logging.getLogger("") |
| 32 | +logger.setLevel(logging.INFO) |
| 33 | + |
| 34 | + |
| 35 | +def set_memory_metadata( |
| 36 | + node: torch.fx.Node, storage: VkStorageType, layout: VkMemoryLayout |
| 37 | +) -> None: |
| 38 | + utils.set_node_spec_attr(node, "vk_storage_type", storage) |
| 39 | + utils.set_node_spec_attr(node, "vk_memory_layout", layout) |
| 40 | + |
| 41 | + |
| 42 | +class TagMemoryMetaPass(ExportPass): |
| 43 | + """ |
| 44 | + There are a variety of ways that tensors can be represented in Vulkan. The two main |
| 45 | + descriptors for how a tensor is laid out in memory is: |
| 46 | +
|
| 47 | + 1. Storage Type (buffer or texture) |
| 48 | + 2. Memory Layout (which dim is packed along a texel / has a stride of 1, etc.) |
| 49 | +
|
| 50 | + Due to the differences between buffers and textures, and the differences between |
| 51 | + different memory layouts, an implementation for an operator may only support a |
| 52 | + specific set of (storage type, memory layout) combinations. |
| 53 | +
|
| 54 | + Furthermore, if an operator implementation supports multiple (storage type, memory |
| 55 | + layout) combinations, there may be a "preferred" setting which results in optimal |
| 56 | + performance. |
| 57 | +
|
| 58 | + This pass is responsible for ensuring that all tensors participating in an operator |
| 59 | + call have a valid/optimal (storage type, memory layout) setting, and insert |
| 60 | + transition operators to transfer input tensors to the correct memory settings when |
| 61 | + necessary. |
| 62 | + """ |
| 63 | + |
| 64 | + def __init__( |
| 65 | + self, |
| 66 | + texture_limits: utils.ImageExtents, |
| 67 | + default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D, |
| 68 | + default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED, |
| 69 | + ): |
| 70 | + super().__init__() |
| 71 | + self.default_storage: VkStorageType = default_storage_type |
| 72 | + self.default_layout: VkMemoryLayout = default_memory_layout |
| 73 | + self.texture_limits = texture_limits |
| 74 | + |
| 75 | + def propose_node_storage( |
| 76 | + self, |
| 77 | + node: torch.fx.Node, |
| 78 | + ) -> VkStorageType: |
| 79 | + """ |
| 80 | + Uses the operator registry to determine the storage type that should be used for |
| 81 | + a given node. The storage type is determined with the following priorities: |
| 82 | + 1. In some cases, a tensor involved in the computation may be too large to be |
| 83 | + represented as a texture. If this is the case, the node is "opinionated" and |
| 84 | + buffer representation must be used. |
| 85 | + 1. If the operator called by the node indicates an optimal storage type, or only |
| 86 | + supports a single storage type, use that storage type. If either is true, |
| 87 | + then the node is considered to be opinionated as well. If multiple storage |
| 88 | + and no preferred storage type is indicated, then the node is not opinionated; |
| 89 | + go to the next step. |
| 90 | + 2. If the node's arguments already have memory metadata annotations, then |
| 91 | + preserve the settings of the first argument. Otherwise, proceed to the next |
| 92 | + step. |
| 93 | + 3. Recursively search the node's uses to see if any subsequent uses are |
| 94 | + opinionated; inherit the settings of the first opinionated node. If no |
| 95 | + opinionated user can be found, then proceed to the last step. |
| 96 | + 4. Use the default storage type setting. |
| 97 | + """ |
| 98 | + # The node may have an input/output tensor that is too big to be stored in a |
| 99 | + # texture. In this case, buffer storage must be used. Note that the partitioner |
| 100 | + # has already checked for the fact that buffer storage is supported by the |
| 101 | + # operator. |
| 102 | + if len(utils.possible_node_memory_layouts(node, self.texture_limits)) == 0: |
| 103 | + return VkStorageType.BUFFER |
| 104 | + |
| 105 | + valid_storage_types: Set[VkStorageType] = utils.all_storage_types |
| 106 | + |
| 107 | + # pyre-ignore |
| 108 | + if has_impl(node.target): |
| 109 | + # pyre-ignore |
| 110 | + features = get_op_features(node.target) |
| 111 | + valid_storage_types = features.supported_storage_types() |
| 112 | + storage = features.propose_storage_type() |
| 113 | + if storage is not None: |
| 114 | + return storage |
| 115 | + |
| 116 | + for arg in node.args: |
| 117 | + if isinstance(arg, torch.fx.Node) and isinstance( |
| 118 | + arg.meta["val"], FakeTensor |
| 119 | + ): |
| 120 | + storage = utils.get_node_storage_type(arg) |
| 121 | + if storage is not None and storage in valid_storage_types: |
| 122 | + return storage |
| 123 | + |
| 124 | + # If no storage type has been resolved yet, assume the optimal storage type of |
| 125 | + # the first opinionated user. This search is recursive. |
| 126 | + for user in node.users: |
| 127 | + optimal_storage = self.propose_node_storage(user) |
| 128 | + if optimal_storage is not None: |
| 129 | + return optimal_storage |
| 130 | + |
| 131 | + if self.default_storage in valid_storage_types: |
| 132 | + return self.default_storage |
| 133 | + else: |
| 134 | + return next(iter(valid_storage_types)) |
| 135 | + |
| 136 | + def propose_node_layout( |
| 137 | + self, |
| 138 | + node: torch.fx.Node, |
| 139 | + storage: VkStorageType, |
| 140 | + ) -> VkMemoryLayout: |
| 141 | + """ |
| 142 | + Performs the same steps as propose_node_storage, but detects the memory layout |
| 143 | + that should be used for the specific storage type. The same prioritization logic |
| 144 | + is applied. |
| 145 | + """ |
| 146 | + valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts |
| 147 | + # pyre-ignore |
| 148 | + if has_impl(node.target): |
| 149 | + # pyre-ignore |
| 150 | + features = get_op_features(node.target) |
| 151 | + valid_layouts = features.supported_memory_layouts(storage) |
| 152 | + layout = features.propose_memory_layout(storage) |
| 153 | + if layout is not None: |
| 154 | + return layout |
| 155 | + |
| 156 | + for arg in node.args: |
| 157 | + if isinstance(arg, torch.fx.Node) and isinstance( |
| 158 | + arg.meta["val"], FakeTensor |
| 159 | + ): |
| 160 | + layout = utils.get_node_memory_layout(arg) |
| 161 | + if layout is not None and layout in valid_layouts: |
| 162 | + return layout |
| 163 | + |
| 164 | + # If no storage type has been resolved yet, assume the optimal storage type of |
| 165 | + # the first opinionated user. This search is recursive. |
| 166 | + for user in node.users: |
| 167 | + optimal_storage = self.propose_node_layout(user, storage) |
| 168 | + if optimal_storage is not None: |
| 169 | + return optimal_storage |
| 170 | + |
| 171 | + # As a last resort, return the default storage type that should be used. |
| 172 | + if self.default_layout in valid_layouts: |
| 173 | + return self.default_layout |
| 174 | + else: |
| 175 | + return next(iter(valid_layouts)) |
| 176 | + |
| 177 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 178 | + sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes)) |
| 179 | + |
| 180 | + for node in sorted_nodes: |
| 181 | + if not isinstance(node.meta["val"], FakeTensor): |
| 182 | + continue |
| 183 | + |
| 184 | + if node.target == exir_ops.edge.et_vk.prepack.default: |
| 185 | + continue |
| 186 | + |
| 187 | + storage = self.propose_node_storage(node) |
| 188 | + layout = self.propose_node_layout(node, storage) |
| 189 | + |
| 190 | + set_memory_metadata(node, storage, layout) |
| 191 | + |
| 192 | + inserting_transitions_for_node = False |
| 193 | + for i, arg in enumerate(node.args): |
| 194 | + if not isinstance(arg, torch.fx.Node): |
| 195 | + continue |
| 196 | + if not isinstance(arg.meta["val"], FakeTensor): |
| 197 | + continue |
| 198 | + |
| 199 | + arg_storage = utils.get_node_storage_type(arg) |
| 200 | + arg_layout = utils.get_node_memory_layout(arg) |
| 201 | + |
| 202 | + if arg_storage is None: |
| 203 | + utils.set_node_spec_attr(arg, "vk_storage_type", storage) |
| 204 | + arg_storage = storage |
| 205 | + if arg_layout is None: |
| 206 | + utils.set_node_spec_attr(arg, "vk_memory_layout", layout) |
| 207 | + arg_layout = layout |
| 208 | + |
| 209 | + if arg_storage == storage and arg_layout == layout: |
| 210 | + continue |
| 211 | + |
| 212 | + if not inserting_transitions_for_node: |
| 213 | + inserting_transitions_for_node = True |
| 214 | + logger.info( |
| 215 | + f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:" |
| 216 | + ) |
| 217 | + |
| 218 | + logger.info( |
| 219 | + f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})" |
| 220 | + ) |
| 221 | + |
| 222 | + # Insert a clone node to copy the original tensor to a tensor with the |
| 223 | + # desired storage type and memory layout. |
| 224 | + with graph_module.graph.inserting_before(node): |
| 225 | + clone_node = graph_module.graph.create_node( |
| 226 | + "call_function", |
| 227 | + exir_ops.edge.aten.clone.default, |
| 228 | + (arg,), |
| 229 | + ) |
| 230 | + clone_node.meta["val"] = arg.meta["val"] |
| 231 | + clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) |
| 232 | + clone_node.meta["spec"].const = False |
| 233 | + set_memory_metadata(clone_node, storage, layout) |
| 234 | + arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) |
| 235 | + |
| 236 | + return PassResult(graph_module, True) |
0 commit comments