Skip to content

Commit 3c1ece4

Browse files
committed
Convert args to LlmConfig
Pull Request resolved: #11081 @imported-using-ghimport Differential Revision: [D75263990](https://our.internmc.facebook.com/intern/diff/D75263990/) ghstack-source-id: 288807823
1 parent 2996b26 commit 3c1ece4

File tree

1 file changed

+139
-1
lines changed

1 file changed

+139
-1
lines changed

examples/models/llama/config/llm_config.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,145 @@ def from_args(args: argparse.Namespace) -> Self:
477477
"""
478478
llm_config = LlmConfig()
479479

480-
# TODO: conversion code.
480+
# BaseConfig
481+
if hasattr(args, "model"):
482+
llm_config.base.model_class = ModelType(args.model)
483+
if hasattr(args, "params"):
484+
llm_config.base.params = args.params
485+
if hasattr(args, "checkpoint"):
486+
llm_config.base.checkpoint = args.checkpoint
487+
if hasattr(args, "checkpoint_dir"):
488+
llm_config.base.checkpoint_dir = args.checkpoint_dir
489+
if hasattr(args, "tokenizer_path"):
490+
llm_config.base.tokenizer_path = args.tokenizer_path
491+
if hasattr(args, "metadata"):
492+
llm_config.base.metadata = args.metadata
493+
if hasattr(args, "use_lora"):
494+
llm_config.base.use_lora = args.use_lora
495+
if hasattr(args, "fairseq2"):
496+
llm_config.base.fairseq2 = args.fairseq2
497+
498+
# PreqMode settings
499+
if hasattr(args, "preq_mode") and args.preq_mode:
500+
llm_config.base.preq_mode = PreqMode(args.preq_mode)
501+
if hasattr(args, "preq_group_size"):
502+
llm_config.base.preq_group_size = args.preq_group_size
503+
if hasattr(args, "preq_embedding_quantize"):
504+
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize
505+
506+
# ModelConfig
507+
if hasattr(args, "dtype_override"):
508+
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
509+
if hasattr(args, "enable_dynamic_shape"):
510+
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
511+
if hasattr(args, "use_shared_embedding"):
512+
llm_config.model.use_shared_embedding = args.use_shared_embedding
513+
if hasattr(args, "use_sdpa_with_kv_cache"):
514+
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
515+
if hasattr(args, "expand_rope_table"):
516+
llm_config.model.expand_rope_table = args.expand_rope_table
517+
if hasattr(args, "use_attention_sink"):
518+
llm_config.model.use_attention_sink = args.use_attention_sink
519+
if hasattr(args, "output_prune_map"):
520+
llm_config.model.output_prune_map = args.output_prune_map
521+
if hasattr(args, "input_prune_map"):
522+
llm_config.model.input_prune_map = args.input_prune_map
523+
if hasattr(args, "use_kv_cache"):
524+
llm_config.model.use_kv_cache = args.use_kv_cache
525+
if hasattr(args, "quantize_kv_cache"):
526+
llm_config.model.quantize_kv_cache = args.quantize_kv_cache
527+
if hasattr(args, "local_global_attention"):
528+
llm_config.model.local_global_attention = args.local_global_attention
529+
530+
# ExportConfig
531+
if hasattr(args, "max_seq_length"):
532+
llm_config.export.max_seq_length = args.max_seq_length
533+
if hasattr(args, "max_context_length"):
534+
llm_config.export.max_context_length = args.max_context_length
535+
if hasattr(args, "output_dir"):
536+
llm_config.export.output_dir = args.output_dir
537+
if hasattr(args, "output_name"):
538+
llm_config.export.output_name = args.output_name
539+
if hasattr(args, "so_library"):
540+
llm_config.export.so_library = args.so_library
541+
if hasattr(args, "export_only"):
542+
llm_config.export.export_only = args.export_only
543+
544+
# QuantizationConfig
545+
if hasattr(args, "quantization_mode"):
546+
llm_config.quantization.qmode = args.quantization_mode
547+
if hasattr(args, "embedding_quantize"):
548+
llm_config.quantization.embedding_quantize = args.embedding_quantize
549+
if hasattr(args, "pt2e_quantize") and args.pt2e_quantize:
550+
llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize)
551+
if hasattr(args, "group_size"):
552+
llm_config.quantization.group_size = args.group_size
553+
if hasattr(args, "use_spin_quant") and args.use_spin_quant:
554+
llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant)
555+
if hasattr(args, "use_qat"):
556+
llm_config.quantization.use_qat = args.use_qat
557+
if hasattr(args, "calibration_tasks"):
558+
llm_config.quantization.calibration_tasks = args.calibration_tasks
559+
if hasattr(args, "calibration_limit"):
560+
llm_config.quantization.calibration_limit = args.calibration_limit
561+
if hasattr(args, "calibration_seq_length"):
562+
llm_config.quantization.calibration_seq_length = args.calibration_seq_length
563+
if hasattr(args, "calibration_data"):
564+
llm_config.quantization.calibration_data = args.calibration_data
565+
566+
# BackendConfig - XNNPack
567+
if hasattr(args, "xnnpack"):
568+
llm_config.backend.xnnpack.enabled = args.xnnpack
569+
if hasattr(args, "xnnpack_extended_ops"):
570+
llm_config.backend.xnnpack.extended_ops = args.xnnpack_extended_ops
571+
572+
# CoreML
573+
if hasattr(args, "coreml"):
574+
llm_config.backend.coreml.enabled = args.coreml
575+
llm_config.backend.coreml.enable_state = getattr(args, "coreml_enable_state", False)
576+
llm_config.backend.coreml.preserve_sdpa = getattr(
577+
args, "coreml_preserve_sdpa", False
578+
)
579+
if hasattr(args, "coreml_quantize") and args.coreml_quantize:
580+
llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize)
581+
if hasattr(args, "coreml_ios"):
582+
llm_config.backend.coreml.ios = args.coreml_ios
583+
if hasattr(args, "coreml_compute_units"):
584+
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
585+
args.coreml_compute_units
586+
)
587+
588+
# Vulkan
589+
if hasattr(args, "vulkan"):
590+
llm_config.backend.vulkan.enabled = args.vulkan
591+
592+
# QNN
593+
if hasattr(args, "qnn"):
594+
llm_config.backend.qnn.enabled = args.qnn
595+
if hasattr(args, "use_qnn_sha"):
596+
llm_config.backend.qnn.use_sha = args.use_qnn_sha
597+
if hasattr(args, "soc_model"):
598+
llm_config.backend.qnn.soc_model = args.soc_model
599+
if hasattr(args, "optimized_rotation_path"):
600+
llm_config.backend.qnn.optimized_rotation_path = args.optimized_rotation_path
601+
if hasattr(args, "num_sharding"):
602+
llm_config.backend.qnn.num_sharding = args.num_sharding
603+
604+
# MPS
605+
if hasattr(args, "mps"):
606+
llm_config.backend.mps.enabled = args.mps
607+
608+
# DebugConfig
609+
if hasattr(args, "profile_memory"):
610+
llm_config.debug.profile_memory = args.profile_memory
611+
if hasattr(args, "profile_path"):
612+
llm_config.debug.profile_path = args.profile_path
613+
if hasattr(args, "generate_etrecord"):
614+
llm_config.debug.generate_etrecord = args.generate_etrecord
615+
if hasattr(args, "generate_full_logits"):
616+
llm_config.debug.generate_full_logits = args.generate_full_logits
617+
if hasattr(args, "verbose"):
618+
llm_config.debug.verbose = args.verbose
481619

482620
return llm_config
483621

0 commit comments

Comments
 (0)