Skip to content

Commit 41536ab

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Refactor _get_source_transforms to remove args (#10519)
Summary: Refactor `_get_source_transforsm` to remove args Reviewed By: iseeyuan Differential Revision: D73800023 Pulled By: jackzhxng
1 parent 27e159e commit 41536ab

File tree

1 file changed

+130
-32
lines changed

1 file changed

+130
-32
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 130 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -661,10 +661,36 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661661
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
662662
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
663663
_get_source_transforms(
664-
modelname=args.model,
665664
dtype_override=dtype_override,
665+
checkpoint=args.checkpoint,
666666
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
667-
args=args,
667+
tokenizer_path=args.tokenizer_path,
668+
use_spin_quant=args.use_spin_quant,
669+
embedding_quantize=args.embedding_quantize,
670+
use_shared_embedding=args.use_shared_embedding,
671+
quantization_mode=args.quantization_mode,
672+
group_size=args.group_size,
673+
calibration_tasks=args.calibration_tasks,
674+
calibration_limit=args.calibration_limit,
675+
calibration_seq_length=args.calibration_seq_length,
676+
expand_rope_table=args.expand_rope_table,
677+
use_custom_sdpa_with_attention_mask=getattr(args, "use_custom_sdpa_with_attention_mask", False),
678+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
679+
quantize_kv_cache=args.quantize_kv_cache,
680+
use_kv_cache=args.use_kv_cache,
681+
qnn=args.qnn,
682+
use_qnn_sha=args.use_qnn_sha,
683+
optimized_rotation_path=args.optimized_rotation_path,
684+
mps=args.mps,
685+
coreml=args.coreml,
686+
coreml_ios=args.coreml_ios,
687+
vulkan=args.vulkan,
688+
use_shared_embedding=args.use_shared_embedding,
689+
use_qat=args.use_qat,
690+
use_lora=args.use_lora,
691+
preq_mode=args.preq_mode,
692+
preq_group_size=args.preq_group_size,
693+
preq_embedding_quantize=args.preq_embedding_quantize,
668694
)
669695
)
670696

@@ -1155,23 +1181,69 @@ def _load_llama_model(
11551181

11561182

11571183
def _get_source_transforms( # noqa
1158-
modelname: str,
11591184
dtype_override: DType,
11601185
*,
1186+
checkpoint: Optional[str] = None,
11611187
checkpoint_dtype: Optional[DType] = None,
1162-
args,
1188+
tokenizer_path: Optional[str] = None,
1189+
use_spin_quant: Optional[str] = None,
1190+
embedding_quantize: Optional[str] = None,
1191+
use_shared_embedding: bool = False,
1192+
quantization_mode: Optional[str] = None,
1193+
group_size: Optional[int] = None,
1194+
calibration_tasks: Optional[List[str]] = None,
1195+
calibration_limit: Optional[int] = None,
1196+
calibration_seq_length: Optional[int] = None,
1197+
expand_rope_table: bool = False,
1198+
use_custom_sdpa_with_attention_mask: bool = False,
1199+
use_sdpa_with_kv_cache: bool = False,
1200+
quantize_kv_cache: bool = False,
1201+
use_kv_cache: bool = False,
1202+
qnn: bool = False,
1203+
use_qnn_sha: bool = False,
1204+
optimized_rotation_path: Optional[str] = None,
1205+
mps: bool = False,
1206+
coreml: bool = False,
1207+
coreml_ios: int = 15,
1208+
vulkan: bool = False,
1209+
use_qat: bool = False,
1210+
use_lora: int = 0,
1211+
preq_mode: Optional[str] = None,
1212+
preq_group_size: Optional[int] = None,
1213+
preq_embedding_quantize: Optional[str] = None,
11631214
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
11641215
"""
11651216
Return a list of functions that transform a graph.
11661217
11671218
Args:
1168-
modelname: The name of the model.
11691219
dtype_override: The dtype to use for the model.
1220+
checkpoint: Path to the checkpoint file.
11701221
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
11711222
it means that you want to run quantize transformations on the weights represented
11721223
in their original dtype, while the overall dtype of the model maybe something
11731224
different. If not specified, defaults to dtype_override.
1174-
args: The arguments passed to the script.
1225+
tokenizer_path: Path to the tokenizer file.
1226+
use_spin_quant: Type of spin quant to use ("cuda" or "native").
1227+
embedding_quantize: Type of embedding quantization.
1228+
quantization_mode: Type of quantization mode.
1229+
expand_rope_table: Whether to expand rope table.
1230+
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1231+
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1232+
quantize_kv_cache: Whether to quantize KV cache.
1233+
use_kv_cache: Whether to use KV cache.
1234+
qnn: Whether to use QNN.
1235+
use_qnn_sha: Whether to use QNN SHA.
1236+
optimized_rotation_path: Path to optimized rotation.
1237+
mps: Whether to use MPS.
1238+
coreml: Whether to use CoreML.
1239+
coreml_ios: CoreML iOS version.
1240+
vulkan: Whether to use Vulkan.
1241+
use_shared_embedding: Whether to use shared embedding.
1242+
use_qat: Whether to use QAT.
1243+
use_lora: LoRA rank (0 means no LoRA).
1244+
preq_mode: Pre-quantization mode.
1245+
preq_group_size: Pre-quantization group size.
1246+
preq_embedding_quantize: Pre-quantization embedding quantize.
11751247
11761248
Returns:
11771249
A list of transformation functions.
@@ -1182,21 +1254,21 @@ def _get_source_transforms( # noqa
11821254

11831255
transforms = []
11841256

1185-
if args.use_spin_quant:
1186-
if args.use_spin_quant == "cuda":
1257+
if use_spin_quant:
1258+
if use_spin_quant == "cuda":
11871259
from .source_transformation.spin_quant import (
11881260
inject_fast_hadamard_transform_cuda_for_spin_quant,
11891261
)
11901262

11911263
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
1192-
elif args.use_spin_quant == "native":
1264+
elif use_spin_quant == "native":
11931265
from .source_transformation.spin_quant import (
11941266
inject_fast_hadamard_transform_native_for_spin_quant,
11951267
)
11961268

11971269
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
11981270

1199-
if args.embedding_quantize:
1271+
if embedding_quantize:
12001272
"""
12011273
When this option is selected, it finds all embedding layers and transforms
12021274
into quantized embedding equivalent module.
@@ -1206,12 +1278,25 @@ def _get_source_transforms( # noqa
12061278
transformations based on the given checkpoint first. In those cases,
12071279
this wil be a no-op.
12081280
"""
1209-
modelname = f"{modelname}_e"
1281+
# Create a mock args object with the necessary attributes
1282+
class Args:
1283+
pass
1284+
args = Args()
1285+
args.checkpoint = checkpoint
1286+
args.tokenizer_path = tokenizer_path
1287+
args.embedding_quantize = embedding_quantize
1288+
args.use_shared_embedding = use_shared_embedding
1289+
args.use_qat = use_qat
1290+
args.use_lora = use_lora
1291+
args.preq_mode = preq_mode
1292+
args.preq_group_size = preq_group_size
1293+
args.preq_embedding_quantize = preq_embedding_quantize
1294+
12101295
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
12111296

12121297
# quantization_mode should be applied after embedding_quantize
12131298
# to support shared_embedding
1214-
if args.quantization_mode:
1299+
if quantization_mode:
12151300
"""
12161301
When this option is selected, it finds all linear layers and transforms
12171302
into quantized linear equivalent module.
@@ -1225,7 +1310,23 @@ def _get_source_transforms( # noqa
12251310
There are cases where this may be a no-op, namely, if all linears are
12261311
quantized in the checkpoint.
12271312
"""
1228-
modelname = f"{modelname}_q"
1313+
# Create a mock args object with the necessary attributes
1314+
class Args:
1315+
pass
1316+
args = Args()
1317+
args.checkpoint = checkpoint
1318+
args.tokenizer_path = tokenizer_path
1319+
args.quantization_mode = quantization_mode
1320+
args.group_size = group_size
1321+
args.use_shared_embedding = use_shared_embedding
1322+
args.calibration_tasks = calibration_tasks
1323+
args.calibration_limit = calibration_limit
1324+
args.calibration_seq_length = calibration_seq_length
1325+
args.use_shared_embedding = use_shared_embedding
1326+
args.use_qat = use_qat
1327+
args.use_lora = use_lora
1328+
args.preq_mode = preq_mode
1329+
12291330
transforms.append(
12301331
get_quant_weight_transform(
12311332
args=args,
@@ -1234,15 +1335,12 @@ def _get_source_transforms( # noqa
12341335
)
12351336
)
12361337

1237-
if args.expand_rope_table:
1338+
if expand_rope_table:
12381339
transforms.append(materialze_broadcast_of_rope_freq_cis)
12391340

1240-
use_attention_mask_for_custom_sdpa = False
1241-
if isinstance(args, argparse.Namespace):
1242-
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
1243-
use_attention_mask_for_custom_sdpa = True
1341+
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
12441342

1245-
if args.use_sdpa_with_kv_cache:
1343+
if use_sdpa_with_kv_cache:
12461344
transforms.append(replace_kv_cache_with_custom_kv_cache)
12471345
# todo: do this optionally
12481346
# if use attention mask instead of causal attention
@@ -1254,23 +1352,23 @@ def _get_source_transforms( # noqa
12541352
else:
12551353
transforms.append(replace_sdpa_with_custom_op)
12561354

1257-
if args.quantize_kv_cache:
1258-
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
1355+
if quantize_kv_cache:
1356+
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
12591357
transforms.append(replace_kv_cache_with_quantized_kv_cache)
12601358
# Right now
12611359
transforms.append(replace_sdpa_with_quantized_sdpa)
12621360

1263-
if args.use_kv_cache:
1264-
if args.qnn:
1361+
if use_kv_cache:
1362+
if qnn:
12651363
from executorch.backends.qualcomm.utils.utils import (
12661364
convert_linear_to_conv2d,
12671365
)
12681366

1269-
if args.use_qnn_sha:
1270-
if args.optimized_rotation_path:
1367+
if use_qnn_sha:
1368+
if optimized_rotation_path:
12711369
transforms.append(fuse_layer_norms)
12721370
transforms.append(
1273-
get_model_with_r1_r2(args.optimized_rotation_path)
1371+
get_model_with_r1_r2(optimized_rotation_path)
12741372
)
12751373
transforms.append(replace_attention_to_attention_sha)
12761374
transforms.append(replace_causal_mask)
@@ -1282,29 +1380,29 @@ def _get_source_transforms( # noqa
12821380
transforms.append(replace_sdpa_with_flex_sdpa)
12831381
transforms.append(replace_causal_mask)
12841382
transforms.append(replace_rms_norm_with_native_rms_norm)
1285-
if args.optimized_rotation_path:
1383+
if optimized_rotation_path:
12861384
transforms.append(fuse_layer_norms)
12871385
transforms.append(
1288-
get_model_with_r1_r2(args.optimized_rotation_path)
1386+
get_model_with_r1_r2(optimized_rotation_path)
12891387
)
12901388
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
12911389
transforms.append(convert_linear_to_conv2d)
12921390

1293-
elif args.mps:
1391+
elif mps:
12941392
# Currently mps doesn't support sdpa op, use the simpler decomposition
12951393
# to get free perf gain.
12961394
transforms.append(replace_sdpa_with_simple_sdpa)
12971395
transforms.append(replace_causal_mask)
12981396

1299-
elif args.coreml:
1397+
elif coreml:
13001398
# iOS 18 introduced fused sdpa op
1301-
if args.coreml_ios >= 18:
1399+
if coreml_ios >= 18:
13021400
transforms.append(replace_sdpa_with_coreml_sdpa)
13031401
else:
13041402
transforms.append(replace_sdpa_with_simple_sdpa)
13051403
transforms.append(replace_kv_cache_with_coreml_kv_cache)
13061404

1307-
if args.vulkan:
1405+
if vulkan:
13081406
transforms.append(replace_with_vulkan_rotary_emb)
13091407

13101408
return transforms

0 commit comments

Comments
 (0)