@@ -661,10 +661,31 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661
661
logging .info (f"Checkpoint dtype: { edge_manager .model .checkpoint_dtype } " )
662
662
edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
663
663
_get_source_transforms (
664
- modelname = args .model ,
665
664
dtype_override = dtype_override ,
665
+ checkpoint = args .checkpoint ,
666
666
checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
667
- args = args ,
667
+ tokenizer_path = args .tokenizer_path ,
668
+ use_spin_quant = args .use_spin_quant ,
669
+ embedding_quantize = args .embedding_quantize ,
670
+ quantization_mode = args .quantization_mode ,
671
+ expand_rope_table = args .expand_rope_table ,
672
+ use_custom_sdpa_with_attention_mask = getattr (args , "use_custom_sdpa_with_attention_mask" , False ),
673
+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
674
+ quantize_kv_cache = args .quantize_kv_cache ,
675
+ use_kv_cache = args .use_kv_cache ,
676
+ qnn = args .qnn ,
677
+ use_qnn_sha = args .use_qnn_sha ,
678
+ optimized_rotation_path = args .optimized_rotation_path ,
679
+ mps = args .mps ,
680
+ coreml = args .coreml ,
681
+ coreml_ios = args .coreml_ios ,
682
+ vulkan = args .vulkan ,
683
+ use_shared_embedding = args .use_shared_embedding ,
684
+ use_qat = args .use_qat ,
685
+ use_lora = args .use_lora ,
686
+ preq_mode = args .preq_mode ,
687
+ preq_group_size = args .preq_group_size ,
688
+ preq_embedding_quantize = args .preq_embedding_quantize ,
668
689
)
669
690
)
670
691
@@ -1155,23 +1176,65 @@ def _load_llama_model(
1155
1176
1156
1177
1157
1178
def _get_source_transforms ( # noqa
1158
- modelname : str ,
1159
1179
dtype_override : DType ,
1160
1180
* ,
1181
+ checkpoint : Optional [str ] = None ,
1161
1182
checkpoint_dtype : Optional [DType ] = None ,
1162
- args ,
1183
+ tokenizer_path : Optional [str ] = None ,
1184
+ use_spin_quant : Optional [str ] = None ,
1185
+ embedding_quantize : Optional [str ] = None ,
1186
+ quantization_mode : Optional [str ] = None ,
1187
+ expand_rope_table : bool = False ,
1188
+ use_custom_sdpa_with_attention_mask : bool = False ,
1189
+ use_sdpa_with_kv_cache : bool = False ,
1190
+ quantize_kv_cache : bool = False ,
1191
+ use_kv_cache : bool = False ,
1192
+ qnn : bool = False ,
1193
+ use_qnn_sha : bool = False ,
1194
+ optimized_rotation_path : Optional [str ] = None ,
1195
+ mps : bool = False ,
1196
+ coreml : bool = False ,
1197
+ coreml_ios : int = 15 ,
1198
+ vulkan : bool = False ,
1199
+ use_shared_embedding : bool = False ,
1200
+ use_qat : bool = False ,
1201
+ use_lora : int = 0 ,
1202
+ preq_mode : Optional [str ] = None ,
1203
+ preq_group_size : Optional [int ] = None ,
1204
+ preq_embedding_quantize : Optional [str ] = None ,
1163
1205
) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
1164
1206
"""
1165
1207
Return a list of functions that transform a graph.
1166
1208
1167
1209
Args:
1168
- modelname: The name of the model.
1169
1210
dtype_override: The dtype to use for the model.
1211
+ checkpoint: Path to the checkpoint file.
1170
1212
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
1171
1213
it means that you want to run quantize transformations on the weights represented
1172
1214
in their original dtype, while the overall dtype of the model maybe something
1173
1215
different. If not specified, defaults to dtype_override.
1174
- args: The arguments passed to the script.
1216
+ tokenizer_path: Path to the tokenizer file.
1217
+ use_spin_quant: Type of spin quant to use ("cuda" or "native").
1218
+ embedding_quantize: Type of embedding quantization.
1219
+ quantization_mode: Type of quantization mode.
1220
+ expand_rope_table: Whether to expand rope table.
1221
+ use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1222
+ use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1223
+ quantize_kv_cache: Whether to quantize KV cache.
1224
+ use_kv_cache: Whether to use KV cache.
1225
+ qnn: Whether to use QNN.
1226
+ use_qnn_sha: Whether to use QNN SHA.
1227
+ optimized_rotation_path: Path to optimized rotation.
1228
+ mps: Whether to use MPS.
1229
+ coreml: Whether to use CoreML.
1230
+ coreml_ios: CoreML iOS version.
1231
+ vulkan: Whether to use Vulkan.
1232
+ use_shared_embedding: Whether to use shared embedding.
1233
+ use_qat: Whether to use QAT.
1234
+ use_lora: LoRA rank (0 means no LoRA).
1235
+ preq_mode: Pre-quantization mode.
1236
+ preq_group_size: Pre-quantization group size.
1237
+ preq_embedding_quantize: Pre-quantization embedding quantize.
1175
1238
1176
1239
Returns:
1177
1240
A list of transformation functions.
@@ -1182,21 +1245,21 @@ def _get_source_transforms( # noqa
1182
1245
1183
1246
transforms = []
1184
1247
1185
- if args . use_spin_quant :
1186
- if args . use_spin_quant == "cuda" :
1248
+ if use_spin_quant :
1249
+ if use_spin_quant == "cuda" :
1187
1250
from .source_transformation .spin_quant import (
1188
1251
inject_fast_hadamard_transform_cuda_for_spin_quant ,
1189
1252
)
1190
1253
1191
1254
transforms .append (inject_fast_hadamard_transform_cuda_for_spin_quant )
1192
- elif args . use_spin_quant == "native" :
1255
+ elif use_spin_quant == "native" :
1193
1256
from .source_transformation .spin_quant import (
1194
1257
inject_fast_hadamard_transform_native_for_spin_quant ,
1195
1258
)
1196
1259
1197
1260
transforms .append (inject_fast_hadamard_transform_native_for_spin_quant )
1198
1261
1199
- if args . embedding_quantize :
1262
+ if embedding_quantize :
1200
1263
"""
1201
1264
When this option is selected, it finds all embedding layers and transforms
1202
1265
into quantized embedding equivalent module.
@@ -1206,12 +1269,25 @@ def _get_source_transforms( # noqa
1206
1269
transformations based on the given checkpoint first. In those cases,
1207
1270
this wil be a no-op.
1208
1271
"""
1209
- modelname = f"{ modelname } _e"
1272
+ # Create a mock args object with the necessary attributes
1273
+ class Args :
1274
+ pass
1275
+ args = Args ()
1276
+ args .checkpoint = checkpoint
1277
+ args .tokenizer_path = tokenizer_path
1278
+ args .embedding_quantize = embedding_quantize
1279
+ args .use_shared_embedding = use_shared_embedding
1280
+ args .use_qat = use_qat
1281
+ args .use_lora = use_lora
1282
+ args .preq_mode = preq_mode
1283
+ args .preq_group_size = preq_group_size
1284
+ args .preq_embedding_quantize = preq_embedding_quantize
1285
+
1210
1286
transforms .append (get_quant_embedding_transform (args , checkpoint_dtype ))
1211
1287
1212
1288
# quantization_mode should be applied after embedding_quantize
1213
1289
# to support shared_embedding
1214
- if args . quantization_mode :
1290
+ if quantization_mode :
1215
1291
"""
1216
1292
When this option is selected, it finds all linear layers and transforms
1217
1293
into quantized linear equivalent module.
@@ -1225,7 +1301,19 @@ def _get_source_transforms( # noqa
1225
1301
There are cases where this may be a no-op, namely, if all linears are
1226
1302
quantized in the checkpoint.
1227
1303
"""
1228
- modelname = f"{ modelname } _q"
1304
+ # Create a mock args object with the necessary attributes
1305
+ class Args :
1306
+ pass
1307
+ args = Args ()
1308
+ args .checkpoint = checkpoint
1309
+ args .tokenizer_path = tokenizer_path
1310
+ args .quantization_mode = quantization_mode
1311
+ args .group_size = preq_group_size # Using preq_group_size as group_size
1312
+ args .use_shared_embedding = use_shared_embedding
1313
+ args .use_qat = use_qat
1314
+ args .use_lora = use_lora
1315
+ args .preq_mode = preq_mode
1316
+
1229
1317
transforms .append (
1230
1318
get_quant_weight_transform (
1231
1319
args = args ,
@@ -1234,15 +1322,12 @@ def _get_source_transforms( # noqa
1234
1322
)
1235
1323
)
1236
1324
1237
- if args . expand_rope_table :
1325
+ if expand_rope_table :
1238
1326
transforms .append (materialze_broadcast_of_rope_freq_cis )
1239
1327
1240
- use_attention_mask_for_custom_sdpa = False
1241
- if isinstance (args , argparse .Namespace ):
1242
- if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1243
- use_attention_mask_for_custom_sdpa = True
1328
+ use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
1244
1329
1245
- if args . use_sdpa_with_kv_cache :
1330
+ if use_sdpa_with_kv_cache :
1246
1331
transforms .append (replace_kv_cache_with_custom_kv_cache )
1247
1332
# todo: do this optionally
1248
1333
# if use attention mask instead of causal attention
@@ -1254,23 +1339,23 @@ def _get_source_transforms( # noqa
1254
1339
else :
1255
1340
transforms .append (replace_sdpa_with_custom_op )
1256
1341
1257
- if args . quantize_kv_cache :
1258
- assert args . use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1342
+ if quantize_kv_cache :
1343
+ assert use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1259
1344
transforms .append (replace_kv_cache_with_quantized_kv_cache )
1260
1345
# Right now
1261
1346
transforms .append (replace_sdpa_with_quantized_sdpa )
1262
1347
1263
- if args . use_kv_cache :
1264
- if args . qnn :
1348
+ if use_kv_cache :
1349
+ if qnn :
1265
1350
from executorch .backends .qualcomm .utils .utils import (
1266
1351
convert_linear_to_conv2d ,
1267
1352
)
1268
1353
1269
- if args . use_qnn_sha :
1270
- if args . optimized_rotation_path :
1354
+ if use_qnn_sha :
1355
+ if optimized_rotation_path :
1271
1356
transforms .append (fuse_layer_norms )
1272
1357
transforms .append (
1273
- get_model_with_r1_r2 (args . optimized_rotation_path )
1358
+ get_model_with_r1_r2 (optimized_rotation_path )
1274
1359
)
1275
1360
transforms .append (replace_attention_to_attention_sha )
1276
1361
transforms .append (replace_causal_mask )
@@ -1282,29 +1367,29 @@ def _get_source_transforms( # noqa
1282
1367
transforms .append (replace_sdpa_with_flex_sdpa )
1283
1368
transforms .append (replace_causal_mask )
1284
1369
transforms .append (replace_rms_norm_with_native_rms_norm )
1285
- if args . optimized_rotation_path :
1370
+ if optimized_rotation_path :
1286
1371
transforms .append (fuse_layer_norms )
1287
1372
transforms .append (
1288
- get_model_with_r1_r2 (args . optimized_rotation_path )
1373
+ get_model_with_r1_r2 (optimized_rotation_path )
1289
1374
)
1290
1375
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
1291
1376
transforms .append (convert_linear_to_conv2d )
1292
1377
1293
- elif args . mps :
1378
+ elif mps :
1294
1379
# Currently mps doesn't support sdpa op, use the simpler decomposition
1295
1380
# to get free perf gain.
1296
1381
transforms .append (replace_sdpa_with_simple_sdpa )
1297
1382
transforms .append (replace_causal_mask )
1298
1383
1299
- elif args . coreml :
1384
+ elif coreml :
1300
1385
# iOS 18 introduced fused sdpa op
1301
- if args . coreml_ios >= 18 :
1386
+ if coreml_ios >= 18 :
1302
1387
transforms .append (replace_sdpa_with_coreml_sdpa )
1303
1388
else :
1304
1389
transforms .append (replace_sdpa_with_simple_sdpa )
1305
1390
transforms .append (replace_kv_cache_with_coreml_kv_cache )
1306
1391
1307
- if args . vulkan :
1392
+ if vulkan :
1308
1393
transforms .append (replace_with_vulkan_rotary_emb )
1309
1394
1310
1395
return transforms
0 commit comments