Skip to content

Commit 32636f5

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Refactor _get_source_transforms to remove args (#10519)
Summary: Refactor `_get_source_transforsm` to remove args Reviewed By: iseeyuan Differential Revision: D73800023 Pulled By: jackzhxng
1 parent df8fc61 commit 32636f5

File tree

1 file changed

+135
-36
lines changed

1 file changed

+135
-36
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 135 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -661,10 +661,37 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661661
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
662662
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
663663
_get_source_transforms(
664-
modelname=args.model,
665664
dtype_override=dtype_override,
665+
checkpoint=args.checkpoint,
666666
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
667-
args=args,
667+
tokenizer_path=args.tokenizer_path,
668+
use_spin_quant=args.use_spin_quant,
669+
embedding_quantize=args.embedding_quantize,
670+
use_shared_embedding=args.use_shared_embedding,
671+
quantization_mode=args.quantization_mode,
672+
group_size=args.group_size,
673+
calibration_tasks=args.calibration_tasks,
674+
calibration_limit=args.calibration_limit,
675+
calibration_seq_length=args.calibration_seq_length,
676+
expand_rope_table=args.expand_rope_table,
677+
use_custom_sdpa_with_attention_mask=getattr(
678+
args, "use_custom_sdpa_with_attention_mask", False
679+
),
680+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
681+
quantize_kv_cache=args.quantize_kv_cache,
682+
use_kv_cache=args.use_kv_cache,
683+
qnn=args.qnn,
684+
use_qnn_sha=args.use_qnn_sha,
685+
optimized_rotation_path=args.optimized_rotation_path,
686+
mps=args.mps,
687+
coreml=args.coreml,
688+
coreml_ios=args.coreml_ios,
689+
vulkan=args.vulkan,
690+
use_qat=args.use_qat,
691+
use_lora=args.use_lora,
692+
preq_mode=args.preq_mode,
693+
preq_group_size=args.preq_group_size,
694+
preq_embedding_quantize=args.preq_embedding_quantize,
668695
)
669696
)
670697

@@ -1155,23 +1182,69 @@ def _load_llama_model(
11551182

11561183

11571184
def _get_source_transforms( # noqa
1158-
modelname: str,
11591185
dtype_override: DType,
11601186
*,
1187+
checkpoint: Optional[str] = None,
11611188
checkpoint_dtype: Optional[DType] = None,
1162-
args,
1189+
tokenizer_path: Optional[str] = None,
1190+
use_spin_quant: Optional[str] = None,
1191+
embedding_quantize: Optional[str] = None,
1192+
use_shared_embedding: bool = False,
1193+
quantization_mode: Optional[str] = None,
1194+
group_size: Optional[int] = None,
1195+
calibration_tasks: Optional[List[str]] = None,
1196+
calibration_limit: Optional[int] = None,
1197+
calibration_seq_length: Optional[int] = None,
1198+
expand_rope_table: bool = False,
1199+
use_custom_sdpa_with_attention_mask: bool = False,
1200+
use_sdpa_with_kv_cache: bool = False,
1201+
quantize_kv_cache: bool = False,
1202+
use_kv_cache: bool = False,
1203+
qnn: bool = False,
1204+
use_qnn_sha: bool = False,
1205+
optimized_rotation_path: Optional[str] = None,
1206+
mps: bool = False,
1207+
coreml: bool = False,
1208+
coreml_ios: int = 15,
1209+
vulkan: bool = False,
1210+
use_qat: bool = False,
1211+
use_lora: int = 0,
1212+
preq_mode: Optional[str] = None,
1213+
preq_group_size: Optional[int] = None,
1214+
preq_embedding_quantize: Optional[str] = None,
11631215
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
11641216
"""
11651217
Return a list of functions that transform a graph.
11661218
11671219
Args:
1168-
modelname: The name of the model.
11691220
dtype_override: The dtype to use for the model.
1221+
checkpoint: Path to the checkpoint file.
11701222
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
11711223
it means that you want to run quantize transformations on the weights represented
11721224
in their original dtype, while the overall dtype of the model maybe something
11731225
different. If not specified, defaults to dtype_override.
1174-
args: The arguments passed to the script.
1226+
tokenizer_path: Path to the tokenizer file.
1227+
use_spin_quant: Type of spin quant to use ("cuda" or "native").
1228+
embedding_quantize: Type of embedding quantization.
1229+
quantization_mode: Type of quantization mode.
1230+
expand_rope_table: Whether to expand rope table.
1231+
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1232+
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1233+
quantize_kv_cache: Whether to quantize KV cache.
1234+
use_kv_cache: Whether to use KV cache.
1235+
qnn: Whether to use QNN.
1236+
use_qnn_sha: Whether to use QNN SHA.
1237+
optimized_rotation_path: Path to optimized rotation.
1238+
mps: Whether to use MPS.
1239+
coreml: Whether to use CoreML.
1240+
coreml_ios: CoreML iOS version.
1241+
vulkan: Whether to use Vulkan.
1242+
use_shared_embedding: Whether to use shared embedding.
1243+
use_qat: Whether to use QAT.
1244+
use_lora: LoRA rank (0 means no LoRA).
1245+
preq_mode: Pre-quantization mode.
1246+
preq_group_size: Pre-quantization group size.
1247+
preq_embedding_quantize: Pre-quantization embedding quantize.
11751248
11761249
Returns:
11771250
A list of transformation functions.
@@ -1182,21 +1255,21 @@ def _get_source_transforms( # noqa
11821255

11831256
transforms = []
11841257

1185-
if args.use_spin_quant:
1186-
if args.use_spin_quant == "cuda":
1258+
if use_spin_quant:
1259+
if use_spin_quant == "cuda":
11871260
from .source_transformation.spin_quant import (
11881261
inject_fast_hadamard_transform_cuda_for_spin_quant,
11891262
)
11901263

11911264
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
1192-
elif args.use_spin_quant == "native":
1265+
elif use_spin_quant == "native":
11931266
from .source_transformation.spin_quant import (
11941267
inject_fast_hadamard_transform_native_for_spin_quant,
11951268
)
11961269

11971270
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
11981271

1199-
if args.embedding_quantize:
1272+
if embedding_quantize:
12001273
"""
12011274
When this option is selected, it finds all embedding layers and transforms
12021275
into quantized embedding equivalent module.
@@ -1206,12 +1279,27 @@ def _get_source_transforms( # noqa
12061279
transformations based on the given checkpoint first. In those cases,
12071280
this wil be a no-op.
12081281
"""
1209-
modelname = f"{modelname}_e"
1282+
1283+
# Create a mock args object with the necessary attributes
1284+
class Args:
1285+
pass
1286+
1287+
args = Args()
1288+
args.checkpoint = checkpoint
1289+
args.tokenizer_path = tokenizer_path
1290+
args.embedding_quantize = embedding_quantize
1291+
args.use_shared_embedding = use_shared_embedding
1292+
args.use_qat = use_qat
1293+
args.use_lora = use_lora
1294+
args.preq_mode = preq_mode
1295+
args.preq_group_size = preq_group_size
1296+
args.preq_embedding_quantize = preq_embedding_quantize
1297+
12101298
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
12111299

12121300
# quantization_mode should be applied after embedding_quantize
12131301
# to support shared_embedding
1214-
if args.quantization_mode:
1302+
if quantization_mode:
12151303
"""
12161304
When this option is selected, it finds all linear layers and transforms
12171305
into quantized linear equivalent module.
@@ -1225,7 +1313,25 @@ def _get_source_transforms( # noqa
12251313
There are cases where this may be a no-op, namely, if all linears are
12261314
quantized in the checkpoint.
12271315
"""
1228-
modelname = f"{modelname}_q"
1316+
1317+
# Create a mock args object with the necessary attributes
1318+
class Args:
1319+
pass
1320+
1321+
args = Args()
1322+
args.checkpoint = checkpoint
1323+
args.tokenizer_path = tokenizer_path
1324+
args.quantization_mode = quantization_mode
1325+
args.group_size = group_size
1326+
args.use_shared_embedding = use_shared_embedding
1327+
args.calibration_tasks = calibration_tasks
1328+
args.calibration_limit = calibration_limit
1329+
args.calibration_seq_length = calibration_seq_length
1330+
args.use_shared_embedding = use_shared_embedding
1331+
args.use_qat = use_qat
1332+
args.use_lora = use_lora
1333+
args.preq_mode = preq_mode
1334+
12291335
transforms.append(
12301336
get_quant_weight_transform(
12311337
args=args,
@@ -1234,15 +1340,12 @@ def _get_source_transforms( # noqa
12341340
)
12351341
)
12361342

1237-
if args.expand_rope_table:
1343+
if expand_rope_table:
12381344
transforms.append(materialze_broadcast_of_rope_freq_cis)
12391345

1240-
use_attention_mask_for_custom_sdpa = False
1241-
if isinstance(args, argparse.Namespace):
1242-
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
1243-
use_attention_mask_for_custom_sdpa = True
1346+
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
12441347

1245-
if args.use_sdpa_with_kv_cache:
1348+
if use_sdpa_with_kv_cache:
12461349
transforms.append(replace_kv_cache_with_custom_kv_cache)
12471350
# todo: do this optionally
12481351
# if use attention mask instead of causal attention
@@ -1254,24 +1357,22 @@ def _get_source_transforms( # noqa
12541357
else:
12551358
transforms.append(replace_sdpa_with_custom_op)
12561359

1257-
if args.quantize_kv_cache:
1258-
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
1360+
if quantize_kv_cache:
1361+
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
12591362
transforms.append(replace_kv_cache_with_quantized_kv_cache)
12601363
# Right now
12611364
transforms.append(replace_sdpa_with_quantized_sdpa)
12621365

1263-
if args.use_kv_cache:
1264-
if args.qnn:
1366+
if use_kv_cache:
1367+
if qnn:
12651368
from executorch.backends.qualcomm.utils.utils import (
12661369
convert_linear_to_conv2d,
12671370
)
12681371

1269-
if args.use_qnn_sha:
1270-
if args.optimized_rotation_path:
1372+
if use_qnn_sha:
1373+
if optimized_rotation_path:
12711374
transforms.append(fuse_layer_norms)
1272-
transforms.append(
1273-
get_model_with_r1_r2(args.optimized_rotation_path)
1274-
)
1375+
transforms.append(get_model_with_r1_r2(optimized_rotation_path))
12751376
transforms.append(replace_attention_to_attention_sha)
12761377
transforms.append(replace_causal_mask)
12771378
transforms.append(replace_rms_norm_with_native_rms_norm)
@@ -1282,29 +1383,27 @@ def _get_source_transforms( # noqa
12821383
transforms.append(replace_sdpa_with_flex_sdpa)
12831384
transforms.append(replace_causal_mask)
12841385
transforms.append(replace_rms_norm_with_native_rms_norm)
1285-
if args.optimized_rotation_path:
1386+
if optimized_rotation_path:
12861387
transforms.append(fuse_layer_norms)
1287-
transforms.append(
1288-
get_model_with_r1_r2(args.optimized_rotation_path)
1289-
)
1388+
transforms.append(get_model_with_r1_r2(optimized_rotation_path))
12901389
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
12911390
transforms.append(convert_linear_to_conv2d)
12921391

1293-
elif args.mps:
1392+
elif mps:
12941393
# Currently mps doesn't support sdpa op, use the simpler decomposition
12951394
# to get free perf gain.
12961395
transforms.append(replace_sdpa_with_simple_sdpa)
12971396
transforms.append(replace_causal_mask)
12981397

1299-
elif args.coreml:
1398+
elif coreml:
13001399
# iOS 18 introduced fused sdpa op
1301-
if args.coreml_ios >= 18:
1400+
if coreml_ios >= 18:
13021401
transforms.append(replace_sdpa_with_coreml_sdpa)
13031402
else:
13041403
transforms.append(replace_sdpa_with_simple_sdpa)
13051404
transforms.append(replace_kv_cache_with_coreml_kv_cache)
13061405

1307-
if args.vulkan:
1406+
if vulkan:
13081407
transforms.append(replace_with_vulkan_rotary_emb)
13091408

13101409
return transforms

0 commit comments

Comments
 (0)