|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import torch |
| 8 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( |
| 9 | + XnnpackDynamicallyQuantizedPartitioner, |
| 10 | + # XnnpackFloatingPointPartitioner, |
| 11 | +) |
| 12 | +from executorch.examples.models.llama2.export_llama_lib import ( |
| 13 | + build_args_parser, |
| 14 | + get_quantizer_and_quant_params, |
| 15 | +) |
| 16 | +from executorch.examples.models.llama2.source_transformation.quantize import ( |
| 17 | + get_quant_weight_transform, |
| 18 | +) |
| 19 | +from executorch.examples.models.llama2.source_transformation.sdpa import ( |
| 20 | + replace_sdpa_with_custom_op, |
| 21 | +) |
| 22 | +from executorch.exir import EdgeCompileConfig, to_edge |
| 23 | + |
| 24 | +from executorch.extension.llm.export.builder import DType, LLMEdgeManager |
| 25 | +from model import LlavaModel |
| 26 | +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( |
| 27 | + get_symmetric_quantization_config, |
| 28 | + XNNPACKQuantizer, |
| 29 | +) |
| 30 | +from torch.export import Dim |
| 31 | +from torch.nn.attention import SDPBackend |
| 32 | + |
| 33 | + |
| 34 | +class LlavaEdgeManager(LLMEdgeManager): |
| 35 | + def capture_pre_autograd_graph(self) -> "LlavaEdgeManager": |
| 36 | + dynamic_shape = self._get_dynamic_shape() |
| 37 | + # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing |
| 38 | + # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) |
| 39 | + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): |
| 40 | + self.export_program = torch.export.export( |
| 41 | + self.model, |
| 42 | + self.example_inputs, |
| 43 | + dynamic_shapes=dynamic_shape, |
| 44 | + strict=False, |
| 45 | + ) |
| 46 | + self.pre_autograd_graph_module = self.export_program.module() |
| 47 | + return self |
| 48 | + |
| 49 | + |
| 50 | +def export_text_model(llava, embeddings, dynamic_shapes): |
| 51 | + class LlavaTextModel(torch.nn.Module): |
| 52 | + """Takes images and prompts and encode them into embeddings. Result will be sent to the text model LlavaTextModel.""" |
| 53 | + |
| 54 | + def __init__(self, llava): |
| 55 | + super().__init__() |
| 56 | + self.text_model = llava.text_model |
| 57 | + |
| 58 | + def forward(self, input_pos, embeddings): |
| 59 | + return self.text_model(None, input_pos, embeddings) |
| 60 | + |
| 61 | + llava_text_model = LlavaTextModel(llava) |
| 62 | + |
| 63 | + text_model_em = LLMEdgeManager( |
| 64 | + model=llava_text_model, |
| 65 | + modelname="llava_text_model", |
| 66 | + max_seq_len=llava.text_model_args.max_seq_len, |
| 67 | + dtype=DType.fp32, |
| 68 | + use_kv_cache=True, |
| 69 | + example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings), |
| 70 | + dynamic_shapes=dynamic_shapes, |
| 71 | + ) |
| 72 | + |
| 73 | + dtype_override = DType.fp32 |
| 74 | + parser = build_args_parser() |
| 75 | + args = parser.parse_args( |
| 76 | + ["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"] |
| 77 | + ) |
| 78 | + quant_transform = get_quant_weight_transform(args, dtype_override, False) |
| 79 | + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) |
| 80 | + |
| 81 | + manager = ( |
| 82 | + text_model_em.set_output_dir("./") |
| 83 | + .to_dtype(dtype_override) |
| 84 | + .source_transform([replace_sdpa_with_custom_op, quant_transform]) |
| 85 | + .capture_pre_autograd_graph() |
| 86 | + .pt2e_quantize(quantizers) |
| 87 | + ) |
| 88 | + |
| 89 | + with torch.no_grad(): |
| 90 | + text_model_ep = torch.export.export( |
| 91 | + manager.pre_autograd_graph_module, |
| 92 | + manager.example_inputs, |
| 93 | + dynamic_shapes=manager._get_dynamic_shape(), |
| 94 | + ) |
| 95 | + return text_model_ep |
| 96 | + |
| 97 | + |
| 98 | +def export_image_encoder(llava, resized, dynamic_shapes): |
| 99 | + class LlavaImageEncoder(torch.nn.Module): |
| 100 | + """Takes images and prompts and encode them into embeddings. Result will be sent to the text model LlavaTextModel.""" |
| 101 | + |
| 102 | + def __init__(self, llava): |
| 103 | + super().__init__() |
| 104 | + self.llava = llava |
| 105 | + |
| 106 | + def forward(self, images): |
| 107 | + return self.llava.image_embedding(images) |
| 108 | + |
| 109 | + llava_image_encode = LlavaImageEncoder(llava) |
| 110 | + |
| 111 | + # quantizer |
| 112 | + linear_quantizer = XNNPACKQuantizer() |
| 113 | + operator_config_dynamic = get_symmetric_quantization_config( |
| 114 | + is_per_channel=True, is_dynamic=True |
| 115 | + ) |
| 116 | + linear_quantizer.set_global(operator_config_dynamic) |
| 117 | + |
| 118 | + manager = LlavaEdgeManager( |
| 119 | + model=llava_image_encode, |
| 120 | + modelname="llava_image_encoder", |
| 121 | + max_seq_len=llava.text_model_args.max_seq_len, # This may not be right |
| 122 | + dtype=DType.fp32, |
| 123 | + use_kv_cache=True, |
| 124 | + example_inputs=(resized,), |
| 125 | + dynamic_shapes=dynamic_shapes, |
| 126 | + ).capture_pre_autograd_graph() |
| 127 | + |
| 128 | + # lower to executorch |
| 129 | + with torch.no_grad(): |
| 130 | + image_encoder_ep = torch.export.export( |
| 131 | + manager.pre_autograd_graph_module, |
| 132 | + manager.example_inputs, |
| 133 | + dynamic_shapes=manager.dynamic_shapes, |
| 134 | + ) |
| 135 | + return image_encoder_ep |
| 136 | + |
| 137 | + |
| 138 | +def export_token_embedding(llava, prompt): |
| 139 | + embed = torch.nn.Embedding( |
| 140 | + llava.model_.config.vocab_size, |
| 141 | + llava.model_.config.hidden_size, |
| 142 | + llava.model_.config.pad_token_id, |
| 143 | + ) |
| 144 | + embed.load_state_dict( |
| 145 | + llava.model_.get_model().embed_tokens.state_dict(), strict=True, assign=True |
| 146 | + ) |
| 147 | + embed = embed.to(torch.float32) |
| 148 | + token_dim_1 = Dim("token_dim_1", min=2, max=3518) |
| 149 | + dynamic_shapes = [{1: token_dim_1}] |
| 150 | + with torch.no_grad(): |
| 151 | + token_embedding_ep = torch.export.export( |
| 152 | + embed, (prompt,), dynamic_shapes=dynamic_shapes |
| 153 | + ) |
| 154 | + return token_embedding_ep |
| 155 | + |
| 156 | + |
| 157 | +def main(): |
| 158 | + llava_model = LlavaModel() |
| 159 | + llava = llava_model.get_eager_model() |
| 160 | + |
| 161 | + prompt_before_image, resized, prompt_after_image = ( |
| 162 | + llava_model.get_inputs_for_prefill() |
| 163 | + ) |
| 164 | + |
| 165 | + image_encoder_ep = export_image_encoder( |
| 166 | + llava, resized, llava_model._get_image_dynamic_shapes() |
| 167 | + ) |
| 168 | + |
| 169 | + embeddings = llava.prefill_embedding( |
| 170 | + prompt_before_image, resized, prompt_after_image |
| 171 | + ) |
| 172 | + |
| 173 | + text_model_ep = export_text_model( |
| 174 | + llava, embeddings, llava_model._get_prompt_dynamic_shapes() |
| 175 | + ) |
| 176 | + |
| 177 | + token_embedding_ep = export_token_embedding(llava, prompt_before_image) |
| 178 | + |
| 179 | + edge_ep = to_edge( |
| 180 | + { |
| 181 | + "image_encoder": image_encoder_ep, |
| 182 | + "token_embedding": token_embedding_ep, |
| 183 | + "text_model": text_model_ep, |
| 184 | + }, |
| 185 | + compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| 186 | + ) |
| 187 | + |
| 188 | + executorch_program = edge_ep.to_backend( |
| 189 | + { |
| 190 | + # TODO: Fix Xnnpack partitioner issue on image encoder. |
| 191 | + # "image_encoder": XnnpackFloatingPointPartitioner(), |
| 192 | + "text_model": XnnpackDynamicallyQuantizedPartitioner(), |
| 193 | + } |
| 194 | + ).to_executorch() |
| 195 | + |
| 196 | + with open("llava_combined_xnnpack.pte", "wb") as f: |
| 197 | + executorch_program.write_to_file(f) |
| 198 | + |
| 199 | + |
| 200 | +if __name__ == "__main__": |
| 201 | + main() |
0 commit comments