@@ -661,10 +661,37 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661
661
logging .info (f"Checkpoint dtype: { edge_manager .model .checkpoint_dtype } " )
662
662
edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
663
663
_get_source_transforms (
664
- modelname = args .model ,
665
664
dtype_override = dtype_override ,
665
+ checkpoint = args .checkpoint ,
666
666
checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
667
- args = args ,
667
+ tokenizer_path = args .tokenizer_path ,
668
+ use_spin_quant = args .use_spin_quant ,
669
+ embedding_quantize = args .embedding_quantize ,
670
+ use_shared_embedding = args .use_shared_embedding ,
671
+ quantization_mode = args .quantization_mode ,
672
+ group_size = args .group_size ,
673
+ calibration_tasks = args .calibration_tasks ,
674
+ calibration_limit = args .calibration_limit ,
675
+ calibration_seq_length = args .calibration_seq_length ,
676
+ expand_rope_table = args .expand_rope_table ,
677
+ use_custom_sdpa_with_attention_mask = getattr (
678
+ args , "use_custom_sdpa_with_attention_mask" , False
679
+ ),
680
+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
681
+ quantize_kv_cache = args .quantize_kv_cache ,
682
+ use_kv_cache = args .use_kv_cache ,
683
+ qnn = args .qnn ,
684
+ use_qnn_sha = args .use_qnn_sha ,
685
+ optimized_rotation_path = args .optimized_rotation_path ,
686
+ mps = args .mps ,
687
+ coreml = args .coreml ,
688
+ coreml_ios = args .coreml_ios ,
689
+ vulkan = args .vulkan ,
690
+ use_qat = args .use_qat ,
691
+ use_lora = args .use_lora ,
692
+ preq_mode = args .preq_mode ,
693
+ preq_group_size = args .preq_group_size ,
694
+ preq_embedding_quantize = args .preq_embedding_quantize ,
668
695
)
669
696
)
670
697
@@ -1155,23 +1182,69 @@ def _load_llama_model(
1155
1182
1156
1183
1157
1184
def _get_source_transforms ( # noqa
1158
- modelname : str ,
1159
1185
dtype_override : DType ,
1160
1186
* ,
1187
+ checkpoint : Optional [str ] = None ,
1161
1188
checkpoint_dtype : Optional [DType ] = None ,
1162
- args ,
1189
+ tokenizer_path : Optional [str ] = None ,
1190
+ use_spin_quant : Optional [str ] = None ,
1191
+ embedding_quantize : Optional [str ] = None ,
1192
+ use_shared_embedding : bool = False ,
1193
+ quantization_mode : Optional [str ] = None ,
1194
+ group_size : Optional [int ] = None ,
1195
+ calibration_tasks : Optional [List [str ]] = None ,
1196
+ calibration_limit : Optional [int ] = None ,
1197
+ calibration_seq_length : Optional [int ] = None ,
1198
+ expand_rope_table : bool = False ,
1199
+ use_custom_sdpa_with_attention_mask : bool = False ,
1200
+ use_sdpa_with_kv_cache : bool = False ,
1201
+ quantize_kv_cache : bool = False ,
1202
+ use_kv_cache : bool = False ,
1203
+ qnn : bool = False ,
1204
+ use_qnn_sha : bool = False ,
1205
+ optimized_rotation_path : Optional [str ] = None ,
1206
+ mps : bool = False ,
1207
+ coreml : bool = False ,
1208
+ coreml_ios : int = 15 ,
1209
+ vulkan : bool = False ,
1210
+ use_qat : bool = False ,
1211
+ use_lora : int = 0 ,
1212
+ preq_mode : Optional [str ] = None ,
1213
+ preq_group_size : Optional [int ] = None ,
1214
+ preq_embedding_quantize : Optional [str ] = None ,
1163
1215
) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
1164
1216
"""
1165
1217
Return a list of functions that transform a graph.
1166
1218
1167
1219
Args:
1168
- modelname: The name of the model.
1169
1220
dtype_override: The dtype to use for the model.
1221
+ checkpoint: Path to the checkpoint file.
1170
1222
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
1171
1223
it means that you want to run quantize transformations on the weights represented
1172
1224
in their original dtype, while the overall dtype of the model maybe something
1173
1225
different. If not specified, defaults to dtype_override.
1174
- args: The arguments passed to the script.
1226
+ tokenizer_path: Path to the tokenizer file.
1227
+ use_spin_quant: Type of spin quant to use ("cuda" or "native").
1228
+ embedding_quantize: Type of embedding quantization.
1229
+ quantization_mode: Type of quantization mode.
1230
+ expand_rope_table: Whether to expand rope table.
1231
+ use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1232
+ use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1233
+ quantize_kv_cache: Whether to quantize KV cache.
1234
+ use_kv_cache: Whether to use KV cache.
1235
+ qnn: Whether to use QNN.
1236
+ use_qnn_sha: Whether to use QNN SHA.
1237
+ optimized_rotation_path: Path to optimized rotation.
1238
+ mps: Whether to use MPS.
1239
+ coreml: Whether to use CoreML.
1240
+ coreml_ios: CoreML iOS version.
1241
+ vulkan: Whether to use Vulkan.
1242
+ use_shared_embedding: Whether to use shared embedding.
1243
+ use_qat: Whether to use QAT.
1244
+ use_lora: LoRA rank (0 means no LoRA).
1245
+ preq_mode: Pre-quantization mode.
1246
+ preq_group_size: Pre-quantization group size.
1247
+ preq_embedding_quantize: Pre-quantization embedding quantize.
1175
1248
1176
1249
Returns:
1177
1250
A list of transformation functions.
@@ -1182,21 +1255,21 @@ def _get_source_transforms( # noqa
1182
1255
1183
1256
transforms = []
1184
1257
1185
- if args . use_spin_quant :
1186
- if args . use_spin_quant == "cuda" :
1258
+ if use_spin_quant :
1259
+ if use_spin_quant == "cuda" :
1187
1260
from .source_transformation .spin_quant import (
1188
1261
inject_fast_hadamard_transform_cuda_for_spin_quant ,
1189
1262
)
1190
1263
1191
1264
transforms .append (inject_fast_hadamard_transform_cuda_for_spin_quant )
1192
- elif args . use_spin_quant == "native" :
1265
+ elif use_spin_quant == "native" :
1193
1266
from .source_transformation .spin_quant import (
1194
1267
inject_fast_hadamard_transform_native_for_spin_quant ,
1195
1268
)
1196
1269
1197
1270
transforms .append (inject_fast_hadamard_transform_native_for_spin_quant )
1198
1271
1199
- if args . embedding_quantize :
1272
+ if embedding_quantize :
1200
1273
"""
1201
1274
When this option is selected, it finds all embedding layers and transforms
1202
1275
into quantized embedding equivalent module.
@@ -1206,12 +1279,27 @@ def _get_source_transforms( # noqa
1206
1279
transformations based on the given checkpoint first. In those cases,
1207
1280
this wil be a no-op.
1208
1281
"""
1209
- modelname = f"{ modelname } _e"
1282
+
1283
+ # Create a mock args object with the necessary attributes
1284
+ class Args :
1285
+ pass
1286
+
1287
+ args = Args ()
1288
+ args .checkpoint = checkpoint
1289
+ args .tokenizer_path = tokenizer_path
1290
+ args .embedding_quantize = embedding_quantize
1291
+ args .use_shared_embedding = use_shared_embedding
1292
+ args .use_qat = use_qat
1293
+ args .use_lora = use_lora
1294
+ args .preq_mode = preq_mode
1295
+ args .preq_group_size = preq_group_size
1296
+ args .preq_embedding_quantize = preq_embedding_quantize
1297
+
1210
1298
transforms .append (get_quant_embedding_transform (args , checkpoint_dtype ))
1211
1299
1212
1300
# quantization_mode should be applied after embedding_quantize
1213
1301
# to support shared_embedding
1214
- if args . quantization_mode :
1302
+ if quantization_mode :
1215
1303
"""
1216
1304
When this option is selected, it finds all linear layers and transforms
1217
1305
into quantized linear equivalent module.
@@ -1225,7 +1313,25 @@ def _get_source_transforms( # noqa
1225
1313
There are cases where this may be a no-op, namely, if all linears are
1226
1314
quantized in the checkpoint.
1227
1315
"""
1228
- modelname = f"{ modelname } _q"
1316
+
1317
+ # Create a mock args object with the necessary attributes
1318
+ class Args :
1319
+ pass
1320
+
1321
+ args = Args ()
1322
+ args .checkpoint = checkpoint
1323
+ args .tokenizer_path = tokenizer_path
1324
+ args .quantization_mode = quantization_mode
1325
+ args .group_size = group_size
1326
+ args .use_shared_embedding = use_shared_embedding
1327
+ args .calibration_tasks = calibration_tasks
1328
+ args .calibration_limit = calibration_limit
1329
+ args .calibration_seq_length = calibration_seq_length
1330
+ args .use_shared_embedding = use_shared_embedding
1331
+ args .use_qat = use_qat
1332
+ args .use_lora = use_lora
1333
+ args .preq_mode = preq_mode
1334
+
1229
1335
transforms .append (
1230
1336
get_quant_weight_transform (
1231
1337
args = args ,
@@ -1234,15 +1340,12 @@ def _get_source_transforms( # noqa
1234
1340
)
1235
1341
)
1236
1342
1237
- if args . expand_rope_table :
1343
+ if expand_rope_table :
1238
1344
transforms .append (materialze_broadcast_of_rope_freq_cis )
1239
1345
1240
- use_attention_mask_for_custom_sdpa = False
1241
- if isinstance (args , argparse .Namespace ):
1242
- if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1243
- use_attention_mask_for_custom_sdpa = True
1346
+ use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
1244
1347
1245
- if args . use_sdpa_with_kv_cache :
1348
+ if use_sdpa_with_kv_cache :
1246
1349
transforms .append (replace_kv_cache_with_custom_kv_cache )
1247
1350
# todo: do this optionally
1248
1351
# if use attention mask instead of causal attention
@@ -1254,24 +1357,22 @@ def _get_source_transforms( # noqa
1254
1357
else :
1255
1358
transforms .append (replace_sdpa_with_custom_op )
1256
1359
1257
- if args . quantize_kv_cache :
1258
- assert args . use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1360
+ if quantize_kv_cache :
1361
+ assert use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1259
1362
transforms .append (replace_kv_cache_with_quantized_kv_cache )
1260
1363
# Right now
1261
1364
transforms .append (replace_sdpa_with_quantized_sdpa )
1262
1365
1263
- if args . use_kv_cache :
1264
- if args . qnn :
1366
+ if use_kv_cache :
1367
+ if qnn :
1265
1368
from executorch .backends .qualcomm .utils .utils import (
1266
1369
convert_linear_to_conv2d ,
1267
1370
)
1268
1371
1269
- if args . use_qnn_sha :
1270
- if args . optimized_rotation_path :
1372
+ if use_qnn_sha :
1373
+ if optimized_rotation_path :
1271
1374
transforms .append (fuse_layer_norms )
1272
- transforms .append (
1273
- get_model_with_r1_r2 (args .optimized_rotation_path )
1274
- )
1375
+ transforms .append (get_model_with_r1_r2 (optimized_rotation_path ))
1275
1376
transforms .append (replace_attention_to_attention_sha )
1276
1377
transforms .append (replace_causal_mask )
1277
1378
transforms .append (replace_rms_norm_with_native_rms_norm )
@@ -1282,29 +1383,27 @@ def _get_source_transforms( # noqa
1282
1383
transforms .append (replace_sdpa_with_flex_sdpa )
1283
1384
transforms .append (replace_causal_mask )
1284
1385
transforms .append (replace_rms_norm_with_native_rms_norm )
1285
- if args . optimized_rotation_path :
1386
+ if optimized_rotation_path :
1286
1387
transforms .append (fuse_layer_norms )
1287
- transforms .append (
1288
- get_model_with_r1_r2 (args .optimized_rotation_path )
1289
- )
1388
+ transforms .append (get_model_with_r1_r2 (optimized_rotation_path ))
1290
1389
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
1291
1390
transforms .append (convert_linear_to_conv2d )
1292
1391
1293
- elif args . mps :
1392
+ elif mps :
1294
1393
# Currently mps doesn't support sdpa op, use the simpler decomposition
1295
1394
# to get free perf gain.
1296
1395
transforms .append (replace_sdpa_with_simple_sdpa )
1297
1396
transforms .append (replace_causal_mask )
1298
1397
1299
- elif args . coreml :
1398
+ elif coreml :
1300
1399
# iOS 18 introduced fused sdpa op
1301
- if args . coreml_ios >= 18 :
1400
+ if coreml_ios >= 18 :
1302
1401
transforms .append (replace_sdpa_with_coreml_sdpa )
1303
1402
else :
1304
1403
transforms .append (replace_sdpa_with_simple_sdpa )
1305
1404
transforms .append (replace_kv_cache_with_coreml_kv_cache )
1306
1405
1307
- if args . vulkan :
1406
+ if vulkan :
1308
1407
transforms .append (replace_with_vulkan_rotary_emb )
1309
1408
1310
1409
return transforms
0 commit comments