@@ -661,10 +661,35 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661
661
logging .info (f"Checkpoint dtype: { edge_manager .model .checkpoint_dtype } " )
662
662
edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
663
663
_get_source_transforms (
664
- modelname = args .model ,
665
664
dtype_override = dtype_override ,
665
+ checkpoint = args .checkpoint ,
666
666
checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
667
- args = args ,
667
+ tokenizer_path = args .tokenizer_path ,
668
+ use_spin_quant = args .use_spin_quant ,
669
+ embedding_quantize = args .embedding_quantize ,
670
+ use_shared_embedding = args .use_shared_embedding ,
671
+ quantization_mode = args .quantization_mode ,
672
+ group_size = args .group_size ,
673
+ calibration_tasks = args .calibration_tasks ,
674
+ calibration_limit = args .calibration_limit ,
675
+ calibration_seq_length = args .calibration_seq_length ,
676
+ expand_rope_table = args .expand_rope_table ,
677
+ use_custom_sdpa_with_attention_mask = getattr (args , "use_custom_sdpa_with_attention_mask" , False ),
678
+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
679
+ quantize_kv_cache = args .quantize_kv_cache ,
680
+ use_kv_cache = args .use_kv_cache ,
681
+ qnn = args .qnn ,
682
+ use_qnn_sha = args .use_qnn_sha ,
683
+ optimized_rotation_path = args .optimized_rotation_path ,
684
+ mps = args .mps ,
685
+ coreml = args .coreml ,
686
+ coreml_ios = args .coreml_ios ,
687
+ vulkan = args .vulkan ,
688
+ use_qat = args .use_qat ,
689
+ use_lora = args .use_lora ,
690
+ preq_mode = args .preq_mode ,
691
+ preq_group_size = args .preq_group_size ,
692
+ preq_embedding_quantize = args .preq_embedding_quantize ,
668
693
)
669
694
)
670
695
@@ -1155,23 +1180,69 @@ def _load_llama_model(
1155
1180
1156
1181
1157
1182
def _get_source_transforms ( # noqa
1158
- modelname : str ,
1159
1183
dtype_override : DType ,
1160
1184
* ,
1185
+ checkpoint : Optional [str ] = None ,
1161
1186
checkpoint_dtype : Optional [DType ] = None ,
1162
- args ,
1187
+ tokenizer_path : Optional [str ] = None ,
1188
+ use_spin_quant : Optional [str ] = None ,
1189
+ embedding_quantize : Optional [str ] = None ,
1190
+ use_shared_embedding : bool = False ,
1191
+ quantization_mode : Optional [str ] = None ,
1192
+ group_size : Optional [int ] = None ,
1193
+ calibration_tasks : Optional [List [str ]] = None ,
1194
+ calibration_limit : Optional [int ] = None ,
1195
+ calibration_seq_length : Optional [int ] = None ,
1196
+ expand_rope_table : bool = False ,
1197
+ use_custom_sdpa_with_attention_mask : bool = False ,
1198
+ use_sdpa_with_kv_cache : bool = False ,
1199
+ quantize_kv_cache : bool = False ,
1200
+ use_kv_cache : bool = False ,
1201
+ qnn : bool = False ,
1202
+ use_qnn_sha : bool = False ,
1203
+ optimized_rotation_path : Optional [str ] = None ,
1204
+ mps : bool = False ,
1205
+ coreml : bool = False ,
1206
+ coreml_ios : int = 15 ,
1207
+ vulkan : bool = False ,
1208
+ use_qat : bool = False ,
1209
+ use_lora : int = 0 ,
1210
+ preq_mode : Optional [str ] = None ,
1211
+ preq_group_size : Optional [int ] = None ,
1212
+ preq_embedding_quantize : Optional [str ] = None ,
1163
1213
) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
1164
1214
"""
1165
1215
Return a list of functions that transform a graph.
1166
1216
1167
1217
Args:
1168
- modelname: The name of the model.
1169
1218
dtype_override: The dtype to use for the model.
1219
+ checkpoint: Path to the checkpoint file.
1170
1220
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
1171
1221
it means that you want to run quantize transformations on the weights represented
1172
1222
in their original dtype, while the overall dtype of the model maybe something
1173
1223
different. If not specified, defaults to dtype_override.
1174
- args: The arguments passed to the script.
1224
+ tokenizer_path: Path to the tokenizer file.
1225
+ use_spin_quant: Type of spin quant to use ("cuda" or "native").
1226
+ embedding_quantize: Type of embedding quantization.
1227
+ quantization_mode: Type of quantization mode.
1228
+ expand_rope_table: Whether to expand rope table.
1229
+ use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1230
+ use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1231
+ quantize_kv_cache: Whether to quantize KV cache.
1232
+ use_kv_cache: Whether to use KV cache.
1233
+ qnn: Whether to use QNN.
1234
+ use_qnn_sha: Whether to use QNN SHA.
1235
+ optimized_rotation_path: Path to optimized rotation.
1236
+ mps: Whether to use MPS.
1237
+ coreml: Whether to use CoreML.
1238
+ coreml_ios: CoreML iOS version.
1239
+ vulkan: Whether to use Vulkan.
1240
+ use_shared_embedding: Whether to use shared embedding.
1241
+ use_qat: Whether to use QAT.
1242
+ use_lora: LoRA rank (0 means no LoRA).
1243
+ preq_mode: Pre-quantization mode.
1244
+ preq_group_size: Pre-quantization group size.
1245
+ preq_embedding_quantize: Pre-quantization embedding quantize.
1175
1246
1176
1247
Returns:
1177
1248
A list of transformation functions.
@@ -1182,21 +1253,21 @@ def _get_source_transforms( # noqa
1182
1253
1183
1254
transforms = []
1184
1255
1185
- if args . use_spin_quant :
1186
- if args . use_spin_quant == "cuda" :
1256
+ if use_spin_quant :
1257
+ if use_spin_quant == "cuda" :
1187
1258
from .source_transformation .spin_quant import (
1188
1259
inject_fast_hadamard_transform_cuda_for_spin_quant ,
1189
1260
)
1190
1261
1191
1262
transforms .append (inject_fast_hadamard_transform_cuda_for_spin_quant )
1192
- elif args . use_spin_quant == "native" :
1263
+ elif use_spin_quant == "native" :
1193
1264
from .source_transformation .spin_quant import (
1194
1265
inject_fast_hadamard_transform_native_for_spin_quant ,
1195
1266
)
1196
1267
1197
1268
transforms .append (inject_fast_hadamard_transform_native_for_spin_quant )
1198
1269
1199
- if args . embedding_quantize :
1270
+ if embedding_quantize :
1200
1271
"""
1201
1272
When this option is selected, it finds all embedding layers and transforms
1202
1273
into quantized embedding equivalent module.
@@ -1206,12 +1277,25 @@ def _get_source_transforms( # noqa
1206
1277
transformations based on the given checkpoint first. In those cases,
1207
1278
this wil be a no-op.
1208
1279
"""
1209
- modelname = f"{ modelname } _e"
1280
+ # Create a mock args object with the necessary attributes
1281
+ class Args :
1282
+ pass
1283
+ args = Args ()
1284
+ args .checkpoint = checkpoint
1285
+ args .tokenizer_path = tokenizer_path
1286
+ args .embedding_quantize = embedding_quantize
1287
+ args .use_shared_embedding = use_shared_embedding
1288
+ args .use_qat = use_qat
1289
+ args .use_lora = use_lora
1290
+ args .preq_mode = preq_mode
1291
+ args .preq_group_size = preq_group_size
1292
+ args .preq_embedding_quantize = preq_embedding_quantize
1293
+
1210
1294
transforms .append (get_quant_embedding_transform (args , checkpoint_dtype ))
1211
1295
1212
1296
# quantization_mode should be applied after embedding_quantize
1213
1297
# to support shared_embedding
1214
- if args . quantization_mode :
1298
+ if quantization_mode :
1215
1299
"""
1216
1300
When this option is selected, it finds all linear layers and transforms
1217
1301
into quantized linear equivalent module.
@@ -1225,7 +1309,23 @@ def _get_source_transforms( # noqa
1225
1309
There are cases where this may be a no-op, namely, if all linears are
1226
1310
quantized in the checkpoint.
1227
1311
"""
1228
- modelname = f"{ modelname } _q"
1312
+ # Create a mock args object with the necessary attributes
1313
+ class Args :
1314
+ pass
1315
+ args = Args ()
1316
+ args .checkpoint = checkpoint
1317
+ args .tokenizer_path = tokenizer_path
1318
+ args .quantization_mode = quantization_mode
1319
+ args .group_size = group_size
1320
+ args .use_shared_embedding = use_shared_embedding
1321
+ args .calibration_tasks = calibration_tasks
1322
+ args .calibration_limit = calibration_limit
1323
+ args .calibration_seq_length = calibration_seq_length
1324
+ args .use_shared_embedding = use_shared_embedding
1325
+ args .use_qat = use_qat
1326
+ args .use_lora = use_lora
1327
+ args .preq_mode = preq_mode
1328
+
1229
1329
transforms .append (
1230
1330
get_quant_weight_transform (
1231
1331
args = args ,
@@ -1234,15 +1334,12 @@ def _get_source_transforms( # noqa
1234
1334
)
1235
1335
)
1236
1336
1237
- if args . expand_rope_table :
1337
+ if expand_rope_table :
1238
1338
transforms .append (materialze_broadcast_of_rope_freq_cis )
1239
1339
1240
- use_attention_mask_for_custom_sdpa = False
1241
- if isinstance (args , argparse .Namespace ):
1242
- if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1243
- use_attention_mask_for_custom_sdpa = True
1340
+ use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
1244
1341
1245
- if args . use_sdpa_with_kv_cache :
1342
+ if use_sdpa_with_kv_cache :
1246
1343
transforms .append (replace_kv_cache_with_custom_kv_cache )
1247
1344
# todo: do this optionally
1248
1345
# if use attention mask instead of causal attention
@@ -1254,23 +1351,23 @@ def _get_source_transforms( # noqa
1254
1351
else :
1255
1352
transforms .append (replace_sdpa_with_custom_op )
1256
1353
1257
- if args . quantize_kv_cache :
1258
- assert args . use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1354
+ if quantize_kv_cache :
1355
+ assert use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1259
1356
transforms .append (replace_kv_cache_with_quantized_kv_cache )
1260
1357
# Right now
1261
1358
transforms .append (replace_sdpa_with_quantized_sdpa )
1262
1359
1263
- if args . use_kv_cache :
1264
- if args . qnn :
1360
+ if use_kv_cache :
1361
+ if qnn :
1265
1362
from executorch .backends .qualcomm .utils .utils import (
1266
1363
convert_linear_to_conv2d ,
1267
1364
)
1268
1365
1269
- if args . use_qnn_sha :
1270
- if args . optimized_rotation_path :
1366
+ if use_qnn_sha :
1367
+ if optimized_rotation_path :
1271
1368
transforms .append (fuse_layer_norms )
1272
1369
transforms .append (
1273
- get_model_with_r1_r2 (args . optimized_rotation_path )
1370
+ get_model_with_r1_r2 (optimized_rotation_path )
1274
1371
)
1275
1372
transforms .append (replace_attention_to_attention_sha )
1276
1373
transforms .append (replace_causal_mask )
@@ -1282,29 +1379,29 @@ def _get_source_transforms( # noqa
1282
1379
transforms .append (replace_sdpa_with_flex_sdpa )
1283
1380
transforms .append (replace_causal_mask )
1284
1381
transforms .append (replace_rms_norm_with_native_rms_norm )
1285
- if args . optimized_rotation_path :
1382
+ if optimized_rotation_path :
1286
1383
transforms .append (fuse_layer_norms )
1287
1384
transforms .append (
1288
- get_model_with_r1_r2 (args . optimized_rotation_path )
1385
+ get_model_with_r1_r2 (optimized_rotation_path )
1289
1386
)
1290
1387
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
1291
1388
transforms .append (convert_linear_to_conv2d )
1292
1389
1293
- elif args . mps :
1390
+ elif mps :
1294
1391
# Currently mps doesn't support sdpa op, use the simpler decomposition
1295
1392
# to get free perf gain.
1296
1393
transforms .append (replace_sdpa_with_simple_sdpa )
1297
1394
transforms .append (replace_causal_mask )
1298
1395
1299
- elif args . coreml :
1396
+ elif coreml :
1300
1397
# iOS 18 introduced fused sdpa op
1301
- if args . coreml_ios >= 18 :
1398
+ if coreml_ios >= 18 :
1302
1399
transforms .append (replace_sdpa_with_coreml_sdpa )
1303
1400
else :
1304
1401
transforms .append (replace_sdpa_with_simple_sdpa )
1305
1402
transforms .append (replace_kv_cache_with_coreml_kv_cache )
1306
1403
1307
- if args . vulkan :
1404
+ if vulkan :
1308
1405
transforms .append (replace_with_vulkan_rotary_emb )
1309
1406
1310
1407
return transforms
0 commit comments