@@ -477,7 +477,145 @@ def from_args(args: argparse.Namespace) -> Self:
477
477
"""
478
478
llm_config = LlmConfig ()
479
479
480
- # TODO: conversion code.
480
+ # BaseConfig
481
+ if hasattr (args , "model" ):
482
+ llm_config .base .model_class = ModelType (args .model )
483
+ if hasattr (args , "params" ):
484
+ llm_config .base .params = args .params
485
+ if hasattr (args , "checkpoint" ):
486
+ llm_config .base .checkpoint = args .checkpoint
487
+ if hasattr (args , "checkpoint_dir" ):
488
+ llm_config .base .checkpoint_dir = args .checkpoint_dir
489
+ if hasattr (args , "tokenizer_path" ):
490
+ llm_config .base .tokenizer_path = args .tokenizer_path
491
+ if hasattr (args , "metadata" ):
492
+ llm_config .base .metadata = args .metadata
493
+ if hasattr (args , "use_lora" ):
494
+ llm_config .base .use_lora = args .use_lora
495
+ if hasattr (args , "fairseq2" ):
496
+ llm_config .base .fairseq2 = args .fairseq2
497
+
498
+ # PreqMode settings
499
+ if hasattr (args , "preq_mode" ) and args .preq_mode :
500
+ llm_config .base .preq_mode = PreqMode (args .preq_mode )
501
+ if hasattr (args , "preq_group_size" ):
502
+ llm_config .base .preq_group_size = args .preq_group_size
503
+ if hasattr (args , "preq_embedding_quantize" ):
504
+ llm_config .base .preq_embedding_quantize = args .preq_embedding_quantize
505
+
506
+ # ModelConfig
507
+ if hasattr (args , "dtype_override" ):
508
+ llm_config .model .dtype_override = DtypeOverride (args .dtype_override )
509
+ if hasattr (args , "enable_dynamic_shape" ):
510
+ llm_config .model .enable_dynamic_shape = args .enable_dynamic_shape
511
+ if hasattr (args , "use_shared_embedding" ):
512
+ llm_config .model .use_shared_embedding = args .use_shared_embedding
513
+ if hasattr (args , "use_sdpa_with_kv_cache" ):
514
+ llm_config .model .use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache
515
+ if hasattr (args , "expand_rope_table" ):
516
+ llm_config .model .expand_rope_table = args .expand_rope_table
517
+ if hasattr (args , "use_attention_sink" ):
518
+ llm_config .model .use_attention_sink = args .use_attention_sink
519
+ if hasattr (args , "output_prune_map" ):
520
+ llm_config .model .output_prune_map = args .output_prune_map
521
+ if hasattr (args , "input_prune_map" ):
522
+ llm_config .model .input_prune_map = args .input_prune_map
523
+ if hasattr (args , "use_kv_cache" ):
524
+ llm_config .model .use_kv_cache = args .use_kv_cache
525
+ if hasattr (args , "quantize_kv_cache" ):
526
+ llm_config .model .quantize_kv_cache = args .quantize_kv_cache
527
+ if hasattr (args , "local_global_attention" ):
528
+ llm_config .model .local_global_attention = args .local_global_attention
529
+
530
+ # ExportConfig
531
+ if hasattr (args , "max_seq_length" ):
532
+ llm_config .export .max_seq_length = args .max_seq_length
533
+ if hasattr (args , "max_context_length" ):
534
+ llm_config .export .max_context_length = args .max_context_length
535
+ if hasattr (args , "output_dir" ):
536
+ llm_config .export .output_dir = args .output_dir
537
+ if hasattr (args , "output_name" ):
538
+ llm_config .export .output_name = args .output_name
539
+ if hasattr (args , "so_library" ):
540
+ llm_config .export .so_library = args .so_library
541
+ if hasattr (args , "export_only" ):
542
+ llm_config .export .export_only = args .export_only
543
+
544
+ # QuantizationConfig
545
+ if hasattr (args , "quantization_mode" ):
546
+ llm_config .quantization .qmode = args .quantization_mode
547
+ if hasattr (args , "embedding_quantize" ):
548
+ llm_config .quantization .embedding_quantize = args .embedding_quantize
549
+ if hasattr (args , "pt2e_quantize" ) and args .pt2e_quantize :
550
+ llm_config .quantization .pt2e_quantize = Pt2eQuantize (args .pt2e_quantize )
551
+ if hasattr (args , "group_size" ):
552
+ llm_config .quantization .group_size = args .group_size
553
+ if hasattr (args , "use_spin_quant" ) and args .use_spin_quant :
554
+ llm_config .quantization .use_spin_quant = SpinQuant (args .use_spin_quant )
555
+ if hasattr (args , "use_qat" ):
556
+ llm_config .quantization .use_qat = args .use_qat
557
+ if hasattr (args , "calibration_tasks" ):
558
+ llm_config .quantization .calibration_tasks = args .calibration_tasks
559
+ if hasattr (args , "calibration_limit" ):
560
+ llm_config .quantization .calibration_limit = args .calibration_limit
561
+ if hasattr (args , "calibration_seq_length" ):
562
+ llm_config .quantization .calibration_seq_length = args .calibration_seq_length
563
+ if hasattr (args , "calibration_data" ):
564
+ llm_config .quantization .calibration_data = args .calibration_data
565
+
566
+ # BackendConfig - XNNPack
567
+ if hasattr (args , "xnnpack" ):
568
+ llm_config .backend .xnnpack .enabled = args .xnnpack
569
+ if hasattr (args , "xnnpack_extended_ops" ):
570
+ llm_config .backend .xnnpack .extended_ops = args .xnnpack_extended_ops
571
+
572
+ # CoreML
573
+ if hasattr (args , "coreml" ):
574
+ llm_config .backend .coreml .enabled = args .coreml
575
+ llm_config .backend .coreml .enable_state = getattr (args , "coreml_enable_state" , False )
576
+ llm_config .backend .coreml .preserve_sdpa = getattr (
577
+ args , "coreml_preserve_sdpa" , False
578
+ )
579
+ if hasattr (args , "coreml_quantize" ) and args .coreml_quantize :
580
+ llm_config .backend .coreml .quantize = CoreMLQuantize (args .coreml_quantize )
581
+ if hasattr (args , "coreml_ios" ):
582
+ llm_config .backend .coreml .ios = args .coreml_ios
583
+ if hasattr (args , "coreml_compute_units" ):
584
+ llm_config .backend .coreml .compute_units = CoreMLComputeUnit (
585
+ args .coreml_compute_units
586
+ )
587
+
588
+ # Vulkan
589
+ if hasattr (args , "vulkan" ):
590
+ llm_config .backend .vulkan .enabled = args .vulkan
591
+
592
+ # QNN
593
+ if hasattr (args , "qnn" ):
594
+ llm_config .backend .qnn .enabled = args .qnn
595
+ if hasattr (args , "use_qnn_sha" ):
596
+ llm_config .backend .qnn .use_sha = args .use_qnn_sha
597
+ if hasattr (args , "soc_model" ):
598
+ llm_config .backend .qnn .soc_model = args .soc_model
599
+ if hasattr (args , "optimized_rotation_path" ):
600
+ llm_config .backend .qnn .optimized_rotation_path = args .optimized_rotation_path
601
+ if hasattr (args , "num_sharding" ):
602
+ llm_config .backend .qnn .num_sharding = args .num_sharding
603
+
604
+ # MPS
605
+ if hasattr (args , "mps" ):
606
+ llm_config .backend .mps .enabled = args .mps
607
+
608
+ # DebugConfig
609
+ if hasattr (args , "profile_memory" ):
610
+ llm_config .debug .profile_memory = args .profile_memory
611
+ if hasattr (args , "profile_path" ):
612
+ llm_config .debug .profile_path = args .profile_path
613
+ if hasattr (args , "generate_etrecord" ):
614
+ llm_config .debug .generate_etrecord = args .generate_etrecord
615
+ if hasattr (args , "generate_full_logits" ):
616
+ llm_config .debug .generate_full_logits = args .generate_full_logits
617
+ if hasattr (args , "verbose" ):
618
+ llm_config .debug .verbose = args .verbose
481
619
482
620
return llm_config
483
621
0 commit comments