Skip to content

Commit af4edc6

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Refactor _get_source_transforms to remove args (#10519)
Summary: Refactor `_get_source_transforsm` to remove args Reviewed By: iseeyuan Differential Revision: D73800023 Pulled By: jackzhxng
1 parent 1da5168 commit af4edc6

File tree

1 file changed

+129
-32
lines changed

1 file changed

+129
-32
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 129 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -661,10 +661,35 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661661
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
662662
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
663663
_get_source_transforms(
664-
modelname=args.model,
665664
dtype_override=dtype_override,
665+
checkpoint=args.checkpoint,
666666
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
667-
args=args,
667+
tokenizer_path=args.tokenizer_path,
668+
use_spin_quant=args.use_spin_quant,
669+
embedding_quantize=args.embedding_quantize,
670+
use_shared_embedding=args.use_shared_embedding,
671+
quantization_mode=args.quantization_mode,
672+
group_size=args.group_size,
673+
calibration_tasks=args.calibration_tasks,
674+
calibration_limit=args.calibration_limit,
675+
calibration_seq_length=args.calibration_seq_length,
676+
expand_rope_table=args.expand_rope_table,
677+
use_custom_sdpa_with_attention_mask=getattr(args, "use_custom_sdpa_with_attention_mask", False),
678+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
679+
quantize_kv_cache=args.quantize_kv_cache,
680+
use_kv_cache=args.use_kv_cache,
681+
qnn=args.qnn,
682+
use_qnn_sha=args.use_qnn_sha,
683+
optimized_rotation_path=args.optimized_rotation_path,
684+
mps=args.mps,
685+
coreml=args.coreml,
686+
coreml_ios=args.coreml_ios,
687+
vulkan=args.vulkan,
688+
use_qat=args.use_qat,
689+
use_lora=args.use_lora,
690+
preq_mode=args.preq_mode,
691+
preq_group_size=args.preq_group_size,
692+
preq_embedding_quantize=args.preq_embedding_quantize,
668693
)
669694
)
670695

@@ -1155,23 +1180,69 @@ def _load_llama_model(
11551180

11561181

11571182
def _get_source_transforms( # noqa
1158-
modelname: str,
11591183
dtype_override: DType,
11601184
*,
1185+
checkpoint: Optional[str] = None,
11611186
checkpoint_dtype: Optional[DType] = None,
1162-
args,
1187+
tokenizer_path: Optional[str] = None,
1188+
use_spin_quant: Optional[str] = None,
1189+
embedding_quantize: Optional[str] = None,
1190+
use_shared_embedding: bool = False,
1191+
quantization_mode: Optional[str] = None,
1192+
group_size: Optional[int] = None,
1193+
calibration_tasks: Optional[List[str]] = None,
1194+
calibration_limit: Optional[int] = None,
1195+
calibration_seq_length: Optional[int] = None,
1196+
expand_rope_table: bool = False,
1197+
use_custom_sdpa_with_attention_mask: bool = False,
1198+
use_sdpa_with_kv_cache: bool = False,
1199+
quantize_kv_cache: bool = False,
1200+
use_kv_cache: bool = False,
1201+
qnn: bool = False,
1202+
use_qnn_sha: bool = False,
1203+
optimized_rotation_path: Optional[str] = None,
1204+
mps: bool = False,
1205+
coreml: bool = False,
1206+
coreml_ios: int = 15,
1207+
vulkan: bool = False,
1208+
use_qat: bool = False,
1209+
use_lora: int = 0,
1210+
preq_mode: Optional[str] = None,
1211+
preq_group_size: Optional[int] = None,
1212+
preq_embedding_quantize: Optional[str] = None,
11631213
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
11641214
"""
11651215
Return a list of functions that transform a graph.
11661216
11671217
Args:
1168-
modelname: The name of the model.
11691218
dtype_override: The dtype to use for the model.
1219+
checkpoint: Path to the checkpoint file.
11701220
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
11711221
it means that you want to run quantize transformations on the weights represented
11721222
in their original dtype, while the overall dtype of the model maybe something
11731223
different. If not specified, defaults to dtype_override.
1174-
args: The arguments passed to the script.
1224+
tokenizer_path: Path to the tokenizer file.
1225+
use_spin_quant: Type of spin quant to use ("cuda" or "native").
1226+
embedding_quantize: Type of embedding quantization.
1227+
quantization_mode: Type of quantization mode.
1228+
expand_rope_table: Whether to expand rope table.
1229+
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1230+
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1231+
quantize_kv_cache: Whether to quantize KV cache.
1232+
use_kv_cache: Whether to use KV cache.
1233+
qnn: Whether to use QNN.
1234+
use_qnn_sha: Whether to use QNN SHA.
1235+
optimized_rotation_path: Path to optimized rotation.
1236+
mps: Whether to use MPS.
1237+
coreml: Whether to use CoreML.
1238+
coreml_ios: CoreML iOS version.
1239+
vulkan: Whether to use Vulkan.
1240+
use_shared_embedding: Whether to use shared embedding.
1241+
use_qat: Whether to use QAT.
1242+
use_lora: LoRA rank (0 means no LoRA).
1243+
preq_mode: Pre-quantization mode.
1244+
preq_group_size: Pre-quantization group size.
1245+
preq_embedding_quantize: Pre-quantization embedding quantize.
11751246
11761247
Returns:
11771248
A list of transformation functions.
@@ -1182,21 +1253,21 @@ def _get_source_transforms( # noqa
11821253

11831254
transforms = []
11841255

1185-
if args.use_spin_quant:
1186-
if args.use_spin_quant == "cuda":
1256+
if use_spin_quant:
1257+
if use_spin_quant == "cuda":
11871258
from .source_transformation.spin_quant import (
11881259
inject_fast_hadamard_transform_cuda_for_spin_quant,
11891260
)
11901261

11911262
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
1192-
elif args.use_spin_quant == "native":
1263+
elif use_spin_quant == "native":
11931264
from .source_transformation.spin_quant import (
11941265
inject_fast_hadamard_transform_native_for_spin_quant,
11951266
)
11961267

11971268
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
11981269

1199-
if args.embedding_quantize:
1270+
if embedding_quantize:
12001271
"""
12011272
When this option is selected, it finds all embedding layers and transforms
12021273
into quantized embedding equivalent module.
@@ -1206,12 +1277,25 @@ def _get_source_transforms( # noqa
12061277
transformations based on the given checkpoint first. In those cases,
12071278
this wil be a no-op.
12081279
"""
1209-
modelname = f"{modelname}_e"
1280+
# Create a mock args object with the necessary attributes
1281+
class Args:
1282+
pass
1283+
args = Args()
1284+
args.checkpoint = checkpoint
1285+
args.tokenizer_path = tokenizer_path
1286+
args.embedding_quantize = embedding_quantize
1287+
args.use_shared_embedding = use_shared_embedding
1288+
args.use_qat = use_qat
1289+
args.use_lora = use_lora
1290+
args.preq_mode = preq_mode
1291+
args.preq_group_size = preq_group_size
1292+
args.preq_embedding_quantize = preq_embedding_quantize
1293+
12101294
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
12111295

12121296
# quantization_mode should be applied after embedding_quantize
12131297
# to support shared_embedding
1214-
if args.quantization_mode:
1298+
if quantization_mode:
12151299
"""
12161300
When this option is selected, it finds all linear layers and transforms
12171301
into quantized linear equivalent module.
@@ -1225,7 +1309,23 @@ def _get_source_transforms( # noqa
12251309
There are cases where this may be a no-op, namely, if all linears are
12261310
quantized in the checkpoint.
12271311
"""
1228-
modelname = f"{modelname}_q"
1312+
# Create a mock args object with the necessary attributes
1313+
class Args:
1314+
pass
1315+
args = Args()
1316+
args.checkpoint = checkpoint
1317+
args.tokenizer_path = tokenizer_path
1318+
args.quantization_mode = quantization_mode
1319+
args.group_size = group_size
1320+
args.use_shared_embedding = use_shared_embedding
1321+
args.calibration_tasks = calibration_tasks
1322+
args.calibration_limit = calibration_limit
1323+
args.calibration_seq_length = calibration_seq_length
1324+
args.use_shared_embedding = use_shared_embedding
1325+
args.use_qat = use_qat
1326+
args.use_lora = use_lora
1327+
args.preq_mode = preq_mode
1328+
12291329
transforms.append(
12301330
get_quant_weight_transform(
12311331
args=args,
@@ -1234,15 +1334,12 @@ def _get_source_transforms( # noqa
12341334
)
12351335
)
12361336

1237-
if args.expand_rope_table:
1337+
if expand_rope_table:
12381338
transforms.append(materialze_broadcast_of_rope_freq_cis)
12391339

1240-
use_attention_mask_for_custom_sdpa = False
1241-
if isinstance(args, argparse.Namespace):
1242-
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
1243-
use_attention_mask_for_custom_sdpa = True
1340+
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
12441341

1245-
if args.use_sdpa_with_kv_cache:
1342+
if use_sdpa_with_kv_cache:
12461343
transforms.append(replace_kv_cache_with_custom_kv_cache)
12471344
# todo: do this optionally
12481345
# if use attention mask instead of causal attention
@@ -1254,23 +1351,23 @@ def _get_source_transforms( # noqa
12541351
else:
12551352
transforms.append(replace_sdpa_with_custom_op)
12561353

1257-
if args.quantize_kv_cache:
1258-
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
1354+
if quantize_kv_cache:
1355+
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
12591356
transforms.append(replace_kv_cache_with_quantized_kv_cache)
12601357
# Right now
12611358
transforms.append(replace_sdpa_with_quantized_sdpa)
12621359

1263-
if args.use_kv_cache:
1264-
if args.qnn:
1360+
if use_kv_cache:
1361+
if qnn:
12651362
from executorch.backends.qualcomm.utils.utils import (
12661363
convert_linear_to_conv2d,
12671364
)
12681365

1269-
if args.use_qnn_sha:
1270-
if args.optimized_rotation_path:
1366+
if use_qnn_sha:
1367+
if optimized_rotation_path:
12711368
transforms.append(fuse_layer_norms)
12721369
transforms.append(
1273-
get_model_with_r1_r2(args.optimized_rotation_path)
1370+
get_model_with_r1_r2(optimized_rotation_path)
12741371
)
12751372
transforms.append(replace_attention_to_attention_sha)
12761373
transforms.append(replace_causal_mask)
@@ -1282,29 +1379,29 @@ def _get_source_transforms( # noqa
12821379
transforms.append(replace_sdpa_with_flex_sdpa)
12831380
transforms.append(replace_causal_mask)
12841381
transforms.append(replace_rms_norm_with_native_rms_norm)
1285-
if args.optimized_rotation_path:
1382+
if optimized_rotation_path:
12861383
transforms.append(fuse_layer_norms)
12871384
transforms.append(
1288-
get_model_with_r1_r2(args.optimized_rotation_path)
1385+
get_model_with_r1_r2(optimized_rotation_path)
12891386
)
12901387
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
12911388
transforms.append(convert_linear_to_conv2d)
12921389

1293-
elif args.mps:
1390+
elif mps:
12941391
# Currently mps doesn't support sdpa op, use the simpler decomposition
12951392
# to get free perf gain.
12961393
transforms.append(replace_sdpa_with_simple_sdpa)
12971394
transforms.append(replace_causal_mask)
12981395

1299-
elif args.coreml:
1396+
elif coreml:
13001397
# iOS 18 introduced fused sdpa op
1301-
if args.coreml_ios >= 18:
1398+
if coreml_ios >= 18:
13021399
transforms.append(replace_sdpa_with_coreml_sdpa)
13031400
else:
13041401
transforms.append(replace_sdpa_with_simple_sdpa)
13051402
transforms.append(replace_kv_cache_with_coreml_kv_cache)
13061403

1307-
if args.vulkan:
1404+
if vulkan:
13081405
transforms.append(replace_with_vulkan_rotary_emb)
13091406

13101407
return transforms

0 commit comments

Comments
 (0)