Skip to content

Commit 79ba9ed

Browse files
authored
Merge pull request #125 from Stonelinks/app-server-module-importable
Make app server module importable
2 parents 755f9fa + efe8e6f commit 79ba9ed

File tree

6 files changed

+401
-270
lines changed

6 files changed

+401
-270
lines changed

llama_cpp/server/__init__.py

Whitespace-only changes.

llama_cpp/server/__main__.py

Lines changed: 14 additions & 268 deletions
Original file line numberDiff line numberDiff line change
@@ -5,283 +5,29 @@
55
```bash
66
pip install fastapi uvicorn sse-starlette
77
export MODEL=../models/7B/...
8-
uvicorn fastapi_server_chat:app --reload
98
```
109
11-
Then visit http://localhost:8000/docs to see the interactive API docs.
12-
13-
"""
14-
import os
15-
import json
16-
from threading import Lock
17-
from typing import List, Optional, Literal, Union, Iterator, Dict
18-
from typing_extensions import TypedDict
19-
20-
import llama_cpp
21-
22-
from fastapi import Depends, FastAPI
23-
from fastapi.middleware.cors import CORSMiddleware
24-
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
25-
from sse_starlette.sse import EventSourceResponse
26-
27-
28-
class Settings(BaseSettings):
29-
model: str
30-
n_ctx: int = 2048
31-
n_batch: int = 512
32-
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
33-
f16_kv: bool = True
34-
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
35-
use_mmap: bool = True
36-
embedding: bool = True
37-
last_n_tokens_size: int = 64
38-
logits_all: bool = False
39-
cache: bool = False # WARNING: This is an experimental feature
40-
41-
42-
app = FastAPI(
43-
title="🦙 llama.cpp Python API",
44-
version="0.0.1",
45-
)
46-
app.add_middleware(
47-
CORSMiddleware,
48-
allow_origins=["*"],
49-
allow_credentials=True,
50-
allow_methods=["*"],
51-
allow_headers=["*"],
52-
)
53-
settings = Settings()
54-
llama = llama_cpp.Llama(
55-
settings.model,
56-
f16_kv=settings.f16_kv,
57-
use_mlock=settings.use_mlock,
58-
use_mmap=settings.use_mmap,
59-
embedding=settings.embedding,
60-
logits_all=settings.logits_all,
61-
n_threads=settings.n_threads,
62-
n_batch=settings.n_batch,
63-
n_ctx=settings.n_ctx,
64-
last_n_tokens_size=settings.last_n_tokens_size,
65-
)
66-
if settings.cache:
67-
cache = llama_cpp.LlamaCache()
68-
llama.set_cache(cache)
69-
llama_lock = Lock()
70-
71-
72-
def get_llama():
73-
with llama_lock:
74-
yield llama
75-
76-
77-
class CreateCompletionRequest(BaseModel):
78-
prompt: Union[str, List[str]]
79-
suffix: Optional[str] = Field(None)
80-
max_tokens: int = 16
81-
temperature: float = 0.8
82-
top_p: float = 0.95
83-
echo: bool = False
84-
stop: Optional[List[str]] = []
85-
stream: bool = False
86-
87-
# ignored or currently unsupported
88-
model: Optional[str] = Field(None)
89-
n: Optional[int] = 1
90-
logprobs: Optional[int] = Field(None)
91-
presence_penalty: Optional[float] = 0
92-
frequency_penalty: Optional[float] = 0
93-
best_of: Optional[int] = 1
94-
logit_bias: Optional[Dict[str, float]] = Field(None)
95-
user: Optional[str] = Field(None)
96-
97-
# llama.cpp specific parameters
98-
top_k: int = 40
99-
repeat_penalty: float = 1.1
100-
101-
class Config:
102-
schema_extra = {
103-
"example": {
104-
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
105-
"stop": ["\n", "###"],
106-
}
107-
}
108-
109-
110-
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
111-
112-
113-
@app.post(
114-
"/v1/completions",
115-
response_model=CreateCompletionResponse,
116-
)
117-
def create_completion(
118-
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
119-
):
120-
if isinstance(request.prompt, list):
121-
request.prompt = "".join(request.prompt)
122-
123-
completion_or_chunks = llama(
124-
**request.dict(
125-
exclude={
126-
"model",
127-
"n",
128-
"frequency_penalty",
129-
"presence_penalty",
130-
"best_of",
131-
"logit_bias",
132-
"user",
133-
}
134-
)
135-
)
136-
if request.stream:
137-
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
138-
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
139-
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
140-
return completion
141-
142-
143-
class CreateEmbeddingRequest(BaseModel):
144-
model: Optional[str]
145-
input: str
146-
user: Optional[str]
147-
148-
class Config:
149-
schema_extra = {
150-
"example": {
151-
"input": "The food was delicious and the waiter...",
152-
}
153-
}
154-
155-
156-
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
157-
158-
159-
@app.post(
160-
"/v1/embeddings",
161-
response_model=CreateEmbeddingResponse,
162-
)
163-
def create_embedding(
164-
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
165-
):
166-
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
167-
168-
169-
class ChatCompletionRequestMessage(BaseModel):
170-
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
171-
content: str
172-
user: Optional[str] = None
173-
174-
175-
class CreateChatCompletionRequest(BaseModel):
176-
model: Optional[str]
177-
messages: List[ChatCompletionRequestMessage]
178-
temperature: float = 0.8
179-
top_p: float = 0.95
180-
stream: bool = False
181-
stop: Optional[List[str]] = []
182-
max_tokens: int = 128
183-
184-
# ignored or currently unsupported
185-
model: Optional[str] = Field(None)
186-
n: Optional[int] = 1
187-
presence_penalty: Optional[float] = 0
188-
frequency_penalty: Optional[float] = 0
189-
logit_bias: Optional[Dict[str, float]] = Field(None)
190-
user: Optional[str] = Field(None)
191-
192-
# llama.cpp specific parameters
193-
repeat_penalty: float = 1.1
194-
195-
class Config:
196-
schema_extra = {
197-
"example": {
198-
"messages": [
199-
ChatCompletionRequestMessage(
200-
role="system", content="You are a helpful assistant."
201-
),
202-
ChatCompletionRequestMessage(
203-
role="user", content="What is the capital of France?"
204-
),
205-
]
206-
}
207-
}
208-
209-
210-
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
211-
212-
213-
@app.post(
214-
"/v1/chat/completions",
215-
response_model=CreateChatCompletionResponse,
216-
)
217-
def create_chat_completion(
218-
request: CreateChatCompletionRequest,
219-
llama: llama_cpp.Llama = Depends(get_llama),
220-
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
221-
completion_or_chunks = llama.create_chat_completion(
222-
**request.dict(
223-
exclude={
224-
"model",
225-
"n",
226-
"presence_penalty",
227-
"frequency_penalty",
228-
"logit_bias",
229-
"user",
230-
}
231-
),
232-
)
233-
234-
if request.stream:
235-
236-
async def server_sent_events(
237-
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
238-
):
239-
for chat_chunk in chat_chunks:
240-
yield dict(data=json.dumps(chat_chunk))
241-
yield dict(data="[DONE]")
242-
243-
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
244-
245-
return EventSourceResponse(
246-
server_sent_events(chunks),
247-
)
248-
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
249-
return completion
250-
251-
252-
class ModelData(TypedDict):
253-
id: str
254-
object: Literal["model"]
255-
owned_by: str
256-
permissions: List[str]
257-
258-
259-
class ModelList(TypedDict):
260-
object: Literal["list"]
261-
data: List[ModelData]
10+
Then run:
11+
```
12+
uvicorn llama_cpp.server.app:app --reload
13+
```
26214
15+
or
26316
264-
GetModelResponse = create_model_from_typeddict(ModelList)
17+
```
18+
python3 -m llama_cpp.server
19+
```
26520
21+
Then visit http://localhost:8000/docs to see the interactive API docs.
26622
267-
@app.get("/v1/models", response_model=GetModelResponse)
268-
def get_models() -> ModelList:
269-
return {
270-
"object": "list",
271-
"data": [
272-
{
273-
"id": llama.model_path,
274-
"object": "model",
275-
"owned_by": "me",
276-
"permissions": [],
277-
}
278-
],
279-
}
23+
"""
24+
import os
25+
import uvicorn
28026

27+
from llama_cpp.server.app import app, init_llama
28128

28229
if __name__ == "__main__":
283-
import os
284-
import uvicorn
30+
init_llama()
28531

28632
uvicorn.run(
28733
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

0 commit comments

Comments
 (0)