24
24
from executorch .examples .models .llama2 .source_transformation .sdpa import (
25
25
replace_sdpa_with_custom_op ,
26
26
)
27
+ from executorch .examples .models .llava .model import LlavaModel
27
28
from executorch .exir import EdgeCompileConfig
28
29
from executorch .exir .program ._program import _to_edge_transform_and_lower
29
30
30
31
from executorch .extension .llm .export .builder import DType , LLMEdgeManager
31
- from model import LlavaModel
32
32
from torch .ao .quantization .quantizer .xnnpack_quantizer import (
33
33
get_symmetric_quantization_config ,
34
34
XNNPACKQuantizer ,
@@ -85,7 +85,7 @@ def forward(self, input_pos, embeddings):
85
85
["-X" , "-qmode" , "8da4w" , "--group_size" , "128" , "--embedding-quantize" , "4,32" ]
86
86
)
87
87
quant_transform = get_quant_weight_transform (args , dtype_override , False )
88
- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
88
+ _ , quantizers , _ = get_quantizer_and_quant_params (args )
89
89
source_transforms = []
90
90
if llava .use_sdpa_with_kv_cache_op :
91
91
source_transforms .append (replace_sdpa_with_custom_op )
@@ -149,15 +149,7 @@ def forward(self, images):
149
149
150
150
151
151
def export_token_embedding (llava , prompt ):
152
- embed = torch .nn .Embedding (
153
- llava .model_ .config .vocab_size ,
154
- llava .model_ .config .hidden_size ,
155
- llava .model_ .config .pad_token_id ,
156
- )
157
- embed .load_state_dict (
158
- llava .model_ .get_model ().embed_tokens .state_dict (), strict = True , assign = True
159
- )
160
- embed = embed .to (torch .float32 )
152
+ embed = llava .embed_tokens
161
153
token_dim_1 = Dim ("token_dim_1" , min = 2 , max = 3518 )
162
154
dynamic_shapes = [{1 : token_dim_1 }]
163
155
with torch .no_grad ():
@@ -167,24 +159,7 @@ def export_token_embedding(llava, prompt):
167
159
return token_embedding_ep
168
160
169
161
170
- def main ():
171
- parser = ArgumentParser ()
172
- parser .add_argument (
173
- "--use-sdpa-with-kv-cache" ,
174
- default = True ,
175
- action = BooleanOptionalAction ,
176
- help = "Use sdpa_with_kv_cache custom op in LLava text model." ,
177
- )
178
- parser .add_argument (
179
- "--pte-name" ,
180
- default = "llava_combined_xnnpack.pte" ,
181
- help = "Name of the exported ExecuTorch program." ,
182
- )
183
- args = parser .parse_args ()
184
- logging .info (
185
- f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: { args .use_sdpa_with_kv_cache } "
186
- )
187
- llava_model = LlavaModel (use_sdpa_with_kv_cache_op = args .use_sdpa_with_kv_cache )
162
+ def export_all (llava_model : LlavaModel ):
188
163
llava = llava_model .get_eager_model ()
189
164
190
165
(
@@ -226,6 +201,29 @@ def main():
226
201
)
227
202
228
203
executorch_program = lowered_and_edge .to_executorch ()
204
+ return executorch_program
205
+
206
+
207
+ def main ():
208
+ parser = ArgumentParser ()
209
+ parser .add_argument (
210
+ "--use-sdpa-with-kv-cache" ,
211
+ default = True ,
212
+ action = BooleanOptionalAction ,
213
+ help = "Use sdpa_with_kv_cache custom op in LLava text model." ,
214
+ )
215
+ parser .add_argument (
216
+ "--pte-name" ,
217
+ default = "llava_combined_xnnpack.pte" ,
218
+ help = "Name of the exported ExecuTorch program." ,
219
+ )
220
+ args = parser .parse_args ()
221
+ logging .info (
222
+ f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: { args .use_sdpa_with_kv_cache } "
223
+ )
224
+ llava_model = LlavaModel (use_sdpa_with_kv_cache_op = args .use_sdpa_with_kv_cache )
225
+
226
+ executorch_program = export_all (llava_model )
229
227
230
228
with open (args .pte_name , "wb" ) as f :
231
229
executorch_program .write_to_file (f )
0 commit comments