Skip to content

Commit 517f9ed

Browse files
committed
Convert missed llama.cpp constants into standard python types
1 parent c4c440b commit 517f9ed

File tree

2 files changed

+86
-86
lines changed

2 files changed

+86
-86
lines changed

llama_cpp/llama.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -343,11 +343,11 @@ def __init__(
343343
if self.lora_path:
344344
if llama_cpp.llama_model_apply_lora_from_file(
345345
self.model,
346-
llama_cpp.c_char_p(self.lora_path.encode("utf-8")),
347-
llama_cpp.c_char_p(self.lora_base.encode("utf-8"))
346+
self.lora_path.encode("utf-8"),
347+
self.lora_base.encode("utf-8")
348348
if self.lora_base is not None
349349
else llama_cpp.c_char_p(0),
350-
llama_cpp.c_int(self.n_threads),
350+
self.n_threads,
351351
):
352352
raise RuntimeError(
353353
f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
@@ -358,8 +358,8 @@ def __init__(
358358

359359
self._n_vocab = self.n_vocab()
360360
self._n_ctx = self.n_ctx()
361-
size = llama_cpp.c_size_t(self._n_vocab)
362-
sorted = llama_cpp.c_bool(False)
361+
size = self._n_vocab
362+
sorted = False
363363
self._candidates_data = np.array(
364364
[],
365365
dtype=np.dtype(
@@ -422,8 +422,8 @@ def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
422422
self.model,
423423
text,
424424
tokens,
425-
llama_cpp.c_int(n_ctx),
426-
llama_cpp.c_bool(add_bos),
425+
n_ctx,
426+
add_bos,
427427
)
428428
if n_tokens < 0:
429429
n_tokens = abs(n_tokens)
@@ -432,8 +432,8 @@ def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
432432
self.model,
433433
text,
434434
tokens,
435-
llama_cpp.c_int(n_tokens),
436-
llama_cpp.c_bool(add_bos),
435+
n_tokens,
436+
add_bos,
437437
)
438438
if n_tokens < 0:
439439
raise RuntimeError(
@@ -491,9 +491,9 @@ def eval(self, tokens: Sequence[int]):
491491
return_code = llama_cpp.llama_eval(
492492
ctx=self.ctx,
493493
tokens=(llama_cpp.llama_token * len(batch))(*batch),
494-
n_tokens=llama_cpp.c_int(n_tokens),
495-
n_past=llama_cpp.c_int(n_past),
496-
n_threads=llama_cpp.c_int(self.n_threads),
494+
n_tokens=n_tokens,
495+
n_past=n_past,
496+
n_threads=self.n_threads,
497497
)
498498
if return_code != 0:
499499
raise RuntimeError(f"llama_eval returned {return_code}")
@@ -514,17 +514,17 @@ def eval(self, tokens: Sequence[int]):
514514
def _sample(
515515
self,
516516
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
517-
last_n_tokens_size: llama_cpp.c_int,
518-
top_k: llama_cpp.c_int,
519-
top_p: llama_cpp.c_float,
520-
temp: llama_cpp.c_float,
521-
tfs_z: llama_cpp.c_float,
522-
repeat_penalty: llama_cpp.c_float,
523-
frequency_penalty: llama_cpp.c_float,
524-
presence_penalty: llama_cpp.c_float,
525-
mirostat_mode: llama_cpp.c_int,
526-
mirostat_tau: llama_cpp.c_float,
527-
mirostat_eta: llama_cpp.c_float,
517+
last_n_tokens_size: int,
518+
top_k: int,
519+
top_p: float,
520+
temp: float,
521+
tfs_z: float,
522+
repeat_penalty: float,
523+
frequency_penalty: float,
524+
presence_penalty: float,
525+
mirostat_mode: float,
526+
mirostat_tau: float,
527+
mirostat_eta: float,
528528
penalize_nl: bool = True,
529529
logits_processor: Optional[LogitsProcessorList] = None,
530530
grammar: Optional[LlamaGrammar] = None,
@@ -533,10 +533,10 @@ def _sample(
533533
assert self.n_tokens > 0
534534
n_vocab = self._n_vocab
535535
n_ctx = self._n_ctx
536-
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
536+
top_k = n_vocab if top_k <= 0 else top_k
537537
last_n_tokens_size = (
538-
llama_cpp.c_int(n_ctx)
539-
if last_n_tokens_size.value < 0
538+
n_ctx
539+
if last_n_tokens_size < 0
540540
else last_n_tokens_size
541541
)
542542
logits: npt.NDArray[np.single] = self._scores[-1, :]
@@ -578,13 +578,13 @@ def _sample(
578578
grammar=grammar.grammar,
579579
)
580580

581-
if temp.value == 0.0:
581+
if temp == 0.0:
582582
id = llama_cpp.llama_sample_token_greedy(
583583
ctx=self.ctx,
584584
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
585585
)
586-
elif mirostat_mode.value == 1:
587-
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
586+
elif mirostat_mode == 1:
587+
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau)
588588
mirostat_m = llama_cpp.c_int(100)
589589
llama_cpp.llama_sample_temperature(
590590
ctx=self.ctx,
@@ -599,8 +599,8 @@ def _sample(
599599
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
600600
m=mirostat_m,
601601
)
602-
elif mirostat_mode.value == 2:
603-
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
602+
elif mirostat_mode== 2:
603+
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau)
604604
llama_cpp.llama_sample_temperature(
605605
ctx=self.ctx,
606606
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
@@ -690,17 +690,17 @@ def sample(
690690
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
691691
*last_n_tokens_data
692692
),
693-
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
694-
top_k=llama_cpp.c_int(top_k),
695-
top_p=llama_cpp.c_float(top_p),
696-
temp=llama_cpp.c_float(temp),
697-
tfs_z=llama_cpp.c_float(tfs_z),
698-
repeat_penalty=llama_cpp.c_float(repeat_penalty),
699-
frequency_penalty=llama_cpp.c_float(frequency_penalty),
700-
presence_penalty=llama_cpp.c_float(presence_penalty),
701-
mirostat_mode=llama_cpp.c_int(mirostat_mode),
702-
mirostat_tau=llama_cpp.c_float(mirostat_tau),
703-
mirostat_eta=llama_cpp.c_float(mirostat_eta),
693+
last_n_tokens_size=self.last_n_tokens_size,
694+
top_k=top_k,
695+
top_p=top_p,
696+
temp=temp,
697+
tfs_z=tfs_z,
698+
repeat_penalty=repeat_penalty,
699+
frequency_penalty=frequency_penalty,
700+
presence_penalty=presence_penalty,
701+
mirostat_mode=mirostat_mode,
702+
mirostat_tau=mirostat_tau,
703+
mirostat_eta=mirostat_eta,
704704
penalize_nl=penalize_nl,
705705
logits_processor=logits_processor,
706706
grammar=grammar,

llama_cpp/llama_cpp.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,15 @@ def _load_shared_library(lib_base_name: str):
9191
LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else 1
9292

9393
# define LLAMA_DEFAULT_SEED 0xFFFFFFFF
94-
LLAMA_DEFAULT_SEED = ctypes.c_int(0xFFFFFFFF)
94+
LLAMA_DEFAULT_SEED = 0xFFFFFFFF
9595

9696
# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
97-
LLAMA_FILE_MAGIC_GGSN = ctypes.c_uint(0x6767736E)
97+
LLAMA_FILE_MAGIC_GGSN = 0x6767736E
9898

9999
# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
100100
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
101101
# define LLAMA_SESSION_VERSION 1
102-
LLAMA_SESSION_VERSION = ctypes.c_int(1)
102+
LLAMA_SESSION_VERSION = 1
103103

104104

105105
# struct llama_model;
@@ -118,16 +118,16 @@ def _load_shared_library(lib_base_name: str):
118118
# LLAMA_LOG_LEVEL_WARN = 3,
119119
# LLAMA_LOG_LEVEL_INFO = 4
120120
# };
121-
LLAMA_LOG_LEVEL_ERROR = c_int(2)
122-
LLAMA_LOG_LEVEL_WARN = c_int(3)
123-
LLAMA_LOG_LEVEL_INFO = c_int(4)
121+
LLAMA_LOG_LEVEL_ERROR = 2
122+
LLAMA_LOG_LEVEL_WARN = 3
123+
LLAMA_LOG_LEVEL_INFO = 4
124124

125125
# enum llama_vocab_type {
126126
# LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
127127
# LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
128128
# };
129-
LLAMA_VOCAB_TYPE_SPM = c_int(0)
130-
LLAMA_VOCAB_TYPE_BPE = c_int(1)
129+
LLAMA_VOCAB_TYPE_SPM = 0
130+
LLAMA_VOCAB_TYPE_BPE = 1
131131

132132

133133
# enum llama_token_type {
@@ -139,13 +139,13 @@ def _load_shared_library(lib_base_name: str):
139139
# LLAMA_TOKEN_TYPE_UNUSED = 5,
140140
# LLAMA_TOKEN_TYPE_BYTE = 6,
141141
# };
142-
LLAMA_TOKEN_TYPE_UNDEFINED = c_int(0)
143-
LLAMA_TOKEN_TYPE_NORMAL = c_int(1)
144-
LLAMA_TOKEN_TYPE_UNKNOWN = c_int(2)
145-
LLAMA_TOKEN_TYPE_CONTROL = c_int(3)
146-
LLAMA_TOKEN_TYPE_USER_DEFINED = c_int(4)
147-
LLAMA_TOKEN_TYPE_UNUSED = c_int(5)
148-
LLAMA_TOKEN_TYPE_BYTE = c_int(6)
142+
LLAMA_TOKEN_TYPE_UNDEFINED = 0
143+
LLAMA_TOKEN_TYPE_NORMAL = 1
144+
LLAMA_TOKEN_TYPE_UNKNOWN = 2
145+
LLAMA_TOKEN_TYPE_CONTROL = 3
146+
LLAMA_TOKEN_TYPE_USER_DEFINED = 4
147+
LLAMA_TOKEN_TYPE_UNUSED = 5
148+
LLAMA_TOKEN_TYPE_BYTE = 6
149149

150150
# enum llama_ftype {
151151
# LLAMA_FTYPE_ALL_F32 = 0,
@@ -170,24 +170,24 @@ def _load_shared_library(lib_base_name: str):
170170
#
171171
# LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
172172
# };
173-
LLAMA_FTYPE_ALL_F32 = c_int(0)
174-
LLAMA_FTYPE_MOSTLY_F16 = c_int(1)
175-
LLAMA_FTYPE_MOSTLY_Q4_0 = c_int(2)
176-
LLAMA_FTYPE_MOSTLY_Q4_1 = c_int(3)
177-
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = c_int(4)
178-
LLAMA_FTYPE_MOSTLY_Q8_0 = c_int(7)
179-
LLAMA_FTYPE_MOSTLY_Q5_0 = c_int(8)
180-
LLAMA_FTYPE_MOSTLY_Q5_1 = c_int(9)
181-
LLAMA_FTYPE_MOSTLY_Q2_K = c_int(10)
182-
LLAMA_FTYPE_MOSTLY_Q3_K_S = c_int(11)
183-
LLAMA_FTYPE_MOSTLY_Q3_K_M = c_int(12)
184-
LLAMA_FTYPE_MOSTLY_Q3_K_L = c_int(13)
185-
LLAMA_FTYPE_MOSTLY_Q4_K_S = c_int(14)
186-
LLAMA_FTYPE_MOSTLY_Q4_K_M = c_int(15)
187-
LLAMA_FTYPE_MOSTLY_Q5_K_S = c_int(16)
188-
LLAMA_FTYPE_MOSTLY_Q5_K_M = c_int(17)
189-
LLAMA_FTYPE_MOSTLY_Q6_K = c_int(18)
190-
LLAMA_FTYPE_GUESSED = c_int(1024)
173+
LLAMA_FTYPE_ALL_F32 = 0
174+
LLAMA_FTYPE_MOSTLY_F16 = 1
175+
LLAMA_FTYPE_MOSTLY_Q4_0 = 2
176+
LLAMA_FTYPE_MOSTLY_Q4_1 = 3
177+
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4
178+
LLAMA_FTYPE_MOSTLY_Q8_0 = 7
179+
LLAMA_FTYPE_MOSTLY_Q5_0 = 8
180+
LLAMA_FTYPE_MOSTLY_Q5_1 = 9
181+
LLAMA_FTYPE_MOSTLY_Q2_K = 10
182+
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11
183+
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12
184+
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13
185+
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14
186+
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15
187+
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16
188+
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17
189+
LLAMA_FTYPE_MOSTLY_Q6_K = 18
190+
LLAMA_FTYPE_GUESSED = 1024
191191

192192

193193
# typedef struct llama_token_data {
@@ -589,7 +589,7 @@ def llama_model_n_embd(model: llama_model_p) -> int:
589589

590590
# // Get a string describing the model type
591591
# LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
592-
def llama_model_desc(model: llama_model_p, buf: bytes, buf_size: c_size_t) -> int:
592+
def llama_model_desc(model: llama_model_p, buf: bytes, buf_size: Union[c_size_t, int]) -> int:
593593
return _lib.llama_model_desc(model, buf, buf_size)
594594

595595

@@ -957,8 +957,8 @@ def llama_tokenize(
957957
ctx: llama_context_p,
958958
text: bytes,
959959
tokens, # type: Array[llama_token]
960-
n_max_tokens: c_int,
961-
add_bos: c_bool,
960+
n_max_tokens: Union[c_int, int],
961+
add_bos: Union[c_bool, int],
962962
) -> int:
963963
return _lib.llama_tokenize(ctx, text, tokens, n_max_tokens, add_bos)
964964

@@ -977,8 +977,8 @@ def llama_tokenize_with_model(
977977
model: llama_model_p,
978978
text: bytes,
979979
tokens, # type: Array[llama_token]
980-
n_max_tokens: c_int,
981-
add_bos: c_bool,
980+
n_max_tokens: Union[c_int, int],
981+
add_bos: Union[c_bool, bool],
982982
) -> int:
983983
return _lib.llama_tokenize_with_model(model, text, tokens, n_max_tokens, add_bos)
984984

@@ -1003,7 +1003,7 @@ def llama_tokenize_with_model(
10031003
# char * buf,
10041004
# int length);
10051005
def llama_token_to_piece(
1006-
ctx: llama_context_p, token: llama_token, buf: bytes, length: c_int
1006+
ctx: llama_context_p, token: llama_token, buf: bytes, length: Union[c_int, int]
10071007
) -> int:
10081008
return _lib.llama_token_to_piece(ctx, token, buf, length)
10091009

@@ -1018,7 +1018,7 @@ def llama_token_to_piece(
10181018
# char * buf,
10191019
# int length);
10201020
def llama_token_to_piece_with_model(
1021-
model: llama_model_p, token: llama_token, buf: bytes, length: c_int
1021+
model: llama_model_p, token: llama_token, buf: bytes, length: Union[c_int, int]
10221022
) -> int:
10231023
return _lib.llama_token_to_piece_with_model(model, token, buf, length)
10241024

@@ -1453,10 +1453,10 @@ def llama_beam_search(
14531453
ctx: llama_context_p,
14541454
callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore
14551455
callback_data: c_void_p,
1456-
n_beams: c_size_t,
1457-
n_past: c_int,
1458-
n_predict: c_int,
1459-
n_threads: c_int,
1456+
n_beams: Union[c_size_t, int],
1457+
n_past: Union[c_int, int],
1458+
n_predict: Union[c_int, int],
1459+
n_threads: Union[c_int, int],
14601460
):
14611461
return _lib.llama_beam_search(
14621462
ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads

0 commit comments

Comments
 (0)