Skip to content

Commit 20ec41d

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Refactor _get_source_transforms to remove args (#10519)
Summary: Refactor `_get_source_transforsm` to remove args Reviewed By: iseeyuan Differential Revision: D73800023 Pulled By: jackzhxng
1 parent a868166 commit 20ec41d

File tree

1 file changed

+117
-32
lines changed

1 file changed

+117
-32
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 117 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -661,10 +661,31 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661661
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
662662
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
663663
_get_source_transforms(
664-
modelname=args.model,
665664
dtype_override=dtype_override,
665+
checkpoint=args.checkpoint,
666666
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
667-
args=args,
667+
tokenizer_path=args.tokenizer_path,
668+
use_spin_quant=args.use_spin_quant,
669+
embedding_quantize=args.embedding_quantize,
670+
quantization_mode=args.quantization_mode,
671+
expand_rope_table=args.expand_rope_table,
672+
use_custom_sdpa_with_attention_mask=getattr(args, "use_custom_sdpa_with_attention_mask", False),
673+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
674+
quantize_kv_cache=args.quantize_kv_cache,
675+
use_kv_cache=args.use_kv_cache,
676+
qnn=args.qnn,
677+
use_qnn_sha=args.use_qnn_sha,
678+
optimized_rotation_path=args.optimized_rotation_path,
679+
mps=args.mps,
680+
coreml=args.coreml,
681+
coreml_ios=args.coreml_ios,
682+
vulkan=args.vulkan,
683+
use_shared_embedding=args.use_shared_embedding,
684+
use_qat=args.use_qat,
685+
use_lora=args.use_lora,
686+
preq_mode=args.preq_mode,
687+
preq_group_size=args.preq_group_size,
688+
preq_embedding_quantize=args.preq_embedding_quantize,
668689
)
669690
)
670691

@@ -1155,23 +1176,65 @@ def _load_llama_model(
11551176

11561177

11571178
def _get_source_transforms( # noqa
1158-
modelname: str,
11591179
dtype_override: DType,
11601180
*,
1181+
checkpoint: Optional[str] = None,
11611182
checkpoint_dtype: Optional[DType] = None,
1162-
args,
1183+
tokenizer_path: Optional[str] = None,
1184+
use_spin_quant: Optional[str] = None,
1185+
embedding_quantize: Optional[str] = None,
1186+
quantization_mode: Optional[str] = None,
1187+
expand_rope_table: bool = False,
1188+
use_custom_sdpa_with_attention_mask: bool = False,
1189+
use_sdpa_with_kv_cache: bool = False,
1190+
quantize_kv_cache: bool = False,
1191+
use_kv_cache: bool = False,
1192+
qnn: bool = False,
1193+
use_qnn_sha: bool = False,
1194+
optimized_rotation_path: Optional[str] = None,
1195+
mps: bool = False,
1196+
coreml: bool = False,
1197+
coreml_ios: int = 15,
1198+
vulkan: bool = False,
1199+
use_shared_embedding: bool = False,
1200+
use_qat: bool = False,
1201+
use_lora: int = 0,
1202+
preq_mode: Optional[str] = None,
1203+
preq_group_size: Optional[int] = None,
1204+
preq_embedding_quantize: Optional[str] = None,
11631205
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
11641206
"""
11651207
Return a list of functions that transform a graph.
11661208
11671209
Args:
1168-
modelname: The name of the model.
11691210
dtype_override: The dtype to use for the model.
1211+
checkpoint: Path to the checkpoint file.
11701212
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
11711213
it means that you want to run quantize transformations on the weights represented
11721214
in their original dtype, while the overall dtype of the model maybe something
11731215
different. If not specified, defaults to dtype_override.
1174-
args: The arguments passed to the script.
1216+
tokenizer_path: Path to the tokenizer file.
1217+
use_spin_quant: Type of spin quant to use ("cuda" or "native").
1218+
embedding_quantize: Type of embedding quantization.
1219+
quantization_mode: Type of quantization mode.
1220+
expand_rope_table: Whether to expand rope table.
1221+
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1222+
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1223+
quantize_kv_cache: Whether to quantize KV cache.
1224+
use_kv_cache: Whether to use KV cache.
1225+
qnn: Whether to use QNN.
1226+
use_qnn_sha: Whether to use QNN SHA.
1227+
optimized_rotation_path: Path to optimized rotation.
1228+
mps: Whether to use MPS.
1229+
coreml: Whether to use CoreML.
1230+
coreml_ios: CoreML iOS version.
1231+
vulkan: Whether to use Vulkan.
1232+
use_shared_embedding: Whether to use shared embedding.
1233+
use_qat: Whether to use QAT.
1234+
use_lora: LoRA rank (0 means no LoRA).
1235+
preq_mode: Pre-quantization mode.
1236+
preq_group_size: Pre-quantization group size.
1237+
preq_embedding_quantize: Pre-quantization embedding quantize.
11751238
11761239
Returns:
11771240
A list of transformation functions.
@@ -1182,21 +1245,21 @@ def _get_source_transforms( # noqa
11821245

11831246
transforms = []
11841247

1185-
if args.use_spin_quant:
1186-
if args.use_spin_quant == "cuda":
1248+
if use_spin_quant:
1249+
if use_spin_quant == "cuda":
11871250
from .source_transformation.spin_quant import (
11881251
inject_fast_hadamard_transform_cuda_for_spin_quant,
11891252
)
11901253

11911254
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
1192-
elif args.use_spin_quant == "native":
1255+
elif use_spin_quant == "native":
11931256
from .source_transformation.spin_quant import (
11941257
inject_fast_hadamard_transform_native_for_spin_quant,
11951258
)
11961259

11971260
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
11981261

1199-
if args.embedding_quantize:
1262+
if embedding_quantize:
12001263
"""
12011264
When this option is selected, it finds all embedding layers and transforms
12021265
into quantized embedding equivalent module.
@@ -1206,12 +1269,25 @@ def _get_source_transforms( # noqa
12061269
transformations based on the given checkpoint first. In those cases,
12071270
this wil be a no-op.
12081271
"""
1209-
modelname = f"{modelname}_e"
1272+
# Create a mock args object with the necessary attributes
1273+
class Args:
1274+
pass
1275+
args = Args()
1276+
args.checkpoint = checkpoint
1277+
args.tokenizer_path = tokenizer_path
1278+
args.embedding_quantize = embedding_quantize
1279+
args.use_shared_embedding = use_shared_embedding
1280+
args.use_qat = use_qat
1281+
args.use_lora = use_lora
1282+
args.preq_mode = preq_mode
1283+
args.preq_group_size = preq_group_size
1284+
args.preq_embedding_quantize = preq_embedding_quantize
1285+
12101286
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
12111287

12121288
# quantization_mode should be applied after embedding_quantize
12131289
# to support shared_embedding
1214-
if args.quantization_mode:
1290+
if quantization_mode:
12151291
"""
12161292
When this option is selected, it finds all linear layers and transforms
12171293
into quantized linear equivalent module.
@@ -1225,7 +1301,19 @@ def _get_source_transforms( # noqa
12251301
There are cases where this may be a no-op, namely, if all linears are
12261302
quantized in the checkpoint.
12271303
"""
1228-
modelname = f"{modelname}_q"
1304+
# Create a mock args object with the necessary attributes
1305+
class Args:
1306+
pass
1307+
args = Args()
1308+
args.checkpoint = checkpoint
1309+
args.tokenizer_path = tokenizer_path
1310+
args.quantization_mode = quantization_mode
1311+
args.group_size = preq_group_size # Using preq_group_size as group_size
1312+
args.use_shared_embedding = use_shared_embedding
1313+
args.use_qat = use_qat
1314+
args.use_lora = use_lora
1315+
args.preq_mode = preq_mode
1316+
12291317
transforms.append(
12301318
get_quant_weight_transform(
12311319
args=args,
@@ -1234,15 +1322,12 @@ def _get_source_transforms( # noqa
12341322
)
12351323
)
12361324

1237-
if args.expand_rope_table:
1325+
if expand_rope_table:
12381326
transforms.append(materialze_broadcast_of_rope_freq_cis)
12391327

1240-
use_attention_mask_for_custom_sdpa = False
1241-
if isinstance(args, argparse.Namespace):
1242-
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
1243-
use_attention_mask_for_custom_sdpa = True
1328+
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
12441329

1245-
if args.use_sdpa_with_kv_cache:
1330+
if use_sdpa_with_kv_cache:
12461331
transforms.append(replace_kv_cache_with_custom_kv_cache)
12471332
# todo: do this optionally
12481333
# if use attention mask instead of causal attention
@@ -1254,23 +1339,23 @@ def _get_source_transforms( # noqa
12541339
else:
12551340
transforms.append(replace_sdpa_with_custom_op)
12561341

1257-
if args.quantize_kv_cache:
1258-
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
1342+
if quantize_kv_cache:
1343+
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
12591344
transforms.append(replace_kv_cache_with_quantized_kv_cache)
12601345
# Right now
12611346
transforms.append(replace_sdpa_with_quantized_sdpa)
12621347

1263-
if args.use_kv_cache:
1264-
if args.qnn:
1348+
if use_kv_cache:
1349+
if qnn:
12651350
from executorch.backends.qualcomm.utils.utils import (
12661351
convert_linear_to_conv2d,
12671352
)
12681353

1269-
if args.use_qnn_sha:
1270-
if args.optimized_rotation_path:
1354+
if use_qnn_sha:
1355+
if optimized_rotation_path:
12711356
transforms.append(fuse_layer_norms)
12721357
transforms.append(
1273-
get_model_with_r1_r2(args.optimized_rotation_path)
1358+
get_model_with_r1_r2(optimized_rotation_path)
12741359
)
12751360
transforms.append(replace_attention_to_attention_sha)
12761361
transforms.append(replace_causal_mask)
@@ -1282,29 +1367,29 @@ def _get_source_transforms( # noqa
12821367
transforms.append(replace_sdpa_with_flex_sdpa)
12831368
transforms.append(replace_causal_mask)
12841369
transforms.append(replace_rms_norm_with_native_rms_norm)
1285-
if args.optimized_rotation_path:
1370+
if optimized_rotation_path:
12861371
transforms.append(fuse_layer_norms)
12871372
transforms.append(
1288-
get_model_with_r1_r2(args.optimized_rotation_path)
1373+
get_model_with_r1_r2(optimized_rotation_path)
12891374
)
12901375
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
12911376
transforms.append(convert_linear_to_conv2d)
12921377

1293-
elif args.mps:
1378+
elif mps:
12941379
# Currently mps doesn't support sdpa op, use the simpler decomposition
12951380
# to get free perf gain.
12961381
transforms.append(replace_sdpa_with_simple_sdpa)
12971382
transforms.append(replace_causal_mask)
12981383

1299-
elif args.coreml:
1384+
elif coreml:
13001385
# iOS 18 introduced fused sdpa op
1301-
if args.coreml_ios >= 18:
1386+
if coreml_ios >= 18:
13021387
transforms.append(replace_sdpa_with_coreml_sdpa)
13031388
else:
13041389
transforms.append(replace_sdpa_with_simple_sdpa)
13051390
transforms.append(replace_kv_cache_with_coreml_kv_cache)
13061391

1307-
if args.vulkan:
1392+
if vulkan:
13081393
transforms.append(replace_with_vulkan_rotary_emb)
13091394

13101395
return transforms

0 commit comments

Comments
 (0)