@@ -216,14 +216,15 @@ def __init__(
216
216
embedding : bool = False ,
217
217
n_threads : Optional [int ] = None ,
218
218
n_batch : int = 512 ,
219
- n_gqa : Optional [int ] = None , # must be 8 for llama2 70b
220
219
last_n_tokens_size : int = 64 ,
221
220
lora_base : Optional [str ] = None ,
222
221
lora_path : Optional [str ] = None ,
223
222
low_vram : bool = False ,
224
223
tensor_split : Optional [List [float ]] = None ,
225
224
rope_freq_base : float = 10000.0 ,
226
225
rope_freq_scale : float = 1.0 ,
226
+ n_gqa : Optional [int ] = None , # (TEMPORARY) must be 8 for llama2 70b
227
+ rms_eps_norm : Optional [float ] = None , # (TEMPORARY)
227
228
verbose : bool = True ,
228
229
):
229
230
"""Load a llama.cpp model from `model_path`.
@@ -261,8 +262,6 @@ def __init__(
261
262
262
263
self .params = llama_cpp .llama_context_default_params ()
263
264
self .params .n_ctx = n_ctx
264
- if n_gqa is not None :
265
- self .params .n_gqa = n_gqa
266
265
self .params .n_gpu_layers = n_gpu_layers
267
266
self .params .seed = seed
268
267
self .params .f16_kv = f16_kv
@@ -285,6 +284,12 @@ def __init__(
285
284
self .params .rope_freq_base = rope_freq_base
286
285
self .params .rope_freq_scale = rope_freq_scale
287
286
287
+ if n_gqa is not None :
288
+ self .params .n_gqa = n_gqa
289
+
290
+ if rms_eps_norm is not None :
291
+ self .params .rms_eps_norm = rms_eps_norm
292
+
288
293
self .last_n_tokens_size = last_n_tokens_size
289
294
self .n_batch = min (n_ctx , n_batch )
290
295
@@ -1526,6 +1531,10 @@ def __getstate__(self):
1526
1531
lora_base = self .lora_base ,
1527
1532
lora_path = self .lora_path ,
1528
1533
tensor_split = self .tensor_split ,
1534
+ ### TEMPORARY ###
1535
+ n_gqa = self .params .n_gqa ,
1536
+ rms_eps_norm = self .params .rms_eps_norm ,
1537
+ ### TEMPORARY ###
1529
1538
### DEPRECATED ###
1530
1539
n_parts = self .n_parts ,
1531
1540
### DEPRECATED ###
@@ -1535,7 +1544,6 @@ def __setstate__(self, state):
1535
1544
self .__init__ (
1536
1545
model_path = state ["model_path" ],
1537
1546
n_ctx = state ["n_ctx" ],
1538
- n_parts = state ["n_parts" ],
1539
1547
n_gpu_layers = state ["n_gpu_layers" ],
1540
1548
seed = state ["seed" ],
1541
1549
f16_kv = state ["f16_kv" ],
@@ -1551,7 +1559,14 @@ def __setstate__(self, state):
1551
1559
lora_base = state ["lora_base" ],
1552
1560
lora_path = state ["lora_path" ],
1553
1561
tensor_split = state ["tensor_split" ],
1562
+ n_gqa = state ["n_gqa" ],
1563
+ ### TEMPORARY ###
1564
+ rms_eps_norm = state ["rms_eps_norm" ],
1554
1565
verbose = state ["verbose" ],
1566
+ ### TEMPORARY ###
1567
+ ### DEPRECATED ###
1568
+ n_parts = state ["n_parts" ],
1569
+ ### DEPRECATED ###
1555
1570
)
1556
1571
1557
1572
def save_state (self ) -> LlamaState :
0 commit comments