Skip to content

Commit 3dc21b2

Browse files
committed
tests: Improve llama.cpp mock
1 parent 63fe137 commit 3dc21b2

File tree

1 file changed

+92
-51
lines changed

1 file changed

+92
-51
lines changed

tests/test_llama.py

Lines changed: 92 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -37,77 +37,106 @@ def test_llama_cpp_tokenization():
3737
assert tokens[-1] == llama.token_eos()
3838
assert tokens == [1, 15043, 2787, 2]
3939

40-
41-
def test_llama_patch(monkeypatch):
40+
text = b""
41+
tokens = llama.tokenize(text, add_bos=True, special=True)
42+
assert tokens[-1] != llama.token_eos()
43+
assert tokens == [llama.token_bos()]
44+
assert text == llama.detokenize(tokens)
45+
46+
47+
@pytest.fixture
48+
def mock_llama(monkeypatch):
49+
def setup_mock(llama: llama_cpp.Llama, output_text: str):
50+
llama.reset()
51+
n_vocab = llama.n_vocab()
52+
output_tokens = llama.tokenize(
53+
output_text.encode("utf-8"), add_bos=True, special=True
54+
)
55+
n = 0
56+
last_n_tokens = 0
57+
58+
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
59+
nonlocal n
60+
nonlocal last_n_tokens
61+
# Test some basic invariants of this mocking technique
62+
assert ctx == llama._ctx.ctx
63+
assert llama.n_tokens == n
64+
assert batch.n_tokens > 0
65+
n += batch.n_tokens
66+
last_n_tokens = batch.n_tokens
67+
return 0
68+
69+
def mock_get_logits(*args, **kwargs):
70+
nonlocal last_n_tokens
71+
size = n_vocab * last_n_tokens
72+
return (llama_cpp.c_float * size)()
73+
74+
def mock_sample(*args, **kwargs):
75+
nonlocal n
76+
if n < len(output_tokens):
77+
return output_tokens[n]
78+
else:
79+
return llama.token_eos()
80+
81+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
82+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
83+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
84+
85+
return setup_mock
86+
87+
88+
def test_llama_patch(mock_llama):
4289
n_ctx = 128
4390
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx)
4491
n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
4592
assert n_vocab == 32000
4693

47-
## Set up mock function
48-
def mock_decode(*args, **kwargs):
49-
return 0
50-
51-
def mock_get_logits(*args, **kwargs):
52-
size = n_vocab * n_ctx
53-
return (llama_cpp.c_float * size)()
54-
55-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
56-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
57-
5894
text = "The quick brown fox"
59-
text_tokens = llama.tokenize(text.encode("utf-8"), add_bos=True, special=True)
6095
output_text = " jumps over the lazy dog."
61-
all_text_tokens = llama.tokenize((text + output_text).encode("utf-8"), add_bos=True, special=True)
62-
output_tokens = all_text_tokens[len(text_tokens):]
63-
token_eos = llama.token_eos()
64-
n = 0
65-
66-
def mock_sample(*args, **kwargs):
67-
nonlocal n
68-
if n < len(output_tokens):
69-
n += 1
70-
return output_tokens[n - 1]
71-
else:
72-
return token_eos
73-
74-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
96+
all_text = text + output_text
7597

98+
## Test basic completion from bos until eos
99+
mock_llama(llama, all_text)
100+
completion = llama.create_completion("", max_tokens=36)
101+
assert completion["choices"][0]["text"] == all_text
102+
assert completion["choices"][0]["finish_reason"] == "stop"
76103

77104
## Test basic completion until eos
78-
n = 0 # reset
105+
mock_llama(llama, all_text)
79106
completion = llama.create_completion(text, max_tokens=20)
80107
assert completion["choices"][0]["text"] == output_text
81108
assert completion["choices"][0]["finish_reason"] == "stop"
82109

83110
## Test streaming completion until eos
84-
n = 0 # reset
111+
mock_llama(llama, all_text)
85112
chunks = list(llama.create_completion(text, max_tokens=20, stream=True))
86113
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
87114
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
88115

89116
## Test basic completion until stop sequence
90-
n = 0 # reset
117+
mock_llama(llama, all_text)
91118
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
92119
assert completion["choices"][0]["text"] == " jumps over the "
93120
assert completion["choices"][0]["finish_reason"] == "stop"
94121

95122
## Test streaming completion until stop sequence
96-
n = 0 # reset
97-
chunks = list(llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"]))
123+
mock_llama(llama, all_text)
124+
chunks = list(
125+
llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
126+
)
98127
assert (
99128
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
100129
)
101130
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
102131

103132
## Test basic completion until length
104-
n = 0 # reset
133+
mock_llama(llama, all_text)
105134
completion = llama.create_completion(text, max_tokens=2)
106135
assert completion["choices"][0]["text"] == " jumps"
107136
assert completion["choices"][0]["finish_reason"] == "length"
108137

109138
## Test streaming completion until length
110-
n = 0 # reset
139+
mock_llama(llama, all_text)
111140
chunks = list(llama.create_completion(text, max_tokens=2, stream=True))
112141
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps"
113142
assert chunks[-1]["choices"][0]["finish_reason"] == "length"
@@ -131,44 +160,55 @@ def test_llama_pickle():
131160
assert llama.detokenize(llama.tokenize(text)) == text
132161

133162

134-
def test_utf8(monkeypatch):
135-
n_ctx = 512
136-
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx, logits_all=True)
163+
def test_utf8(mock_llama, monkeypatch):
164+
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True)
165+
n_ctx = llama.n_ctx()
137166
n_vocab = llama.n_vocab()
138167

168+
output_text = "😀"
169+
output_tokens = llama.tokenize(
170+
output_text.encode("utf-8"), add_bos=True, special=True
171+
)
172+
token_eos = llama.token_eos()
173+
n = 0
174+
175+
def reset():
176+
nonlocal n
177+
llama.reset()
178+
n = 0
179+
139180
## Set up mock function
140-
def mock_decode(*args, **kwargs):
181+
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
182+
nonlocal n
183+
assert batch.n_tokens > 0
184+
assert llama.n_tokens == n
185+
n += batch.n_tokens
141186
return 0
142187

143188
def mock_get_logits(*args, **kwargs):
144189
size = n_vocab * n_ctx
145190
return (llama_cpp.c_float * size)()
146191

147-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
148-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
149-
150-
output_text = "😀"
151-
output_tokens = llama.tokenize(output_text.encode("utf-8"))
152-
token_eos = llama.token_eos()
153-
n = 0
154-
155192
def mock_sample(*args, **kwargs):
156193
nonlocal n
157-
if n < len(output_tokens):
158-
n += 1
194+
if n <= len(output_tokens):
159195
return output_tokens[n - 1]
160196
else:
161197
return token_eos
162198

199+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
200+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
163201
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
164202

165203
## Test basic completion with utf8 multibyte
166-
n = 0 # reset
204+
# mock_llama(llama, output_text)
205+
reset()
167206
completion = llama.create_completion("", max_tokens=4)
168207
assert completion["choices"][0]["text"] == output_text
169208

170209
## Test basic completion with incomplete utf8 multibyte
171-
n = 0 # reset
210+
# mock_llama(llama, output_text)
211+
reset()
172212
completion = llama.create_completion("", max_tokens=1)
173213
assert completion["choices"][0]["text"] == ""
174214

@@ -196,5 +236,6 @@ def test_llama_server():
196236
],
197237
}
198238

239+
199240
def test_llama_cpp_version():
200241
assert llama_cpp.__version__

0 commit comments

Comments
 (0)