5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import os
8
- from typing import Optional
8
+ from typing import Dict , Optional
9
9
10
10
import torch
11
11
import torch ._inductor
@@ -39,6 +39,7 @@ def export_for_server(
39
39
output_path : str = "model.pt2" ,
40
40
dynamic_shapes : bool = False ,
41
41
package : bool = True ,
42
+ metadata : Optional [Dict [str , str ]] = None ,
42
43
) -> str :
43
44
"""
44
45
Export the model using AOT Compile to get a .dso for server use cases.
@@ -67,8 +68,10 @@ def export_for_server(
67
68
dynamic_shapes = None
68
69
69
70
with torch .nn .attention .sdpa_kernel ([torch .nn .attention .SDPBackend .MATH ]):
70
- metadata = {} # TODO: put more metadata here
71
- options = {"aot_inductor.metadata" : metadata }
71
+ options = {
72
+ "aot_inductor.package" : package ,
73
+ "aot_inductor.metadata" : metadata or {},
74
+ }
72
75
if not package :
73
76
options = {"aot_inductor.output_path" : output_path }
74
77
@@ -106,13 +109,13 @@ def export_for_server(
106
109
from typing import Any , Dict , Tuple , Union
107
110
108
111
import executorch .exir as exir
112
+ from executorch .backends .xnnpack ._passes .convert_to_linear import (
113
+ ConvertToLinearPass ,
114
+ )
109
115
110
116
from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
111
117
XnnpackDynamicallyQuantizedPartitioner ,
112
118
)
113
- from executorch .backends .xnnpack ._passes .convert_to_linear import (
114
- ConvertToLinearPass ,
115
- )
116
119
from executorch .exir import EdgeProgramManager , to_edge
117
120
118
121
from executorch .exir .capture ._config import (
@@ -170,18 +173,22 @@ def __init__(self, attention: Attention):
170
173
171
174
self .wo = attention .wo
172
175
173
- max_batch_size , n_heads , max_seq_length , head_dim = (
174
- attention . kv_cache [ 0 ]. k_cache . shape
175
- )
176
+ max_batch_size , n_heads , max_seq_length , head_dim = attention . kv_cache [
177
+ 0
178
+ ]. k_cache . shape
176
179
cache_dtype = attention .kv_cache [0 ].k_cache .dtype
177
180
# The `Attention` module being replaced can have multiple KV caches
178
181
# (denoted by `cache_lanes`). Thus we follow the same setup format
179
182
# as in `Attention.setup_cache`.
180
183
cache_lanes = len (attention .kv_cache )
181
- self .kv_cache = nn .ModuleList ([
182
- CustomKVCache (max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype )
183
- for _ in range (cache_lanes )
184
- ])
184
+ self .kv_cache = nn .ModuleList (
185
+ [
186
+ CustomKVCache (
187
+ max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype
188
+ )
189
+ for _ in range (cache_lanes )
190
+ ]
191
+ )
185
192
186
193
self .n_heads = attention .n_heads
187
194
self .head_dim = attention .head_dim
@@ -219,9 +226,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
219
226
return self .wo (output )
220
227
221
228
def replace_attention_with_custom_sdpa_attention (module : nn .Module ):
222
- from executorch .extension .llm .custom_ops import ( # noqa
223
- sdpa_with_kv_cache ,
224
- )
229
+ from executorch .extension .llm .custom_ops import sdpa_with_kv_cache # noqa
225
230
226
231
for name , child in module .named_children ():
227
232
if isinstance (child , Attention ):
@@ -242,7 +247,9 @@ def _to_core_aten(
242
247
raise ValueError (
243
248
f"Expected passed in model to be an instance of fx.GraphModule, got { type (model )} "
244
249
)
245
- core_aten_ep = export_for_training (model , example_inputs , dynamic_shapes = dynamic_shapes )
250
+ core_aten_ep = export_for_training (
251
+ model , example_inputs , dynamic_shapes = dynamic_shapes
252
+ )
246
253
if verbose :
247
254
logging .info (f"Core ATen graph:\n { core_aten_ep .graph } " )
248
255
return core_aten_ep
@@ -354,7 +361,11 @@ def main(args):
354
361
355
362
print (f"Using device={ builder_args .device } " )
356
363
set_precision (builder_args .precision )
357
- set_backend (dso = args .output_dso_path , pte = args .output_pte_path , aoti_package = args .output_aoti_package_path )
364
+ set_backend (
365
+ dso = args .output_dso_path ,
366
+ pte = args .output_pte_path ,
367
+ aoti_package = args .output_aoti_package_path ,
368
+ )
358
369
359
370
builder_args .dso_path = None
360
371
builder_args .pte_path = None
@@ -376,6 +387,7 @@ def main(args):
376
387
377
388
# TODO: clean this up
378
389
# This mess is because ET does not support _weight_int4pack_mm right now
390
+ tokenizer_args = None
379
391
if not builder_args .gguf_path :
380
392
# tokenizer needed for quantization so get that here,
381
393
try :
@@ -386,9 +398,8 @@ def main(args):
386
398
387
399
if builder_args .max_seq_length is None :
388
400
if (
389
- (output_dso_path is not None or output_aoti_package_path is not None )
390
- and not builder_args .dynamic_shapes
391
- ):
401
+ output_dso_path is not None or output_aoti_package_path is not None
402
+ ) and not builder_args .dynamic_shapes :
392
403
print ("Setting max_seq_length to 300 for DSO export." )
393
404
builder_args .max_seq_length = 300
394
405
elif output_pte_path is not None :
@@ -401,7 +412,8 @@ def main(args):
401
412
quantize ,
402
413
tokenizer ,
403
414
max_seq_length = builder_args .max_seq_length ,
404
- support_tensor_subclass = output_dso_path is None and output_aoti_package_path is None ,
415
+ support_tensor_subclass = output_dso_path is None
416
+ and output_aoti_package_path is None ,
405
417
)
406
418
model_to_pte = model
407
419
model_to_dso = model
@@ -439,7 +451,9 @@ def main(args):
439
451
if output_dso_path :
440
452
output_dso_path = str (os .path .abspath (output_dso_path ))
441
453
print (f"Exporting model using AOT Inductor to { output_dso_path } " )
442
- print ("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead." )
454
+ print (
455
+ "WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
456
+ )
443
457
export_for_server (
444
458
model_to_dso ,
445
459
builder_args .device ,
@@ -450,11 +464,23 @@ def main(args):
450
464
451
465
if output_aoti_package_path :
452
466
output_aoti_package_path = str (os .path .abspath (output_aoti_package_path ))
453
- print (f"Exporting model using AOT Inductor to { output_aoti_package_path } " )
467
+
468
+ if tokenizer_args is None :
469
+ tokenizer_type = "0"
470
+ elif tokenizer_args .is_sentencepiece :
471
+ tokenizer_type = "2" # Corresponding to llama2
472
+ else :
473
+ tokenizer_type = "3" # Corresponding to llama3
474
+
475
+ metadata = {"tokenizer_type" : tokenizer_type }
476
+ print (
477
+ "Exporting model using AOT Inductor to " f"{ output_aoti_package_path } ."
478
+ )
454
479
export_for_server (
455
480
model_to_aoti_package ,
456
481
builder_args .device ,
457
482
output_aoti_package_path ,
458
483
builder_args .dynamic_shapes ,
459
484
package = True ,
485
+ metadata = metadata ,
460
486
)
0 commit comments