Skip to content

Commit f1dcbb4

Browse files
SagsMugDon Mahurin
authored andcommitted
Fix session loading and saving in low level example chat
1 parent 90ae8d3 commit f1dcbb4

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

examples/low_level_api_chat_cpp.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,17 @@ def __init__(self, params: GptParams) -> None:
138138

139139
if (path.exists(self.params.path_session)):
140140
_session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))()
141-
_n_token_count_out = llama_cpp.c_int()
141+
_n_token_count_out = llama_cpp.c_size_t()
142142
if (llama_cpp.llama_load_session_file(
143143
self.ctx,
144144
self.params.path_session.encode("utf8"),
145145
_session_tokens,
146146
self.params.n_ctx,
147147
ctypes.byref(_n_token_count_out)
148-
) != 0):
148+
) != 1):
149149
print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr)
150150
return
151+
_n_token_count_out = _n_token_count_out.value
151152
self.session_tokens = _session_tokens[:_n_token_count_out]
152153
print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr)
153154
else:
@@ -161,19 +162,21 @@ def __init__(self, params: GptParams) -> None:
161162
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
162163

163164
# debug message about similarity of saved session, if applicable
164-
n_matching_session_tokens = 0
165+
self.n_matching_session_tokens = 0
165166
if len(self.session_tokens) > 0:
166167
for id in self.session_tokens:
167-
if n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[n_matching_session_tokens]:
168+
if self.n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[self.n_matching_session_tokens]:
168169
break
169-
n_matching_session_tokens += 1
170+
self.n_matching_session_tokens += 1
170171

171-
if n_matching_session_tokens >= len(self.embd_inp):
172+
if self.n_matching_session_tokens >= len(self.embd_inp):
172173
print(f"session file has exact match for prompt!")
173-
elif n_matching_session_tokens < (len(self.embd_inp) / 2):
174-
print(f"warning: session file has low similarity to prompt ({n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
174+
elif self.n_matching_session_tokens < (len(self.embd_inp) / 2):
175+
print(f"warning: session file has low similarity to prompt ({self.n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
175176
else:
176-
print(f"session file matches {n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
177+
print(f"session file matches {self.n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
178+
179+
self.need_to_save_session = len(self.params.path_session) > 0 and self.n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
177180

178181
# number of tokens to keep when resetting context
179182
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
@@ -258,9 +261,6 @@ def __init__(self, params: GptParams) -> None:
258261
""", file=sys.stderr)
259262
self.set_color(CONSOLE_COLOR_PROMPT)
260263

261-
self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
262-
263-
264264
# tokenize a prompt
265265
def _tokenize(self, prompt, bos=True):
266266
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
@@ -329,7 +329,7 @@ def generate(self):
329329
) != 0):
330330
raise Exception("Failed to llama_eval!")
331331

332-
if len(self.embd) > 0 and not len(self.params.path_session) > 0:
332+
if len(self.embd) > 0 and len(self.params.path_session) > 0:
333333
self.session_tokens.extend(self.embd)
334334
self.n_session_consumed = len(self.session_tokens)
335335

@@ -346,7 +346,7 @@ def generate(self):
346346
llama_cpp.llama_save_session_file(
347347
self.ctx,
348348
self.params.path_session.encode("utf8"),
349-
self.session_tokens,
349+
(llama_cpp.llama_token * len(self.session_tokens))(*self.session_tokens),
350350
len(self.session_tokens)
351351
)
352352

0 commit comments

Comments
 (0)