@@ -26,92 +26,143 @@ def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
26
26
llm_config = LlmConfig ()
27
27
28
28
# BaseConfig
29
- llm_config .base .model_class = ModelType (args .model )
30
- llm_config .base .params = args .params
31
- llm_config .base .checkpoint = args .checkpoint
32
- llm_config .base .checkpoint_dir = args .checkpoint_dir
33
- llm_config .base .tokenizer_path = args .tokenizer_path
34
- llm_config .base .metadata = args .metadata
35
- llm_config .base .use_lora = bool (args .use_lora )
36
- llm_config .base .fairseq2 = args .fairseq2
29
+ if hasattr (args , "model" ):
30
+ llm_config .base .model_class = ModelType (args .model )
31
+ if hasattr (args , "params" ):
32
+ llm_config .base .params = args .params
33
+ if hasattr (args , "checkpoint" ):
34
+ llm_config .base .checkpoint = args .checkpoint
35
+ if hasattr (args , "checkpoint_dir" ):
36
+ llm_config .base .checkpoint_dir = args .checkpoint_dir
37
+ if hasattr (args , "tokenizer_path" ):
38
+ llm_config .base .tokenizer_path = args .tokenizer_path
39
+ if hasattr (args , "metadata" ):
40
+ llm_config .base .metadata = args .metadata
41
+ if hasattr (args , "use_lora" ):
42
+ llm_config .base .use_lora = args .use_lora
43
+ if hasattr (args , "fairseq2" ):
44
+ llm_config .base .fairseq2 = args .fairseq2
37
45
38
46
# PreqMode settings
39
- if args .preq_mode :
47
+ if hasattr ( args , "preq_mode" ) and args .preq_mode :
40
48
llm_config .base .preq_mode = PreqMode (args .preq_mode )
41
- llm_config .base .preq_group_size = args .preq_group_size
42
- llm_config .base .preq_embedding_quantize = args .preq_embedding_quantize
49
+ if hasattr (args , "preq_group_size" ):
50
+ llm_config .base .preq_group_size = args .preq_group_size
51
+ if hasattr (args , "preq_embedding_quantize" ):
52
+ llm_config .base .preq_embedding_quantize = args .preq_embedding_quantize
43
53
44
54
# ModelConfig
45
- llm_config .model .dtype_override = DtypeOverride (args .dtype_override )
46
- llm_config .model .enable_dynamic_shape = args .enable_dynamic_shape
47
- llm_config .model .use_shared_embedding = args .use_shared_embedding
48
- llm_config .model .use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache
49
- llm_config .model .expand_rope_table = args .expand_rope_table
50
- llm_config .model .use_attention_sink = args .use_attention_sink
51
- llm_config .model .output_prune_map = args .output_prune_map
52
- llm_config .model .input_prune_map = args .input_prune_map
53
- llm_config .model .use_kv_cache = args .use_kv_cache
54
- llm_config .model .quantize_kv_cache = args .quantize_kv_cache
55
- llm_config .model .local_global_attention = args .local_global_attention
55
+ if hasattr (args , "dtype_override" ):
56
+ llm_config .model .dtype_override = DtypeOverride (args .dtype_override )
57
+ if hasattr (args , "enable_dynamic_shape" ):
58
+ llm_config .model .enable_dynamic_shape = args .enable_dynamic_shape
59
+ if hasattr (args , "use_shared_embedding" ):
60
+ llm_config .model .use_shared_embedding = args .use_shared_embedding
61
+ if hasattr (args , "use_sdpa_with_kv_cache" ):
62
+ llm_config .model .use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache
63
+ if hasattr (args , "expand_rope_table" ):
64
+ llm_config .model .expand_rope_table = args .expand_rope_table
65
+ if hasattr (args , "use_attention_sink" ):
66
+ llm_config .model .use_attention_sink = args .use_attention_sink
67
+ if hasattr (args , "output_prune_map" ):
68
+ llm_config .model .output_prune_map = args .output_prune_map
69
+ if hasattr (args , "input_prune_map" ):
70
+ llm_config .model .input_prune_map = args .input_prune_map
71
+ if hasattr (args , "use_kv_cache" ):
72
+ llm_config .model .use_kv_cache = args .use_kv_cache
73
+ if hasattr (args , "quantize_kv_cache" ):
74
+ llm_config .model .quantize_kv_cache = args .quantize_kv_cache
75
+ if hasattr (args , "local_global_attention" ):
76
+ llm_config .model .local_global_attention = args .local_global_attention
56
77
57
78
# ExportConfig
58
- llm_config .export .max_seq_length = args .max_seq_length
59
- llm_config .export .max_context_length = args .max_context_length
60
- llm_config .export .output_dir = args .output_dir
61
- llm_config .export .output_name = args .output_name
62
- llm_config .export .so_library = args .so_library
63
- llm_config .export .export_only = args .export_only
79
+ if hasattr (args , "max_seq_length" ):
80
+ llm_config .export .max_seq_length = args .max_seq_length
81
+ if hasattr (args , "max_context_length" ):
82
+ llm_config .export .max_context_length = args .max_context_length
83
+ if hasattr (args , "output_dir" ):
84
+ llm_config .export .output_dir = args .output_dir
85
+ if hasattr (args , "output_name" ):
86
+ llm_config .export .output_name = args .output_name
87
+ if hasattr (args , "so_library" ):
88
+ llm_config .export .so_library = args .so_library
89
+ if hasattr (args , "export_only" ):
90
+ llm_config .export .export_only = args .export_only
64
91
65
92
# QuantizationConfig
66
- llm_config .quantization .qmode = args .quantization_mode
67
- llm_config .quantization .embedding_quantize = args .embedding_quantize
68
- if args .pt2e_quantize :
93
+ if hasattr (args , "quantization_mode" ):
94
+ llm_config .quantization .qmode = args .quantization_mode
95
+ if hasattr (args , "embedding_quantize" ):
96
+ llm_config .quantization .embedding_quantize = args .embedding_quantize
97
+ if hasattr (args , "pt2e_quantize" ) and args .pt2e_quantize :
69
98
llm_config .quantization .pt2e_quantize = Pt2eQuantize (args .pt2e_quantize )
70
- llm_config .quantization .group_size = args .group_size
71
- if args .use_spin_quant :
99
+ if hasattr (args , "group_size" ):
100
+ llm_config .quantization .group_size = args .group_size
101
+ if hasattr (args , "use_spin_quant" ) and args .use_spin_quant :
72
102
llm_config .quantization .use_spin_quant = SpinQuant (args .use_spin_quant )
73
- llm_config .quantization .use_qat = args .use_qat
74
- llm_config .quantization .calibration_tasks = args .calibration_tasks
75
- llm_config .quantization .calibration_limit = args .calibration_limit
76
- llm_config .quantization .calibration_seq_length = args .calibration_seq_length
77
- llm_config .quantization .calibration_data = args .calibration_data
78
-
79
- # BackendConfig
80
- # XNNPack
81
- llm_config .backend .xnnpack .enabled = args .xnnpack
82
- llm_config .backend .xnnpack .extended_ops = args .xnnpack_extended_ops
103
+ if hasattr (args , "use_qat" ):
104
+ llm_config .quantization .use_qat = args .use_qat
105
+ if hasattr (args , "calibration_tasks" ):
106
+ llm_config .quantization .calibration_tasks = args .calibration_tasks
107
+ if hasattr (args , "calibration_limit" ):
108
+ llm_config .quantization .calibration_limit = args .calibration_limit
109
+ if hasattr (args , "calibration_seq_length" ):
110
+ llm_config .quantization .calibration_seq_length = args .calibration_seq_length
111
+ if hasattr (args , "calibration_data" ):
112
+ llm_config .quantization .calibration_data = args .calibration_data
113
+
114
+ # BackendConfig - XNNPack
115
+ if hasattr (args , "xnnpack" ):
116
+ llm_config .backend .xnnpack .enabled = args .xnnpack
117
+ if hasattr (args , "xnnpack_extended_ops" ):
118
+ llm_config .backend .xnnpack .extended_ops = args .xnnpack_extended_ops
83
119
84
120
# CoreML
85
- llm_config .backend .coreml .enabled = args .coreml
121
+ if hasattr (args , "coreml" ):
122
+ llm_config .backend .coreml .enabled = args .coreml
86
123
llm_config .backend .coreml .enable_state = getattr (args , "coreml_enable_state" , False )
87
124
llm_config .backend .coreml .preserve_sdpa = getattr (
88
125
args , "coreml_preserve_sdpa" , False
89
126
)
90
- if args .coreml_quantize :
127
+ if hasattr ( args , "coreml_quantize" ) and args .coreml_quantize :
91
128
llm_config .backend .coreml .quantize = CoreMLQuantize (args .coreml_quantize )
92
- llm_config .backend .coreml .ios = args .coreml_ios
93
- llm_config .backend .coreml .compute_units = CoreMLComputeUnit (
94
- args .coreml_compute_units
95
- )
129
+ if hasattr (args , "coreml_ios" ):
130
+ llm_config .backend .coreml .ios = args .coreml_ios
131
+ if hasattr (args , "coreml_compute_units" ):
132
+ llm_config .backend .coreml .compute_units = CoreMLComputeUnit (
133
+ args .coreml_compute_units
134
+ )
96
135
97
136
# Vulkan
98
- llm_config .backend .vulkan .enabled = args .vulkan
137
+ if hasattr (args , "vulkan" ):
138
+ llm_config .backend .vulkan .enabled = args .vulkan
99
139
100
140
# QNN
101
- llm_config .backend .qnn .enabled = args .qnn
102
- llm_config .backend .qnn .use_sha = args .use_qnn_sha
103
- llm_config .backend .qnn .soc_model = args .soc_model
104
- llm_config .backend .qnn .optimized_rotation_path = args .optimized_rotation_path
105
- llm_config .backend .qnn .num_sharding = args .num_sharding
141
+ if hasattr (args , "qnn" ):
142
+ llm_config .backend .qnn .enabled = args .qnn
143
+ if hasattr (args , "use_qnn_sha" ):
144
+ llm_config .backend .qnn .use_sha = args .use_qnn_sha
145
+ if hasattr (args , "soc_model" ):
146
+ llm_config .backend .qnn .soc_model = args .soc_model
147
+ if hasattr (args , "optimized_rotation_path" ):
148
+ llm_config .backend .qnn .optimized_rotation_path = args .optimized_rotation_path
149
+ if hasattr (args , "num_sharding" ):
150
+ llm_config .backend .qnn .num_sharding = args .num_sharding
106
151
107
152
# MPS
108
- llm_config .backend .mps .enabled = args .mps
153
+ if hasattr (args , "mps" ):
154
+ llm_config .backend .mps .enabled = args .mps
109
155
110
156
# DebugConfig
111
- llm_config .debug .profile_memory = args .profile_memory
112
- llm_config .debug .profile_path = args .profile_path
113
- llm_config .debug .generate_etrecord = args .generate_etrecord
114
- llm_config .debug .generate_full_logits = args .generate_full_logits
115
- llm_config .debug .verbose = args .verbose
157
+ if hasattr (args , "profile_memory" ):
158
+ llm_config .debug .profile_memory = args .profile_memory
159
+ if hasattr (args , "profile_path" ):
160
+ llm_config .debug .profile_path = args .profile_path
161
+ if hasattr (args , "generate_etrecord" ):
162
+ llm_config .debug .generate_etrecord = args .generate_etrecord
163
+ if hasattr (args , "generate_full_logits" ):
164
+ llm_config .debug .generate_full_logits = args .generate_full_logits
165
+ if hasattr (args , "verbose" ):
166
+ llm_config .debug .verbose = args .verbose
116
167
117
168
return llm_config
0 commit comments