Skip to content

Commit 80066f0

Browse files
committed
Use async routes
1 parent c2b59a5 commit 80066f0

File tree

1 file changed

+88
-57
lines changed

1 file changed

+88
-57
lines changed

llama_cpp/server/app.py

Lines changed: 88 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import json
22
import multiprocessing
33
from threading import Lock
4-
from typing import List, Optional, Union, Iterator, Dict
4+
from functools import partial
5+
from typing import Iterator, List, Optional, Union, Dict
56
from typing_extensions import TypedDict, Literal
67

78
import llama_cpp
89

9-
from fastapi import Depends, FastAPI, APIRouter
10+
import anyio
11+
from anyio.streams.memory import MemoryObjectSendStream
12+
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
13+
from fastapi import Depends, FastAPI, APIRouter, Request
1014
from fastapi.middleware.cors import CORSMiddleware
1115
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
1216
from sse_starlette.sse import EventSourceResponse
@@ -241,35 +245,49 @@ class Config:
241245
"/v1/completions",
242246
response_model=CreateCompletionResponse,
243247
)
244-
def create_completion(
245-
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
248+
async def create_completion(
249+
request: Request,
250+
body: CreateCompletionRequest,
251+
llama: llama_cpp.Llama = Depends(get_llama),
246252
):
247-
if isinstance(request.prompt, list):
248-
assert len(request.prompt) <= 1
249-
request.prompt = request.prompt[0] if len(request.prompt) > 0 else ""
250-
251-
completion_or_chunks = llama(
252-
**request.dict(
253-
exclude={
254-
"n",
255-
"best_of",
256-
"logit_bias",
257-
"user",
258-
}
259-
)
260-
)
261-
if request.stream:
262-
263-
async def server_sent_events(
264-
chunks: Iterator[llama_cpp.CompletionChunk],
265-
):
266-
for chunk in chunks:
267-
yield dict(data=json.dumps(chunk))
253+
if isinstance(body.prompt, list):
254+
assert len(body.prompt) <= 1
255+
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
256+
257+
exclude = {
258+
"n",
259+
"best_of",
260+
"logit_bias",
261+
"user",
262+
}
263+
kwargs = body.dict(exclude=exclude)
264+
if body.stream:
265+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
266+
267+
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
268+
async with inner_send_chan:
269+
try:
270+
iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
271+
async for chunk in iterate_in_threadpool(iterator):
272+
await inner_send_chan.send(dict(data=json.dumps(chunk)))
273+
if await request.is_disconnected():
274+
raise anyio.get_cancelled_exc_class()()
275+
await inner_send_chan.send(dict(data="[DONE]"))
276+
except anyio.get_cancelled_exc_class() as e:
277+
print("disconnected")
278+
with anyio.move_on_after(1, shield=True):
279+
print(
280+
f"Disconnected from client (via refresh/close) {request.client}"
281+
)
282+
await inner_send_chan.send(dict(closing=True))
283+
raise e
268284

269-
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
270-
return EventSourceResponse(server_sent_events(chunks))
271-
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
272-
return completion
285+
return EventSourceResponse(
286+
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
287+
)
288+
else:
289+
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
290+
return completion
273291

274292

275293
class CreateEmbeddingRequest(BaseModel):
@@ -292,10 +310,12 @@ class Config:
292310
"/v1/embeddings",
293311
response_model=CreateEmbeddingResponse,
294312
)
295-
def create_embedding(
313+
async def create_embedding(
296314
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
297315
):
298-
return llama.create_embedding(**request.dict(exclude={"user"}))
316+
return await run_in_threadpool(
317+
llama.create_embedding, **request.dict(exclude={"user"})
318+
)
299319

300320

301321
class ChatCompletionRequestMessage(BaseModel):
@@ -349,36 +369,47 @@ class Config:
349369
"/v1/chat/completions",
350370
response_model=CreateChatCompletionResponse,
351371
)
352-
def create_chat_completion(
353-
request: CreateChatCompletionRequest,
372+
async def create_chat_completion(
373+
request: Request,
374+
body: CreateChatCompletionRequest,
354375
llama: llama_cpp.Llama = Depends(get_llama),
355376
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
356-
completion_or_chunks = llama.create_chat_completion(
357-
**request.dict(
358-
exclude={
359-
"n",
360-
"logit_bias",
361-
"user",
362-
}
363-
),
364-
)
365-
366-
if request.stream:
367-
368-
async def server_sent_events(
369-
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
370-
):
371-
for chat_chunk in chat_chunks:
372-
yield dict(data=json.dumps(chat_chunk))
373-
yield dict(data="[DONE]")
374-
375-
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
377+
exclude = {
378+
"n",
379+
"logit_bias",
380+
"user",
381+
}
382+
kwargs = body.dict(exclude=exclude)
383+
if body.stream:
384+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
385+
386+
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
387+
async with inner_send_chan:
388+
try:
389+
iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
390+
async for chat_chunk in iterate_in_threadpool(iterator):
391+
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
392+
if await request.is_disconnected():
393+
raise anyio.get_cancelled_exc_class()()
394+
await inner_send_chan.send(dict(data="[DONE]"))
395+
except anyio.get_cancelled_exc_class() as e:
396+
print("disconnected")
397+
with anyio.move_on_after(1, shield=True):
398+
print(
399+
f"Disconnected from client (via refresh/close) {request.client}"
400+
)
401+
await inner_send_chan.send(dict(closing=True))
402+
raise e
376403

377404
return EventSourceResponse(
378-
server_sent_events(chunks),
405+
recv_chan,
406+
data_sender_callable=partial(event_publisher, send_chan),
407+
)
408+
else:
409+
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
410+
llama.create_chat_completion, **kwargs # type: ignore
379411
)
380-
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
381-
return completion
412+
return completion
382413

383414

384415
class ModelData(TypedDict):
@@ -397,7 +428,7 @@ class ModelList(TypedDict):
397428

398429

399430
@router.get("/v1/models", response_model=GetModelResponse)
400-
def get_models(
431+
async def get_models(
401432
settings: Settings = Depends(get_settings),
402433
llama: llama_cpp.Llama = Depends(get_llama),
403434
) -> ModelList:

0 commit comments

Comments
 (0)