|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from typing import Optional |
| 8 | + |
| 9 | +import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema |
| 10 | + |
| 11 | +import torch |
| 12 | + |
| 13 | +from executorch.exir.tensor import TensorSpec |
| 14 | +from torch._export.utils import get_buffer, get_param, is_buffer, is_param |
| 15 | +from torch.export import ExportedProgram |
| 16 | +from torch.fx import Node |
| 17 | + |
| 18 | + |
| 19 | +class VkGraphBuilder: |
| 20 | + def __init__(self, program: ExportedProgram) -> None: |
| 21 | + self.program = program |
| 22 | + |
| 23 | + self.chain = [] |
| 24 | + self.values = [] |
| 25 | + self.input_ids = [] |
| 26 | + self.output_ids = [] |
| 27 | + self.const_tensors = [] |
| 28 | + |
| 29 | + # Mapping from torch.fx.Node to VkValue id |
| 30 | + self.node_to_value_ids = {} |
| 31 | + |
| 32 | + @staticmethod |
| 33 | + def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: |
| 34 | + if torch_dtype == torch.float32: |
| 35 | + return vk_graph_schema.VkDataType.fp32 |
| 36 | + else: |
| 37 | + raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") |
| 38 | + |
| 39 | + def is_constant(self, node: torch.fx.Node): |
| 40 | + return ( |
| 41 | + node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants |
| 42 | + ) |
| 43 | + |
| 44 | + def is_get_attr_node(self, node: torch.fx.Node) -> bool: |
| 45 | + """ |
| 46 | + Returns true if the given node is a get attr node for a tensor of the model |
| 47 | + """ |
| 48 | + return isinstance(node, torch.fx.Node) and node.op == "get_attr" |
| 49 | + |
| 50 | + def is_param_node(self, node: torch.fx.Node) -> bool: |
| 51 | + """ |
| 52 | + Check if the given node is a parameter within the exported program |
| 53 | + """ |
| 54 | + return ( |
| 55 | + self.is_get_attr_node(node) |
| 56 | + or is_param(self.program, node) |
| 57 | + or is_buffer(self.program, node) |
| 58 | + or self.is_constant(node) |
| 59 | + ) |
| 60 | + |
| 61 | + def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]: |
| 62 | + """ |
| 63 | + Returns the constant associated with the given node in the exported program. |
| 64 | + Returns None if the node is not a constant within the exported program |
| 65 | + """ |
| 66 | + if self.is_constant(node): |
| 67 | + constant_name = ( |
| 68 | + self.program.graph_signature.inputs_to_lifted_tensor_constants[ |
| 69 | + node.name |
| 70 | + ] |
| 71 | + ) |
| 72 | + if constant_name in self.program.constants: |
| 73 | + return self.program.constants[constant_name] |
| 74 | + else: |
| 75 | + return None |
| 76 | + |
| 77 | + return None |
| 78 | + |
| 79 | + def get_param_tensor(self, node: torch.fx.Node) -> torch.Tensor: |
| 80 | + tensor = None |
| 81 | + if node is None: |
| 82 | + raise RuntimeError("node is None") |
| 83 | + elif is_param(self.program, node): |
| 84 | + tensor = get_param(self.program, node) |
| 85 | + elif is_buffer(self.program, node): |
| 86 | + tensor = get_buffer(self.program, node) |
| 87 | + elif self.is_constant(node): |
| 88 | + tensor = self.get_constant(node) |
| 89 | + elif self.is_get_attr_node(node): |
| 90 | + # This is a hack to support both lifted and unlifted graph |
| 91 | + try: |
| 92 | + tensor = getattr(node.graph.owning_module, node.target) |
| 93 | + except AttributeError: |
| 94 | + tensor = getattr(self.program.graph_module, node.target) |
| 95 | + else: |
| 96 | + raise RuntimeError(f"unsupported param type, {node.op}.") |
| 97 | + |
| 98 | + assert tensor is not None |
| 99 | + return tensor |
| 100 | + |
| 101 | + def maybe_add_constant_tensor(self, node: Node) -> int: |
| 102 | + const_buffer_idx = -1 |
| 103 | + if self.is_param_node(node): |
| 104 | + const_buffer_idx = len(self.const_tensors) |
| 105 | + self.const_tensors.append(self.get_param_tensor(node)) |
| 106 | + |
| 107 | + return const_buffer_idx |
| 108 | + |
| 109 | + def create_single_vk_value(self, node: Node) -> int: |
| 110 | + constant_id = self.maybe_add_constant_tensor(node) |
| 111 | + |
| 112 | + spec = node.meta.get("spec") |
| 113 | + assert isinstance(spec, TensorSpec) |
| 114 | + new_id = len(self.values) |
| 115 | + if node not in self.node_to_value_ids: |
| 116 | + self.node_to_value_ids[node] = new_id |
| 117 | + else: |
| 118 | + current_ids = self.node_to_value_ids[node] |
| 119 | + if isinstance(current_ids, int): |
| 120 | + current_ids = [current_ids, new_id] |
| 121 | + else: |
| 122 | + current_ids.append(new_id) |
| 123 | + |
| 124 | + # Negative id indicates that this tensor will have its own dedicated memory. |
| 125 | + mem_obj_id = -1 |
| 126 | + if spec.mem_obj_id is not None: |
| 127 | + mem_obj_id = spec.mem_obj_id |
| 128 | + |
| 129 | + self.values.append( |
| 130 | + vk_graph_schema.VkValue( |
| 131 | + value=vk_graph_schema.VkTensor( |
| 132 | + datatype=self.get_vk_datatype(spec.dtype), |
| 133 | + dims=spec.shape, |
| 134 | + constant_id=constant_id, |
| 135 | + mem_obj_id=mem_obj_id, |
| 136 | + ) |
| 137 | + ) |
| 138 | + ) |
| 139 | + return new_id |
| 140 | + |
| 141 | + def create_vk_values_for(self, node: Node): |
| 142 | + spec = node.meta.get("spec") |
| 143 | + if isinstance(spec, TensorSpec): |
| 144 | + return self.create_single_vk_value(node) |
| 145 | + else: |
| 146 | + raise RuntimeError( |
| 147 | + "Creating values for nodes with collection types is not supported yet." |
| 148 | + ) |
| 149 | + |
| 150 | + def process_placeholder_node(self, node: Node) -> None: |
| 151 | + ids = self.create_vk_values_for(node) |
| 152 | + if not self.is_param_node(node): |
| 153 | + if isinstance(ids, int): |
| 154 | + self.input_ids.append(ids) |
| 155 | + else: |
| 156 | + self.input_ids += ids |
| 157 | + |
| 158 | + def process_call_function_node(self, node) -> None: |
| 159 | + args = [] |
| 160 | + # Add input nodes |
| 161 | + for inp_node in node.all_input_nodes: |
| 162 | + if inp_node not in self.node_to_value_ids: |
| 163 | + raise AssertionError( |
| 164 | + "Cannot find input to current node in node_to_value_ids. This means " |
| 165 | + "this node is being serialized before its input which is not allowed." |
| 166 | + ) |
| 167 | + args.append(self.node_to_value_ids[inp_node]) |
| 168 | + # Add output node |
| 169 | + args.append(self.create_vk_values_for(node)) |
| 170 | + |
| 171 | + self.chain.append( |
| 172 | + vk_graph_schema.OperatorCall( |
| 173 | + name=node.target.__name__, |
| 174 | + args=args, |
| 175 | + ), |
| 176 | + ) |
| 177 | + |
| 178 | + def process_getattr_node(self, node: Node) -> None: |
| 179 | + self.create_vk_values_for(node) |
| 180 | + |
| 181 | + def process_output_node(self, node: Node) -> None: |
| 182 | + if node.all_input_nodes[0] not in self.node_to_value_ids: |
| 183 | + raise AssertionError( |
| 184 | + "Cannot find input to output node in node_to_value_ids. This means the " |
| 185 | + "output node is being serialized before its corresponding internal node " |
| 186 | + "which is not allowed." |
| 187 | + ) |
| 188 | + self.output_ids.append(self.node_to_value_ids[node.all_input_nodes[0]]) |
| 189 | + |
| 190 | + def process_node(self, node: Node) -> None: |
| 191 | + if node.op == "placeholder": |
| 192 | + self.process_placeholder_node(node) |
| 193 | + elif node.op == "call_function": |
| 194 | + self.process_call_function_node(node) |
| 195 | + elif node.op == "get_attr": |
| 196 | + self.process_getattr_node(node) |
| 197 | + elif node.op == "output": |
| 198 | + self.process_output_node(node) |
| 199 | + else: |
| 200 | + raise AssertionError(f"Unsupported node op: {node.op}") |
| 201 | + |
| 202 | + def build_graph(self) -> vk_graph_schema.VkGraph: |
| 203 | + for node in self.program.graph_module.graph.nodes: |
| 204 | + self.process_node(node) |
| 205 | + |
| 206 | + return vk_graph_schema.VkGraph( |
| 207 | + version="0", |
| 208 | + chain=self.chain, |
| 209 | + values=self.values, |
| 210 | + input_ids=self.input_ids, |
| 211 | + output_ids=self.output_ids, |
| 212 | + constants=[], |
| 213 | + shaders=[], |
| 214 | + ) |
0 commit comments