@@ -47,15 +47,30 @@ def __init__(self, **kwargs):
47
47
# Params file.
48
48
params_path = kwargs .get ("params" , None )
49
49
50
- self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
51
- self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
52
- self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
53
- self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
54
- self .input_prune_map_path = kwargs .get ("input_prune_map_path" , None )
55
- self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
56
- self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
57
- self .max_context_len = kwargs .get ("max_context_len" , 128 )
58
50
self .llm_config = kwargs .get ("llm_config" , None )
51
+
52
+ # Set all parameters from llm_config if available, otherwise use kwargs as fallback
53
+ if self .llm_config :
54
+ self .use_kv_cache = self .llm_config .model .use_kv_cache
55
+ self .use_sdpa_with_kv_cache_op = self .llm_config .model .use_sdpa_with_kv_cache
56
+ self .generate_full_logits = self .llm_config .debug .generate_full_logits
57
+ self .enable_dynamic_shape = self .llm_config .model .enable_dynamic_shape
58
+ self .input_prune_map_path = self .llm_config .model .input_prune_map
59
+ self .output_prune_map_path = self .llm_config .model .output_prune_map
60
+ self .max_seq_len = self .llm_config .export .max_seq_length
61
+ self .max_context_len = self .llm_config .export .max_context_length
62
+ self .verbose = self .llm_config .debug .verbose
63
+ else :
64
+ # Fallback to kwargs for backward compatibility
65
+ self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
66
+ self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
67
+ self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
68
+ self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
69
+ self .input_prune_map_path = kwargs .get ("input_prune_map_path" , None )
70
+ self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
71
+ self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
72
+ self .max_context_len = kwargs .get ("max_context_len" , 128 )
73
+ self .verbose = kwargs .get ("verbose" , False )
59
74
60
75
assert (
61
76
self .max_context_len >= self .max_seq_len
@@ -165,7 +180,7 @@ def __init__(self, **kwargs):
165
180
if model_name not in ["llama3" , "llama3_1" ]:
166
181
model_args .rope_scale_factor = 32
167
182
168
- if kwargs . get ( " verbose" , False ) :
183
+ if self . verbose :
169
184
print ("============= weights ================" )
170
185
print ("{key} : {weights.numel()} : {weights.size()}" )
171
186
for key , weights in checkpoint .items ():
@@ -280,7 +295,7 @@ def __init__(self, **kwargs):
280
295
f"The provided checkpoint is missing the following weights that are expected by the model: { missing_weights } . Please fix the fqn's in your checkpoint to match."
281
296
)
282
297
if unexpected :
283
- if kwargs . get ( " verbose" , False ) :
298
+ if self . verbose :
284
299
print (f"Unexpected keys: { unexpected } " )
285
300
286
301
# Prune the input layer if input_prune_map is provided
0 commit comments