@@ -661,10 +661,36 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661
661
logging .info (f"Checkpoint dtype: { edge_manager .model .checkpoint_dtype } " )
662
662
edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
663
663
_get_source_transforms (
664
- modelname = args .model ,
665
664
dtype_override = dtype_override ,
665
+ checkpoint = args .checkpoint ,
666
666
checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
667
- args = args ,
667
+ tokenizer_path = args .tokenizer_path ,
668
+ use_spin_quant = args .use_spin_quant ,
669
+ embedding_quantize = args .embedding_quantize ,
670
+ use_shared_embedding = args .use_shared_embedding ,
671
+ quantization_mode = args .quantization_mode ,
672
+ group_size = args .group_size ,
673
+ calibration_tasks = args .calibration_tasks ,
674
+ calibration_limit = args .calibration_limit ,
675
+ calibration_seq_length = args .calibration_seq_length ,
676
+ expand_rope_table = args .expand_rope_table ,
677
+ use_custom_sdpa_with_attention_mask = getattr (args , "use_custom_sdpa_with_attention_mask" , False ),
678
+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
679
+ quantize_kv_cache = args .quantize_kv_cache ,
680
+ use_kv_cache = args .use_kv_cache ,
681
+ qnn = args .qnn ,
682
+ use_qnn_sha = args .use_qnn_sha ,
683
+ optimized_rotation_path = args .optimized_rotation_path ,
684
+ mps = args .mps ,
685
+ coreml = args .coreml ,
686
+ coreml_ios = args .coreml_ios ,
687
+ vulkan = args .vulkan ,
688
+ use_shared_embedding = args .use_shared_embedding ,
689
+ use_qat = args .use_qat ,
690
+ use_lora = args .use_lora ,
691
+ preq_mode = args .preq_mode ,
692
+ preq_group_size = args .preq_group_size ,
693
+ preq_embedding_quantize = args .preq_embedding_quantize ,
668
694
)
669
695
)
670
696
@@ -1155,23 +1181,69 @@ def _load_llama_model(
1155
1181
1156
1182
1157
1183
def _get_source_transforms ( # noqa
1158
- modelname : str ,
1159
1184
dtype_override : DType ,
1160
1185
* ,
1186
+ checkpoint : Optional [str ] = None ,
1161
1187
checkpoint_dtype : Optional [DType ] = None ,
1162
- args ,
1188
+ tokenizer_path : Optional [str ] = None ,
1189
+ use_spin_quant : Optional [str ] = None ,
1190
+ embedding_quantize : Optional [str ] = None ,
1191
+ use_shared_embedding : bool = False ,
1192
+ quantization_mode : Optional [str ] = None ,
1193
+ group_size : Optional [int ] = None ,
1194
+ calibration_tasks : Optional [List [str ]] = None ,
1195
+ calibration_limit : Optional [int ] = None ,
1196
+ calibration_seq_length : Optional [int ] = None ,
1197
+ expand_rope_table : bool = False ,
1198
+ use_custom_sdpa_with_attention_mask : bool = False ,
1199
+ use_sdpa_with_kv_cache : bool = False ,
1200
+ quantize_kv_cache : bool = False ,
1201
+ use_kv_cache : bool = False ,
1202
+ qnn : bool = False ,
1203
+ use_qnn_sha : bool = False ,
1204
+ optimized_rotation_path : Optional [str ] = None ,
1205
+ mps : bool = False ,
1206
+ coreml : bool = False ,
1207
+ coreml_ios : int = 15 ,
1208
+ vulkan : bool = False ,
1209
+ use_qat : bool = False ,
1210
+ use_lora : int = 0 ,
1211
+ preq_mode : Optional [str ] = None ,
1212
+ preq_group_size : Optional [int ] = None ,
1213
+ preq_embedding_quantize : Optional [str ] = None ,
1163
1214
) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
1164
1215
"""
1165
1216
Return a list of functions that transform a graph.
1166
1217
1167
1218
Args:
1168
- modelname: The name of the model.
1169
1219
dtype_override: The dtype to use for the model.
1220
+ checkpoint: Path to the checkpoint file.
1170
1221
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
1171
1222
it means that you want to run quantize transformations on the weights represented
1172
1223
in their original dtype, while the overall dtype of the model maybe something
1173
1224
different. If not specified, defaults to dtype_override.
1174
- args: The arguments passed to the script.
1225
+ tokenizer_path: Path to the tokenizer file.
1226
+ use_spin_quant: Type of spin quant to use ("cuda" or "native").
1227
+ embedding_quantize: Type of embedding quantization.
1228
+ quantization_mode: Type of quantization mode.
1229
+ expand_rope_table: Whether to expand rope table.
1230
+ use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1231
+ use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1232
+ quantize_kv_cache: Whether to quantize KV cache.
1233
+ use_kv_cache: Whether to use KV cache.
1234
+ qnn: Whether to use QNN.
1235
+ use_qnn_sha: Whether to use QNN SHA.
1236
+ optimized_rotation_path: Path to optimized rotation.
1237
+ mps: Whether to use MPS.
1238
+ coreml: Whether to use CoreML.
1239
+ coreml_ios: CoreML iOS version.
1240
+ vulkan: Whether to use Vulkan.
1241
+ use_shared_embedding: Whether to use shared embedding.
1242
+ use_qat: Whether to use QAT.
1243
+ use_lora: LoRA rank (0 means no LoRA).
1244
+ preq_mode: Pre-quantization mode.
1245
+ preq_group_size: Pre-quantization group size.
1246
+ preq_embedding_quantize: Pre-quantization embedding quantize.
1175
1247
1176
1248
Returns:
1177
1249
A list of transformation functions.
@@ -1182,21 +1254,21 @@ def _get_source_transforms( # noqa
1182
1254
1183
1255
transforms = []
1184
1256
1185
- if args . use_spin_quant :
1186
- if args . use_spin_quant == "cuda" :
1257
+ if use_spin_quant :
1258
+ if use_spin_quant == "cuda" :
1187
1259
from .source_transformation .spin_quant import (
1188
1260
inject_fast_hadamard_transform_cuda_for_spin_quant ,
1189
1261
)
1190
1262
1191
1263
transforms .append (inject_fast_hadamard_transform_cuda_for_spin_quant )
1192
- elif args . use_spin_quant == "native" :
1264
+ elif use_spin_quant == "native" :
1193
1265
from .source_transformation .spin_quant import (
1194
1266
inject_fast_hadamard_transform_native_for_spin_quant ,
1195
1267
)
1196
1268
1197
1269
transforms .append (inject_fast_hadamard_transform_native_for_spin_quant )
1198
1270
1199
- if args . embedding_quantize :
1271
+ if embedding_quantize :
1200
1272
"""
1201
1273
When this option is selected, it finds all embedding layers and transforms
1202
1274
into quantized embedding equivalent module.
@@ -1206,12 +1278,25 @@ def _get_source_transforms( # noqa
1206
1278
transformations based on the given checkpoint first. In those cases,
1207
1279
this wil be a no-op.
1208
1280
"""
1209
- modelname = f"{ modelname } _e"
1281
+ # Create a mock args object with the necessary attributes
1282
+ class Args :
1283
+ pass
1284
+ args = Args ()
1285
+ args .checkpoint = checkpoint
1286
+ args .tokenizer_path = tokenizer_path
1287
+ args .embedding_quantize = embedding_quantize
1288
+ args .use_shared_embedding = use_shared_embedding
1289
+ args .use_qat = use_qat
1290
+ args .use_lora = use_lora
1291
+ args .preq_mode = preq_mode
1292
+ args .preq_group_size = preq_group_size
1293
+ args .preq_embedding_quantize = preq_embedding_quantize
1294
+
1210
1295
transforms .append (get_quant_embedding_transform (args , checkpoint_dtype ))
1211
1296
1212
1297
# quantization_mode should be applied after embedding_quantize
1213
1298
# to support shared_embedding
1214
- if args . quantization_mode :
1299
+ if quantization_mode :
1215
1300
"""
1216
1301
When this option is selected, it finds all linear layers and transforms
1217
1302
into quantized linear equivalent module.
@@ -1225,7 +1310,23 @@ def _get_source_transforms( # noqa
1225
1310
There are cases where this may be a no-op, namely, if all linears are
1226
1311
quantized in the checkpoint.
1227
1312
"""
1228
- modelname = f"{ modelname } _q"
1313
+ # Create a mock args object with the necessary attributes
1314
+ class Args :
1315
+ pass
1316
+ args = Args ()
1317
+ args .checkpoint = checkpoint
1318
+ args .tokenizer_path = tokenizer_path
1319
+ args .quantization_mode = quantization_mode
1320
+ args .group_size = group_size
1321
+ args .use_shared_embedding = use_shared_embedding
1322
+ args .calibration_tasks = calibration_tasks
1323
+ args .calibration_limit = calibration_limit
1324
+ args .calibration_seq_length = calibration_seq_length
1325
+ args .use_shared_embedding = use_shared_embedding
1326
+ args .use_qat = use_qat
1327
+ args .use_lora = use_lora
1328
+ args .preq_mode = preq_mode
1329
+
1229
1330
transforms .append (
1230
1331
get_quant_weight_transform (
1231
1332
args = args ,
@@ -1234,15 +1335,12 @@ def _get_source_transforms( # noqa
1234
1335
)
1235
1336
)
1236
1337
1237
- if args . expand_rope_table :
1338
+ if expand_rope_table :
1238
1339
transforms .append (materialze_broadcast_of_rope_freq_cis )
1239
1340
1240
- use_attention_mask_for_custom_sdpa = False
1241
- if isinstance (args , argparse .Namespace ):
1242
- if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1243
- use_attention_mask_for_custom_sdpa = True
1341
+ use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
1244
1342
1245
- if args . use_sdpa_with_kv_cache :
1343
+ if use_sdpa_with_kv_cache :
1246
1344
transforms .append (replace_kv_cache_with_custom_kv_cache )
1247
1345
# todo: do this optionally
1248
1346
# if use attention mask instead of causal attention
@@ -1254,23 +1352,23 @@ def _get_source_transforms( # noqa
1254
1352
else :
1255
1353
transforms .append (replace_sdpa_with_custom_op )
1256
1354
1257
- if args . quantize_kv_cache :
1258
- assert args . use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1355
+ if quantize_kv_cache :
1356
+ assert use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1259
1357
transforms .append (replace_kv_cache_with_quantized_kv_cache )
1260
1358
# Right now
1261
1359
transforms .append (replace_sdpa_with_quantized_sdpa )
1262
1360
1263
- if args . use_kv_cache :
1264
- if args . qnn :
1361
+ if use_kv_cache :
1362
+ if qnn :
1265
1363
from executorch .backends .qualcomm .utils .utils import (
1266
1364
convert_linear_to_conv2d ,
1267
1365
)
1268
1366
1269
- if args . use_qnn_sha :
1270
- if args . optimized_rotation_path :
1367
+ if use_qnn_sha :
1368
+ if optimized_rotation_path :
1271
1369
transforms .append (fuse_layer_norms )
1272
1370
transforms .append (
1273
- get_model_with_r1_r2 (args . optimized_rotation_path )
1371
+ get_model_with_r1_r2 (optimized_rotation_path )
1274
1372
)
1275
1373
transforms .append (replace_attention_to_attention_sha )
1276
1374
transforms .append (replace_causal_mask )
@@ -1282,29 +1380,29 @@ def _get_source_transforms( # noqa
1282
1380
transforms .append (replace_sdpa_with_flex_sdpa )
1283
1381
transforms .append (replace_causal_mask )
1284
1382
transforms .append (replace_rms_norm_with_native_rms_norm )
1285
- if args . optimized_rotation_path :
1383
+ if optimized_rotation_path :
1286
1384
transforms .append (fuse_layer_norms )
1287
1385
transforms .append (
1288
- get_model_with_r1_r2 (args . optimized_rotation_path )
1386
+ get_model_with_r1_r2 (optimized_rotation_path )
1289
1387
)
1290
1388
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
1291
1389
transforms .append (convert_linear_to_conv2d )
1292
1390
1293
- elif args . mps :
1391
+ elif mps :
1294
1392
# Currently mps doesn't support sdpa op, use the simpler decomposition
1295
1393
# to get free perf gain.
1296
1394
transforms .append (replace_sdpa_with_simple_sdpa )
1297
1395
transforms .append (replace_causal_mask )
1298
1396
1299
- elif args . coreml :
1397
+ elif coreml :
1300
1398
# iOS 18 introduced fused sdpa op
1301
- if args . coreml_ios >= 18 :
1399
+ if coreml_ios >= 18 :
1302
1400
transforms .append (replace_sdpa_with_coreml_sdpa )
1303
1401
else :
1304
1402
transforms .append (replace_sdpa_with_simple_sdpa )
1305
1403
transforms .append (replace_kv_cache_with_coreml_kv_cache )
1306
1404
1307
- if args . vulkan :
1405
+ if vulkan :
1308
1406
transforms .append (replace_with_vulkan_rotary_emb )
1309
1407
1310
1408
return transforms
0 commit comments