Skip to content

Commit ff5e6d9

Browse files
committed
refactor _get_source_transforms to remove args parameter and unused modelname
1 parent 2837867 commit ff5e6d9

File tree

1 file changed

+112
-32
lines changed

1 file changed

+112
-32
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 112 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -651,10 +651,30 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
651651
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
652652
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
653653
_get_source_transforms(
654-
modelname=args.model,
655654
dtype_override=dtype_override,
655+
checkpoint=args.checkpoint,
656656
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
657-
args=args,
657+
use_spin_quant=args.use_spin_quant,
658+
embedding_quantize=args.embedding_quantize,
659+
quantization_mode=args.quantization_mode,
660+
expand_rope_table=args.expand_rope_table,
661+
use_custom_sdpa_with_attention_mask=getattr(args, "use_custom_sdpa_with_attention_mask", False),
662+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
663+
quantize_kv_cache=args.quantize_kv_cache,
664+
use_kv_cache=args.use_kv_cache,
665+
qnn=args.qnn,
666+
use_qnn_sha=args.use_qnn_sha,
667+
optimized_rotation_path=args.optimized_rotation_path,
668+
mps=args.mps,
669+
coreml=args.coreml,
670+
coreml_ios=args.coreml_ios,
671+
vulkan=args.vulkan,
672+
use_shared_embedding=args.use_shared_embedding,
673+
use_qat=args.use_qat,
674+
use_lora=args.use_lora,
675+
preq_mode=args.preq_mode,
676+
preq_group_size=args.preq_group_size,
677+
preq_embedding_quantize=args.preq_embedding_quantize,
658678
)
659679
)
660680

@@ -1145,23 +1165,63 @@ def _load_llama_model(
11451165

11461166

11471167
def _get_source_transforms( # noqa
1148-
modelname: str,
11491168
dtype_override: DType,
11501169
*,
1170+
checkpoint: Optional[str] = None,
11511171
checkpoint_dtype: Optional[DType] = None,
1152-
args,
1172+
use_spin_quant: Optional[str] = None,
1173+
embedding_quantize: Optional[str] = None,
1174+
quantization_mode: Optional[str] = None,
1175+
expand_rope_table: bool = False,
1176+
use_custom_sdpa_with_attention_mask: bool = False,
1177+
use_sdpa_with_kv_cache: bool = False,
1178+
quantize_kv_cache: bool = False,
1179+
use_kv_cache: bool = False,
1180+
qnn: bool = False,
1181+
use_qnn_sha: bool = False,
1182+
optimized_rotation_path: Optional[str] = None,
1183+
mps: bool = False,
1184+
coreml: bool = False,
1185+
coreml_ios: int = 15,
1186+
vulkan: bool = False,
1187+
use_shared_embedding: bool = False,
1188+
use_qat: bool = False,
1189+
use_lora: int = 0,
1190+
preq_mode: Optional[str] = None,
1191+
preq_group_size: int = 32,
1192+
preq_embedding_quantize: str = "8,0",
11531193
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
11541194
"""
11551195
Return a list of functions that transform a graph.
11561196
11571197
Args:
1158-
modelname: The name of the model.
11591198
dtype_override: The dtype to use for the model.
1199+
checkpoint: Path to the checkpoint file.
11601200
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
11611201
it means that you want to run quantize transformations on the weights represented
11621202
in their original dtype, while the overall dtype of the model maybe something
11631203
different. If not specified, defaults to dtype_override.
1164-
args: The arguments passed to the script.
1204+
use_spin_quant: Type of spin quant to use ("cuda" or "native").
1205+
embedding_quantize: Type of embedding quantization.
1206+
quantization_mode: Type of quantization mode.
1207+
expand_rope_table: Whether to expand rope table.
1208+
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1209+
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1210+
quantize_kv_cache: Whether to quantize KV cache.
1211+
use_kv_cache: Whether to use KV cache.
1212+
qnn: Whether to use QNN.
1213+
use_qnn_sha: Whether to use QNN SHA.
1214+
optimized_rotation_path: Path to optimized rotation.
1215+
mps: Whether to use MPS.
1216+
coreml: Whether to use CoreML.
1217+
coreml_ios: CoreML iOS version.
1218+
vulkan: Whether to use Vulkan.
1219+
use_shared_embedding: Whether to use shared embedding.
1220+
use_qat: Whether to use QAT.
1221+
use_lora: LoRA rank (0 means no LoRA).
1222+
preq_mode: Pre-quantization mode.
1223+
preq_group_size: Pre-quantization group size.
1224+
preq_embedding_quantize: Pre-quantization embedding quantize.
11651225
11661226
Returns:
11671227
A list of transformation functions.
@@ -1172,21 +1232,21 @@ def _get_source_transforms( # noqa
11721232

11731233
transforms = []
11741234

1175-
if args.use_spin_quant:
1176-
if args.use_spin_quant == "cuda":
1235+
if use_spin_quant:
1236+
if use_spin_quant == "cuda":
11771237
from .source_transformation.spin_quant import (
11781238
inject_fast_hadamard_transform_cuda_for_spin_quant,
11791239
)
11801240

11811241
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
1182-
elif args.use_spin_quant == "native":
1242+
elif use_spin_quant == "native":
11831243
from .source_transformation.spin_quant import (
11841244
inject_fast_hadamard_transform_native_for_spin_quant,
11851245
)
11861246

11871247
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
11881248

1189-
if args.embedding_quantize:
1249+
if embedding_quantize:
11901250
"""
11911251
When this option is selected, it finds all embedding layers and transforms
11921252
into quantized embedding equivalent module.
@@ -1196,12 +1256,24 @@ def _get_source_transforms( # noqa
11961256
transformations based on the given checkpoint first. In those cases,
11971257
this wil be a no-op.
11981258
"""
1199-
modelname = f"{modelname}_e"
1259+
# Create a mock args object with the necessary attributes
1260+
class Args:
1261+
pass
1262+
args = Args()
1263+
args.checkpoint = checkpoint
1264+
args.embedding_quantize = embedding_quantize
1265+
args.use_shared_embedding = use_shared_embedding
1266+
args.use_qat = use_qat
1267+
args.use_lora = use_lora
1268+
args.preq_mode = preq_mode
1269+
args.preq_group_size = preq_group_size
1270+
args.preq_embedding_quantize = preq_embedding_quantize
1271+
12001272
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
12011273

12021274
# quantization_mode should be applied after embedding_quantize
12031275
# to support shared_embedding
1204-
if args.quantization_mode:
1276+
if quantization_mode:
12051277
"""
12061278
When this option is selected, it finds all linear layers and transforms
12071279
into quantized linear equivalent module.
@@ -1215,7 +1287,18 @@ def _get_source_transforms( # noqa
12151287
There are cases where this may be a no-op, namely, if all linears are
12161288
quantized in the checkpoint.
12171289
"""
1218-
modelname = f"{modelname}_q"
1290+
# Create a mock args object with the necessary attributes
1291+
class Args:
1292+
pass
1293+
args = Args()
1294+
args.checkpoint = checkpoint
1295+
args.quantization_mode = quantization_mode
1296+
args.group_size = preq_group_size # Using preq_group_size as group_size
1297+
args.use_shared_embedding = use_shared_embedding
1298+
args.use_qat = use_qat
1299+
args.use_lora = use_lora
1300+
args.preq_mode = preq_mode
1301+
12191302
transforms.append(
12201303
get_quant_weight_transform(
12211304
args=args,
@@ -1224,15 +1307,12 @@ def _get_source_transforms( # noqa
12241307
)
12251308
)
12261309

1227-
if args.expand_rope_table:
1310+
if expand_rope_table:
12281311
transforms.append(materialze_broadcast_of_rope_freq_cis)
12291312

1230-
use_attention_mask_for_custom_sdpa = False
1231-
if isinstance(args, argparse.Namespace):
1232-
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
1233-
use_attention_mask_for_custom_sdpa = True
1313+
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
12341314

1235-
if args.use_sdpa_with_kv_cache:
1315+
if use_sdpa_with_kv_cache:
12361316
transforms.append(replace_kv_cache_with_custom_kv_cache)
12371317
# todo: do this optionally
12381318
# if use attention mask instead of causal attention
@@ -1244,23 +1324,23 @@ def _get_source_transforms( # noqa
12441324
else:
12451325
transforms.append(replace_sdpa_with_custom_op)
12461326

1247-
if args.quantize_kv_cache:
1248-
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
1327+
if quantize_kv_cache:
1328+
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
12491329
transforms.append(replace_kv_cache_with_quantized_kv_cache)
12501330
# Right now
12511331
transforms.append(replace_sdpa_with_quantized_sdpa)
12521332

1253-
if args.use_kv_cache:
1254-
if args.qnn:
1333+
if use_kv_cache:
1334+
if qnn:
12551335
from executorch.backends.qualcomm.utils.utils import (
12561336
convert_linear_to_conv2d,
12571337
)
12581338

1259-
if args.use_qnn_sha:
1260-
if args.optimized_rotation_path:
1339+
if use_qnn_sha:
1340+
if optimized_rotation_path:
12611341
transforms.append(fuse_layer_norms)
12621342
transforms.append(
1263-
get_model_with_r1_r2(args.optimized_rotation_path)
1343+
get_model_with_r1_r2(optimized_rotation_path)
12641344
)
12651345
transforms.append(replace_attention_to_attention_sha)
12661346
transforms.append(replace_causal_mask)
@@ -1272,29 +1352,29 @@ def _get_source_transforms( # noqa
12721352
transforms.append(replace_sdpa_with_flex_sdpa)
12731353
transforms.append(replace_causal_mask)
12741354
transforms.append(replace_rms_norm_with_native_rms_norm)
1275-
if args.optimized_rotation_path:
1355+
if optimized_rotation_path:
12761356
transforms.append(fuse_layer_norms)
12771357
transforms.append(
1278-
get_model_with_r1_r2(args.optimized_rotation_path)
1358+
get_model_with_r1_r2(optimized_rotation_path)
12791359
)
12801360
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
12811361
transforms.append(convert_linear_to_conv2d)
12821362

1283-
elif args.mps:
1363+
elif mps:
12841364
# Currently mps doesn't support sdpa op, use the simpler decomposition
12851365
# to get free perf gain.
12861366
transforms.append(replace_sdpa_with_simple_sdpa)
12871367
transforms.append(replace_causal_mask)
12881368

1289-
elif args.coreml:
1369+
elif coreml:
12901370
# iOS 18 introduced fused sdpa op
1291-
if args.coreml_ios >= 18:
1371+
if coreml_ios >= 18:
12921372
transforms.append(replace_sdpa_with_coreml_sdpa)
12931373
else:
12941374
transforms.append(replace_sdpa_with_simple_sdpa)
12951375
transforms.append(replace_kv_cache_with_coreml_kv_cache)
12961376

1297-
if args.vulkan:
1377+
if vulkan:
12981378
transforms.append(replace_with_vulkan_rotary_emb)
12991379

13001380
return transforms

0 commit comments

Comments
 (0)