@@ -138,16 +138,17 @@ def __init__(self, params: GptParams) -> None:
138
138
139
139
if (path .exists (self .params .path_session )):
140
140
_session_tokens = (llama_cpp .llama_token * (self .params .n_ctx ))()
141
- _n_token_count_out = llama_cpp .c_int ()
141
+ _n_token_count_out = llama_cpp .c_size_t ()
142
142
if (llama_cpp .llama_load_session_file (
143
143
self .ctx ,
144
144
self .params .path_session .encode ("utf8" ),
145
145
_session_tokens ,
146
146
self .params .n_ctx ,
147
147
ctypes .byref (_n_token_count_out )
148
- ) != 0 ):
148
+ ) != 1 ):
149
149
print (f"error: failed to load session file '{ self .params .path_session } '" , file = sys .stderr )
150
150
return
151
+ _n_token_count_out = _n_token_count_out .value
151
152
self .session_tokens = _session_tokens [:_n_token_count_out ]
152
153
print (f"loaded a session with prompt size of { _n_token_count_out } tokens" , file = sys .stderr )
153
154
else :
@@ -161,19 +162,21 @@ def __init__(self, params: GptParams) -> None:
161
162
raise RuntimeError (f"error: prompt is too long ({ len (self .embd_inp )} tokens, max { self .params .n_ctx - 4 } )" )
162
163
163
164
# debug message about similarity of saved session, if applicable
164
- n_matching_session_tokens = 0
165
+ self . n_matching_session_tokens = 0
165
166
if len (self .session_tokens ) > 0 :
166
167
for id in self .session_tokens :
167
- if n_matching_session_tokens >= len (self .embd_inp ) or id != self .embd_inp [n_matching_session_tokens ]:
168
+ if self . n_matching_session_tokens >= len (self .embd_inp ) or id != self .embd_inp [self . n_matching_session_tokens ]:
168
169
break
169
- n_matching_session_tokens += 1
170
+ self . n_matching_session_tokens += 1
170
171
171
- if n_matching_session_tokens >= len (self .embd_inp ):
172
+ if self . n_matching_session_tokens >= len (self .embd_inp ):
172
173
print (f"session file has exact match for prompt!" )
173
- elif n_matching_session_tokens < (len (self .embd_inp ) / 2 ):
174
- print (f"warning: session file has low similarity to prompt ({ n_matching_session_tokens } / { len (self .embd_inp )} tokens); will mostly be reevaluated" )
174
+ elif self . n_matching_session_tokens < (len (self .embd_inp ) / 2 ):
175
+ print (f"warning: session file has low similarity to prompt ({ self . n_matching_session_tokens } / { len (self .embd_inp )} tokens); will mostly be reevaluated" )
175
176
else :
176
- print (f"session file matches { n_matching_session_tokens } / { len (self .embd_inp )} tokens of prompt" )
177
+ print (f"session file matches { self .n_matching_session_tokens } / { len (self .embd_inp )} tokens of prompt" )
178
+
179
+ self .need_to_save_session = len (self .params .path_session ) > 0 and self .n_matching_session_tokens < (len (self .embd_inp ) * 3 / 4 )
177
180
178
181
# number of tokens to keep when resetting context
179
182
if (self .params .n_keep < 0 or self .params .n_keep > len (self .embd_inp ) or self .params .instruct ):
@@ -258,9 +261,6 @@ def __init__(self, params: GptParams) -> None:
258
261
""" , file = sys .stderr )
259
262
self .set_color (CONSOLE_COLOR_PROMPT )
260
263
261
- self .need_to_save_session = len (self .params .path_session ) > 0 and n_matching_session_tokens < (len (self .embd_inp ) * 3 / 4 )
262
-
263
-
264
264
# tokenize a prompt
265
265
def _tokenize (self , prompt , bos = True ):
266
266
_arr = (llama_cpp .llama_token * (len (prompt ) + 1 ))()
@@ -329,7 +329,7 @@ def generate(self):
329
329
) != 0 ):
330
330
raise Exception ("Failed to llama_eval!" )
331
331
332
- if len (self .embd ) > 0 and not len (self .params .path_session ) > 0 :
332
+ if len (self .embd ) > 0 and len (self .params .path_session ) > 0 :
333
333
self .session_tokens .extend (self .embd )
334
334
self .n_session_consumed = len (self .session_tokens )
335
335
@@ -346,7 +346,7 @@ def generate(self):
346
346
llama_cpp .llama_save_session_file (
347
347
self .ctx ,
348
348
self .params .path_session .encode ("utf8" ),
349
- self .session_tokens ,
349
+ ( llama_cpp . llama_token * len ( self .session_tokens ))( * self . session_tokens ) ,
350
350
len (self .session_tokens )
351
351
)
352
352
0 commit comments