@@ -651,10 +651,30 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
651
651
logging .info (f"Checkpoint dtype: { edge_manager .model .checkpoint_dtype } " )
652
652
edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
653
653
_get_source_transforms (
654
- modelname = args .model ,
655
654
dtype_override = dtype_override ,
655
+ checkpoint = args .checkpoint ,
656
656
checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
657
- args = args ,
657
+ use_spin_quant = args .use_spin_quant ,
658
+ embedding_quantize = args .embedding_quantize ,
659
+ quantization_mode = args .quantization_mode ,
660
+ expand_rope_table = args .expand_rope_table ,
661
+ use_custom_sdpa_with_attention_mask = getattr (args , "use_custom_sdpa_with_attention_mask" , False ),
662
+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
663
+ quantize_kv_cache = args .quantize_kv_cache ,
664
+ use_kv_cache = args .use_kv_cache ,
665
+ qnn = args .qnn ,
666
+ use_qnn_sha = args .use_qnn_sha ,
667
+ optimized_rotation_path = args .optimized_rotation_path ,
668
+ mps = args .mps ,
669
+ coreml = args .coreml ,
670
+ coreml_ios = args .coreml_ios ,
671
+ vulkan = args .vulkan ,
672
+ use_shared_embedding = args .use_shared_embedding ,
673
+ use_qat = args .use_qat ,
674
+ use_lora = args .use_lora ,
675
+ preq_mode = args .preq_mode ,
676
+ preq_group_size = args .preq_group_size ,
677
+ preq_embedding_quantize = args .preq_embedding_quantize ,
658
678
)
659
679
)
660
680
@@ -1145,23 +1165,63 @@ def _load_llama_model(
1145
1165
1146
1166
1147
1167
def _get_source_transforms ( # noqa
1148
- modelname : str ,
1149
1168
dtype_override : DType ,
1150
1169
* ,
1170
+ checkpoint : Optional [str ] = None ,
1151
1171
checkpoint_dtype : Optional [DType ] = None ,
1152
- args ,
1172
+ use_spin_quant : Optional [str ] = None ,
1173
+ embedding_quantize : Optional [str ] = None ,
1174
+ quantization_mode : Optional [str ] = None ,
1175
+ expand_rope_table : bool = False ,
1176
+ use_custom_sdpa_with_attention_mask : bool = False ,
1177
+ use_sdpa_with_kv_cache : bool = False ,
1178
+ quantize_kv_cache : bool = False ,
1179
+ use_kv_cache : bool = False ,
1180
+ qnn : bool = False ,
1181
+ use_qnn_sha : bool = False ,
1182
+ optimized_rotation_path : Optional [str ] = None ,
1183
+ mps : bool = False ,
1184
+ coreml : bool = False ,
1185
+ coreml_ios : int = 15 ,
1186
+ vulkan : bool = False ,
1187
+ use_shared_embedding : bool = False ,
1188
+ use_qat : bool = False ,
1189
+ use_lora : int = 0 ,
1190
+ preq_mode : Optional [str ] = None ,
1191
+ preq_group_size : int = 32 ,
1192
+ preq_embedding_quantize : str = "8,0" ,
1153
1193
) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
1154
1194
"""
1155
1195
Return a list of functions that transform a graph.
1156
1196
1157
1197
Args:
1158
- modelname: The name of the model.
1159
1198
dtype_override: The dtype to use for the model.
1199
+ checkpoint: Path to the checkpoint file.
1160
1200
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
1161
1201
it means that you want to run quantize transformations on the weights represented
1162
1202
in their original dtype, while the overall dtype of the model maybe something
1163
1203
different. If not specified, defaults to dtype_override.
1164
- args: The arguments passed to the script.
1204
+ use_spin_quant: Type of spin quant to use ("cuda" or "native").
1205
+ embedding_quantize: Type of embedding quantization.
1206
+ quantization_mode: Type of quantization mode.
1207
+ expand_rope_table: Whether to expand rope table.
1208
+ use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1209
+ use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1210
+ quantize_kv_cache: Whether to quantize KV cache.
1211
+ use_kv_cache: Whether to use KV cache.
1212
+ qnn: Whether to use QNN.
1213
+ use_qnn_sha: Whether to use QNN SHA.
1214
+ optimized_rotation_path: Path to optimized rotation.
1215
+ mps: Whether to use MPS.
1216
+ coreml: Whether to use CoreML.
1217
+ coreml_ios: CoreML iOS version.
1218
+ vulkan: Whether to use Vulkan.
1219
+ use_shared_embedding: Whether to use shared embedding.
1220
+ use_qat: Whether to use QAT.
1221
+ use_lora: LoRA rank (0 means no LoRA).
1222
+ preq_mode: Pre-quantization mode.
1223
+ preq_group_size: Pre-quantization group size.
1224
+ preq_embedding_quantize: Pre-quantization embedding quantize.
1165
1225
1166
1226
Returns:
1167
1227
A list of transformation functions.
@@ -1172,21 +1232,21 @@ def _get_source_transforms( # noqa
1172
1232
1173
1233
transforms = []
1174
1234
1175
- if args . use_spin_quant :
1176
- if args . use_spin_quant == "cuda" :
1235
+ if use_spin_quant :
1236
+ if use_spin_quant == "cuda" :
1177
1237
from .source_transformation .spin_quant import (
1178
1238
inject_fast_hadamard_transform_cuda_for_spin_quant ,
1179
1239
)
1180
1240
1181
1241
transforms .append (inject_fast_hadamard_transform_cuda_for_spin_quant )
1182
- elif args . use_spin_quant == "native" :
1242
+ elif use_spin_quant == "native" :
1183
1243
from .source_transformation .spin_quant import (
1184
1244
inject_fast_hadamard_transform_native_for_spin_quant ,
1185
1245
)
1186
1246
1187
1247
transforms .append (inject_fast_hadamard_transform_native_for_spin_quant )
1188
1248
1189
- if args . embedding_quantize :
1249
+ if embedding_quantize :
1190
1250
"""
1191
1251
When this option is selected, it finds all embedding layers and transforms
1192
1252
into quantized embedding equivalent module.
@@ -1196,12 +1256,24 @@ def _get_source_transforms( # noqa
1196
1256
transformations based on the given checkpoint first. In those cases,
1197
1257
this wil be a no-op.
1198
1258
"""
1199
- modelname = f"{ modelname } _e"
1259
+ # Create a mock args object with the necessary attributes
1260
+ class Args :
1261
+ pass
1262
+ args = Args ()
1263
+ args .checkpoint = checkpoint
1264
+ args .embedding_quantize = embedding_quantize
1265
+ args .use_shared_embedding = use_shared_embedding
1266
+ args .use_qat = use_qat
1267
+ args .use_lora = use_lora
1268
+ args .preq_mode = preq_mode
1269
+ args .preq_group_size = preq_group_size
1270
+ args .preq_embedding_quantize = preq_embedding_quantize
1271
+
1200
1272
transforms .append (get_quant_embedding_transform (args , checkpoint_dtype ))
1201
1273
1202
1274
# quantization_mode should be applied after embedding_quantize
1203
1275
# to support shared_embedding
1204
- if args . quantization_mode :
1276
+ if quantization_mode :
1205
1277
"""
1206
1278
When this option is selected, it finds all linear layers and transforms
1207
1279
into quantized linear equivalent module.
@@ -1215,7 +1287,18 @@ def _get_source_transforms( # noqa
1215
1287
There are cases where this may be a no-op, namely, if all linears are
1216
1288
quantized in the checkpoint.
1217
1289
"""
1218
- modelname = f"{ modelname } _q"
1290
+ # Create a mock args object with the necessary attributes
1291
+ class Args :
1292
+ pass
1293
+ args = Args ()
1294
+ args .checkpoint = checkpoint
1295
+ args .quantization_mode = quantization_mode
1296
+ args .group_size = preq_group_size # Using preq_group_size as group_size
1297
+ args .use_shared_embedding = use_shared_embedding
1298
+ args .use_qat = use_qat
1299
+ args .use_lora = use_lora
1300
+ args .preq_mode = preq_mode
1301
+
1219
1302
transforms .append (
1220
1303
get_quant_weight_transform (
1221
1304
args = args ,
@@ -1224,15 +1307,12 @@ def _get_source_transforms( # noqa
1224
1307
)
1225
1308
)
1226
1309
1227
- if args . expand_rope_table :
1310
+ if expand_rope_table :
1228
1311
transforms .append (materialze_broadcast_of_rope_freq_cis )
1229
1312
1230
- use_attention_mask_for_custom_sdpa = False
1231
- if isinstance (args , argparse .Namespace ):
1232
- if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1233
- use_attention_mask_for_custom_sdpa = True
1313
+ use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
1234
1314
1235
- if args . use_sdpa_with_kv_cache :
1315
+ if use_sdpa_with_kv_cache :
1236
1316
transforms .append (replace_kv_cache_with_custom_kv_cache )
1237
1317
# todo: do this optionally
1238
1318
# if use attention mask instead of causal attention
@@ -1244,23 +1324,23 @@ def _get_source_transforms( # noqa
1244
1324
else :
1245
1325
transforms .append (replace_sdpa_with_custom_op )
1246
1326
1247
- if args . quantize_kv_cache :
1248
- assert args . use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1327
+ if quantize_kv_cache :
1328
+ assert use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1249
1329
transforms .append (replace_kv_cache_with_quantized_kv_cache )
1250
1330
# Right now
1251
1331
transforms .append (replace_sdpa_with_quantized_sdpa )
1252
1332
1253
- if args . use_kv_cache :
1254
- if args . qnn :
1333
+ if use_kv_cache :
1334
+ if qnn :
1255
1335
from executorch .backends .qualcomm .utils .utils import (
1256
1336
convert_linear_to_conv2d ,
1257
1337
)
1258
1338
1259
- if args . use_qnn_sha :
1260
- if args . optimized_rotation_path :
1339
+ if use_qnn_sha :
1340
+ if optimized_rotation_path :
1261
1341
transforms .append (fuse_layer_norms )
1262
1342
transforms .append (
1263
- get_model_with_r1_r2 (args . optimized_rotation_path )
1343
+ get_model_with_r1_r2 (optimized_rotation_path )
1264
1344
)
1265
1345
transforms .append (replace_attention_to_attention_sha )
1266
1346
transforms .append (replace_causal_mask )
@@ -1272,29 +1352,29 @@ def _get_source_transforms( # noqa
1272
1352
transforms .append (replace_sdpa_with_flex_sdpa )
1273
1353
transforms .append (replace_causal_mask )
1274
1354
transforms .append (replace_rms_norm_with_native_rms_norm )
1275
- if args . optimized_rotation_path :
1355
+ if optimized_rotation_path :
1276
1356
transforms .append (fuse_layer_norms )
1277
1357
transforms .append (
1278
- get_model_with_r1_r2 (args . optimized_rotation_path )
1358
+ get_model_with_r1_r2 (optimized_rotation_path )
1279
1359
)
1280
1360
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
1281
1361
transforms .append (convert_linear_to_conv2d )
1282
1362
1283
- elif args . mps :
1363
+ elif mps :
1284
1364
# Currently mps doesn't support sdpa op, use the simpler decomposition
1285
1365
# to get free perf gain.
1286
1366
transforms .append (replace_sdpa_with_simple_sdpa )
1287
1367
transforms .append (replace_causal_mask )
1288
1368
1289
- elif args . coreml :
1369
+ elif coreml :
1290
1370
# iOS 18 introduced fused sdpa op
1291
- if args . coreml_ios >= 18 :
1371
+ if coreml_ios >= 18 :
1292
1372
transforms .append (replace_sdpa_with_coreml_sdpa )
1293
1373
else :
1294
1374
transforms .append (replace_sdpa_with_simple_sdpa )
1295
1375
transforms .append (replace_kv_cache_with_coreml_kv_cache )
1296
1376
1297
- if args . vulkan :
1377
+ if vulkan :
1298
1378
transforms .append (replace_with_vulkan_rotary_emb )
1299
1379
1300
1380
return transforms
0 commit comments