Skip to content

Commit 7499fc1

Browse files
authored
Merge pull request #126 from Stonelinks/deprecate-example-server
Deprecate example server
2 parents 1971514 + 0fcc25c commit 7499fc1

File tree

1 file changed

+19
-244
lines changed

1 file changed

+19
-244
lines changed

examples/high_level_api/fastapi_server.py

Lines changed: 19 additions & 244 deletions
Original file line numberDiff line numberDiff line change
@@ -4,259 +4,34 @@
44
55
```bash
66
pip install fastapi uvicorn sse-starlette
7-
export MODEL=../models/7B/ggml-model.bin
8-
uvicorn fastapi_server_chat:app --reload
7+
export MODEL=../models/7B/...
98
```
109
11-
Then visit http://localhost:8000/docs to see the interactive API docs.
12-
13-
"""
14-
import os
15-
import json
16-
from typing import List, Optional, Literal, Union, Iterator, Dict
17-
from typing_extensions import TypedDict
18-
19-
import llama_cpp
20-
21-
from fastapi import FastAPI
22-
from fastapi.middleware.cors import CORSMiddleware
23-
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
24-
from sse_starlette.sse import EventSourceResponse
25-
26-
27-
class Settings(BaseSettings):
28-
model: str
29-
n_ctx: int = 2048
30-
n_batch: int = 8
31-
n_threads: int = int(os.cpu_count() / 2) or 1
32-
f16_kv: bool = True
33-
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
34-
embedding: bool = True
35-
last_n_tokens_size: int = 64
36-
37-
38-
app = FastAPI(
39-
title="🦙 llama.cpp Python API",
40-
version="0.0.1",
41-
)
42-
app.add_middleware(
43-
CORSMiddleware,
44-
allow_origins=["*"],
45-
allow_credentials=True,
46-
allow_methods=["*"],
47-
allow_headers=["*"],
48-
)
49-
settings = Settings()
50-
llama = llama_cpp.Llama(
51-
settings.model,
52-
f16_kv=settings.f16_kv,
53-
use_mlock=settings.use_mlock,
54-
embedding=settings.embedding,
55-
n_threads=settings.n_threads,
56-
n_batch=settings.n_batch,
57-
n_ctx=settings.n_ctx,
58-
last_n_tokens_size=settings.last_n_tokens_size,
59-
)
60-
61-
62-
class CreateCompletionRequest(BaseModel):
63-
prompt: str
64-
suffix: Optional[str] = Field(None)
65-
max_tokens: int = 16
66-
temperature: float = 0.8
67-
top_p: float = 0.95
68-
echo: bool = False
69-
stop: List[str] = []
70-
stream: bool = False
71-
72-
# ignored or currently unsupported
73-
model: Optional[str] = Field(None)
74-
n: Optional[int] = 1
75-
logprobs: Optional[int] = Field(None)
76-
presence_penalty: Optional[float] = 0
77-
frequency_penalty: Optional[float] = 0
78-
best_of: Optional[int] = 1
79-
logit_bias: Optional[Dict[str, float]] = Field(None)
80-
user: Optional[str] = Field(None)
81-
82-
# llama.cpp specific parameters
83-
top_k: int = 40
84-
repeat_penalty: float = 1.1
85-
86-
class Config:
87-
schema_extra = {
88-
"example": {
89-
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
90-
"stop": ["\n", "###"],
91-
}
92-
}
93-
94-
95-
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
96-
97-
98-
@app.post(
99-
"/v1/completions",
100-
response_model=CreateCompletionResponse,
101-
)
102-
def create_completion(request: CreateCompletionRequest):
103-
if request.stream:
104-
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
105-
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
106-
return llama(
107-
**request.dict(
108-
exclude={
109-
"model",
110-
"n",
111-
"logprobs",
112-
"frequency_penalty",
113-
"presence_penalty",
114-
"best_of",
115-
"logit_bias",
116-
"user",
117-
}
118-
)
119-
)
120-
121-
122-
class CreateEmbeddingRequest(BaseModel):
123-
model: Optional[str]
124-
input: str
125-
user: Optional[str]
126-
127-
class Config:
128-
schema_extra = {
129-
"example": {
130-
"input": "The food was delicious and the waiter...",
131-
}
132-
}
133-
134-
135-
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
136-
137-
138-
@app.post(
139-
"/v1/embeddings",
140-
response_model=CreateEmbeddingResponse,
141-
)
142-
def create_embedding(request: CreateEmbeddingRequest):
143-
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
144-
145-
146-
class ChatCompletionRequestMessage(BaseModel):
147-
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
148-
content: str
149-
user: Optional[str] = None
150-
151-
152-
class CreateChatCompletionRequest(BaseModel):
153-
model: Optional[str]
154-
messages: List[ChatCompletionRequestMessage]
155-
temperature: float = 0.8
156-
top_p: float = 0.95
157-
stream: bool = False
158-
stop: List[str] = []
159-
max_tokens: int = 128
160-
161-
# ignored or currently unsupported
162-
model: Optional[str] = Field(None)
163-
n: Optional[int] = 1
164-
presence_penalty: Optional[float] = 0
165-
frequency_penalty: Optional[float] = 0
166-
logit_bias: Optional[Dict[str, float]] = Field(None)
167-
user: Optional[str] = Field(None)
168-
169-
# llama.cpp specific parameters
170-
repeat_penalty: float = 1.1
171-
172-
class Config:
173-
schema_extra = {
174-
"example": {
175-
"messages": [
176-
ChatCompletionRequestMessage(
177-
role="system", content="You are a helpful assistant."
178-
),
179-
ChatCompletionRequestMessage(
180-
role="user", content="What is the capital of France?"
181-
),
182-
]
183-
}
184-
}
185-
186-
187-
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
188-
189-
190-
@app.post(
191-
"/v1/chat/completions",
192-
response_model=CreateChatCompletionResponse,
193-
)
194-
async def create_chat_completion(
195-
request: CreateChatCompletionRequest,
196-
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
197-
completion_or_chunks = llama.create_chat_completion(
198-
**request.dict(
199-
exclude={
200-
"model",
201-
"n",
202-
"presence_penalty",
203-
"frequency_penalty",
204-
"logit_bias",
205-
"user",
206-
}
207-
),
208-
)
209-
210-
if request.stream:
211-
212-
async def server_sent_events(
213-
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
214-
):
215-
for chat_chunk in chat_chunks:
216-
yield dict(data=json.dumps(chat_chunk))
217-
yield dict(data="[DONE]")
218-
219-
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
220-
221-
return EventSourceResponse(
222-
server_sent_events(chunks),
223-
)
224-
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
225-
return completion
226-
227-
228-
class ModelData(TypedDict):
229-
id: str
230-
object: Literal["model"]
231-
owned_by: str
232-
permissions: List[str]
10+
Then run:
11+
```
12+
uvicorn llama_cpp.server.app:app --reload
13+
```
23314
15+
or
23416
235-
class ModelList(TypedDict):
236-
object: Literal["list"]
237-
data: List[ModelData]
17+
```
18+
python3 -m llama_cpp.server
19+
```
23820
21+
Then visit http://localhost:8000/docs to see the interactive API docs.
23922
240-
GetModelResponse = create_model_from_typeddict(ModelList)
24123
24+
To actually see the implementation of the server, see llama_cpp/server/app.py
24225
243-
@app.get("/v1/models", response_model=GetModelResponse)
244-
def get_models() -> ModelList:
245-
return {
246-
"object": "list",
247-
"data": [
248-
{
249-
"id": llama.model_path,
250-
"object": "model",
251-
"owned_by": "me",
252-
"permissions": [],
253-
}
254-
],
255-
}
26+
"""
27+
import os
28+
import uvicorn
25629

30+
from llama_cpp.server.app import create_app
25731

25832
if __name__ == "__main__":
259-
import os
260-
import uvicorn
33+
app = create_app()
26134

262-
uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=os.getenv("PORT", 8000))
35+
uvicorn.run(
36+
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
37+
)

0 commit comments

Comments
 (0)