Skip to content

Commit 8cd64d4

Browse files
committed
Add rms_eps_norm
1 parent e4431a6 commit 8cd64d4

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

llama_cpp/llama.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,14 +216,15 @@ def __init__(
216216
embedding: bool = False,
217217
n_threads: Optional[int] = None,
218218
n_batch: int = 512,
219-
n_gqa: Optional[int] = None, # must be 8 for llama2 70b
220219
last_n_tokens_size: int = 64,
221220
lora_base: Optional[str] = None,
222221
lora_path: Optional[str] = None,
223222
low_vram: bool = False,
224223
tensor_split: Optional[List[float]] = None,
225224
rope_freq_base: float = 10000.0,
226225
rope_freq_scale: float = 1.0,
226+
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
227+
rms_eps_norm: Optional[float] = None, # (TEMPORARY)
227228
verbose: bool = True,
228229
):
229230
"""Load a llama.cpp model from `model_path`.
@@ -261,8 +262,6 @@ def __init__(
261262

262263
self.params = llama_cpp.llama_context_default_params()
263264
self.params.n_ctx = n_ctx
264-
if n_gqa is not None:
265-
self.params.n_gqa = n_gqa
266265
self.params.n_gpu_layers = n_gpu_layers
267266
self.params.seed = seed
268267
self.params.f16_kv = f16_kv
@@ -285,6 +284,12 @@ def __init__(
285284
self.params.rope_freq_base = rope_freq_base
286285
self.params.rope_freq_scale = rope_freq_scale
287286

287+
if n_gqa is not None:
288+
self.params.n_gqa = n_gqa
289+
290+
if rms_eps_norm is not None:
291+
self.params.rms_eps_norm = rms_eps_norm
292+
288293
self.last_n_tokens_size = last_n_tokens_size
289294
self.n_batch = min(n_ctx, n_batch)
290295

@@ -1526,6 +1531,10 @@ def __getstate__(self):
15261531
lora_base=self.lora_base,
15271532
lora_path=self.lora_path,
15281533
tensor_split=self.tensor_split,
1534+
### TEMPORARY ###
1535+
n_gqa=self.params.n_gqa,
1536+
rms_eps_norm=self.params.rms_eps_norm,
1537+
### TEMPORARY ###
15291538
### DEPRECATED ###
15301539
n_parts=self.n_parts,
15311540
### DEPRECATED ###
@@ -1535,7 +1544,6 @@ def __setstate__(self, state):
15351544
self.__init__(
15361545
model_path=state["model_path"],
15371546
n_ctx=state["n_ctx"],
1538-
n_parts=state["n_parts"],
15391547
n_gpu_layers=state["n_gpu_layers"],
15401548
seed=state["seed"],
15411549
f16_kv=state["f16_kv"],
@@ -1551,7 +1559,14 @@ def __setstate__(self, state):
15511559
lora_base=state["lora_base"],
15521560
lora_path=state["lora_path"],
15531561
tensor_split=state["tensor_split"],
1562+
n_gqa=state["n_gqa"],
1563+
### TEMPORARY ###
1564+
rms_eps_norm=state["rms_eps_norm"],
15541565
verbose=state["verbose"],
1566+
### TEMPORARY ###
1567+
### DEPRECATED ###
1568+
n_parts=state["n_parts"],
1569+
### DEPRECATED ###
15551570
)
15561571

15571572
def save_state(self) -> LlamaState:

0 commit comments

Comments
 (0)