Skip to content

Commit fab064d

Browse files
committed
Remove unnecessary ffi calls
1 parent e5d596e commit fab064d

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

llama_cpp/llama.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -177,19 +177,19 @@ def __init__(
177177
if self.verbose:
178178
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
179179

180-
n_vocab = self.n_vocab()
181-
n_ctx = self.n_ctx()
182-
data = (llama_cpp.llama_token_data * n_vocab)(
180+
self._n_vocab = self.n_vocab()
181+
self._n_ctx = self.n_ctx()
182+
data = (llama_cpp.llama_token_data * self._n_vocab)(
183183
*[
184184
llama_cpp.llama_token_data(
185185
id=llama_cpp.llama_token(i),
186186
logit=llama_cpp.c_float(0.0),
187187
p=llama_cpp.c_float(0.0),
188188
)
189-
for i in range(n_vocab)
189+
for i in range(self._n_vocab)
190190
]
191191
)
192-
size = llama_cpp.c_size_t(n_vocab)
192+
size = llama_cpp.c_size_t(self._n_vocab)
193193
sorted = False
194194
candidates = llama_cpp.llama_token_data_array(
195195
data=data,
@@ -213,18 +213,18 @@ def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
213213
A list of tokens.
214214
"""
215215
assert self.ctx is not None
216-
n_ctx = llama_cpp.llama_n_ctx(self.ctx)
217-
tokens = (llama_cpp.llama_token * int(n_ctx))()
216+
n_ctx = self._n_ctx
217+
tokens = (llama_cpp.llama_token * n_ctx)()
218218
n_tokens = llama_cpp.llama_tokenize(
219219
self.ctx,
220220
text,
221221
tokens,
222222
llama_cpp.c_int(n_ctx),
223223
llama_cpp.c_bool(add_bos),
224224
)
225-
if int(n_tokens) < 0:
225+
if n_tokens < 0:
226226
n_tokens = abs(n_tokens)
227-
tokens = (llama_cpp.llama_token * int(n_tokens))()
227+
tokens = (llama_cpp.llama_token * n_tokens)()
228228
n_tokens = llama_cpp.llama_tokenize(
229229
self.ctx,
230230
text,
@@ -275,7 +275,7 @@ def eval(self, tokens: Sequence[int]):
275275
tokens: The list of tokens to evaluate.
276276
"""
277277
assert self.ctx is not None
278-
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
278+
n_ctx = self._n_ctx
279279
for i in range(0, len(tokens), self.n_batch):
280280
batch = tokens[i : min(len(tokens), i + self.n_batch)]
281281
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
@@ -287,18 +287,16 @@ def eval(self, tokens: Sequence[int]):
287287
n_past=llama_cpp.c_int(n_past),
288288
n_threads=llama_cpp.c_int(self.n_threads),
289289
)
290-
if int(return_code) != 0:
290+
if return_code != 0:
291291
raise RuntimeError(f"llama_eval returned {return_code}")
292292
# Save tokens
293293
self.eval_tokens.extend(batch)
294294
# Save logits
295295
rows = n_tokens if self.params.logits_all else 1
296-
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
297-
cols = int(n_vocab)
296+
n_vocab = self._n_vocab
297+
cols = n_vocab
298298
logits_view = llama_cpp.llama_get_logits(self.ctx)
299-
logits: List[List[float]] = [
300-
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
301-
]
299+
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
302300
self.eval_logits.extend(logits)
303301

304302
def _sample(
@@ -319,8 +317,8 @@ def _sample(
319317
):
320318
assert self.ctx is not None
321319
assert len(self.eval_logits) > 0
322-
n_vocab = self.n_vocab()
323-
n_ctx = self.n_ctx()
320+
n_vocab = self._n_vocab
321+
n_ctx = self._n_ctx
324322
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
325323
last_n_tokens_size = (
326324
llama_cpp.c_int(n_ctx)
@@ -654,9 +652,9 @@ def _create_completion(
654652
if self.verbose:
655653
llama_cpp.llama_reset_timings(self.ctx)
656654

657-
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
655+
if len(prompt_tokens) + max_tokens > self._n_ctx:
658656
raise ValueError(
659-
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
657+
f"Requested tokens exceed context window of {self._n_ctx}"
660658
)
661659

662660
if stop != []:

0 commit comments

Comments
 (0)