Skip to content

Commit efe8e6f

Browse files
committed
llama_cpp server: slight refactor to init_llama function
Define an init_llama function that starts llama with supplied settings instead of just doing it in the global context of app.py This allows the test to be less brittle by not needing to mess with os.environ, then importing the app
1 parent 6d8db9d commit efe8e6f

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

llama_cpp/server/__main__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
import os
2525
import uvicorn
2626

27-
from llama_cpp.server.app import app
27+
from llama_cpp.server.app import app, init_llama
2828

2929
if __name__ == "__main__":
30+
init_llama()
3031

3132
uvicorn.run(
3233
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

llama_cpp/server/app.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
class Settings(BaseSettings):
16-
model: str = os.environ["MODEL"]
16+
model: str = os.environ.get("MODEL", "null")
1717
n_ctx: int = 2048
1818
n_batch: int = 512
1919
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
@@ -38,31 +38,34 @@ class Settings(BaseSettings):
3838
allow_methods=["*"],
3939
allow_headers=["*"],
4040
)
41-
settings = Settings()
42-
llama = llama_cpp.Llama(
43-
settings.model,
44-
f16_kv=settings.f16_kv,
45-
use_mlock=settings.use_mlock,
46-
use_mmap=settings.use_mmap,
47-
embedding=settings.embedding,
48-
logits_all=settings.logits_all,
49-
n_threads=settings.n_threads,
50-
n_batch=settings.n_batch,
51-
n_ctx=settings.n_ctx,
52-
last_n_tokens_size=settings.last_n_tokens_size,
53-
vocab_only=settings.vocab_only,
54-
)
55-
if settings.cache:
56-
cache = llama_cpp.LlamaCache()
57-
llama.set_cache(cache)
58-
llama_lock = Lock()
5941

42+
llama: llama_cpp.Llama = None
43+
def init_llama(settings: Settings = None):
44+
if settings is None:
45+
settings = Settings()
46+
global llama
47+
llama = llama_cpp.Llama(
48+
settings.model,
49+
f16_kv=settings.f16_kv,
50+
use_mlock=settings.use_mlock,
51+
use_mmap=settings.use_mmap,
52+
embedding=settings.embedding,
53+
logits_all=settings.logits_all,
54+
n_threads=settings.n_threads,
55+
n_batch=settings.n_batch,
56+
n_ctx=settings.n_ctx,
57+
last_n_tokens_size=settings.last_n_tokens_size,
58+
vocab_only=settings.vocab_only,
59+
)
60+
if settings.cache:
61+
cache = llama_cpp.LlamaCache()
62+
llama.set_cache(cache)
6063

64+
llama_lock = Lock()
6165
def get_llama():
6266
with llama_lock:
6367
yield llama
6468

65-
6669
class CreateCompletionRequest(BaseModel):
6770
prompt: Union[str, List[str]]
6871
suffix: Optional[str] = Field(None)

tests/test_llama.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,11 @@ def mock_sample(*args, **kwargs):
132132

133133
def test_llama_server():
134134
from fastapi.testclient import TestClient
135-
import os
136-
os.environ["MODEL"] = MODEL
137-
os.environ["VOCAB_ONLY"] = "true"
138-
from llama_cpp.server.app import app
135+
from llama_cpp.server.app import app, init_llama, Settings
136+
s = Settings()
137+
s.model = MODEL
138+
s.vocab_only = True
139+
init_llama(s)
139140
client = TestClient(app)
140141
response = client.get("/v1/models")
141142
assert response.json() == {

0 commit comments

Comments
 (0)