1
1
import json
2
2
import multiprocessing
3
3
from threading import Lock
4
- from typing import List , Optional , Union , Iterator , Dict
4
+ from functools import partial
5
+ from typing import Iterator , List , Optional , Union , Dict
5
6
from typing_extensions import TypedDict , Literal
6
7
7
8
import llama_cpp
8
9
9
- from fastapi import Depends , FastAPI , APIRouter
10
+ import anyio
11
+ from anyio .streams .memory import MemoryObjectSendStream
12
+ from starlette .concurrency import run_in_threadpool , iterate_in_threadpool
13
+ from fastapi import Depends , FastAPI , APIRouter , Request
10
14
from fastapi .middleware .cors import CORSMiddleware
11
15
from pydantic import BaseModel , BaseSettings , Field , create_model_from_typeddict
12
16
from sse_starlette .sse import EventSourceResponse
@@ -241,35 +245,49 @@ class Config:
241
245
"/v1/completions" ,
242
246
response_model = CreateCompletionResponse ,
243
247
)
244
- def create_completion (
245
- request : CreateCompletionRequest , llama : llama_cpp .Llama = Depends (get_llama )
248
+ async def create_completion (
249
+ request : Request ,
250
+ body : CreateCompletionRequest ,
251
+ llama : llama_cpp .Llama = Depends (get_llama ),
246
252
):
247
- if isinstance (request .prompt , list ):
248
- assert len (request .prompt ) <= 1
249
- request .prompt = request .prompt [0 ] if len (request .prompt ) > 0 else ""
250
-
251
- completion_or_chunks = llama (
252
- ** request .dict (
253
- exclude = {
254
- "n" ,
255
- "best_of" ,
256
- "logit_bias" ,
257
- "user" ,
258
- }
259
- )
260
- )
261
- if request .stream :
262
-
263
- async def server_sent_events (
264
- chunks : Iterator [llama_cpp .CompletionChunk ],
265
- ):
266
- for chunk in chunks :
267
- yield dict (data = json .dumps (chunk ))
253
+ if isinstance (body .prompt , list ):
254
+ assert len (body .prompt ) <= 1
255
+ body .prompt = body .prompt [0 ] if len (body .prompt ) > 0 else ""
256
+
257
+ exclude = {
258
+ "n" ,
259
+ "best_of" ,
260
+ "logit_bias" ,
261
+ "user" ,
262
+ }
263
+ kwargs = body .dict (exclude = exclude )
264
+ if body .stream :
265
+ send_chan , recv_chan = anyio .create_memory_object_stream (10 )
266
+
267
+ async def event_publisher (inner_send_chan : MemoryObjectSendStream ):
268
+ async with inner_send_chan :
269
+ try :
270
+ iterator : Iterator [llama_cpp .CompletionChunk ] = await run_in_threadpool (llama , ** kwargs ) # type: ignore
271
+ async for chunk in iterate_in_threadpool (iterator ):
272
+ await inner_send_chan .send (dict (data = json .dumps (chunk )))
273
+ if await request .is_disconnected ():
274
+ raise anyio .get_cancelled_exc_class ()()
275
+ await inner_send_chan .send (dict (data = "[DONE]" ))
276
+ except anyio .get_cancelled_exc_class () as e :
277
+ print ("disconnected" )
278
+ with anyio .move_on_after (1 , shield = True ):
279
+ print (
280
+ f"Disconnected from client (via refresh/close) { request .client } "
281
+ )
282
+ await inner_send_chan .send (dict (closing = True ))
283
+ raise e
268
284
269
- chunks : Iterator [llama_cpp .CompletionChunk ] = completion_or_chunks # type: ignore
270
- return EventSourceResponse (server_sent_events (chunks ))
271
- completion : llama_cpp .Completion = completion_or_chunks # type: ignore
272
- return completion
285
+ return EventSourceResponse (
286
+ recv_chan , data_sender_callable = partial (event_publisher , send_chan )
287
+ )
288
+ else :
289
+ completion : llama_cpp .Completion = await run_in_threadpool (llama , ** kwargs ) # type: ignore
290
+ return completion
273
291
274
292
275
293
class CreateEmbeddingRequest (BaseModel ):
@@ -292,10 +310,12 @@ class Config:
292
310
"/v1/embeddings" ,
293
311
response_model = CreateEmbeddingResponse ,
294
312
)
295
- def create_embedding (
313
+ async def create_embedding (
296
314
request : CreateEmbeddingRequest , llama : llama_cpp .Llama = Depends (get_llama )
297
315
):
298
- return llama .create_embedding (** request .dict (exclude = {"user" }))
316
+ return await run_in_threadpool (
317
+ llama .create_embedding , ** request .dict (exclude = {"user" })
318
+ )
299
319
300
320
301
321
class ChatCompletionRequestMessage (BaseModel ):
@@ -349,36 +369,47 @@ class Config:
349
369
"/v1/chat/completions" ,
350
370
response_model = CreateChatCompletionResponse ,
351
371
)
352
- def create_chat_completion (
353
- request : CreateChatCompletionRequest ,
372
+ async def create_chat_completion (
373
+ request : Request ,
374
+ body : CreateChatCompletionRequest ,
354
375
llama : llama_cpp .Llama = Depends (get_llama ),
355
376
) -> Union [llama_cpp .ChatCompletion , EventSourceResponse ]:
356
- completion_or_chunks = llama .create_chat_completion (
357
- ** request .dict (
358
- exclude = {
359
- "n" ,
360
- "logit_bias" ,
361
- "user" ,
362
- }
363
- ),
364
- )
365
-
366
- if request .stream :
367
-
368
- async def server_sent_events (
369
- chat_chunks : Iterator [llama_cpp .ChatCompletionChunk ],
370
- ):
371
- for chat_chunk in chat_chunks :
372
- yield dict (data = json .dumps (chat_chunk ))
373
- yield dict (data = "[DONE]" )
374
-
375
- chunks : Iterator [llama_cpp .ChatCompletionChunk ] = completion_or_chunks # type: ignore
377
+ exclude = {
378
+ "n" ,
379
+ "logit_bias" ,
380
+ "user" ,
381
+ }
382
+ kwargs = body .dict (exclude = exclude )
383
+ if body .stream :
384
+ send_chan , recv_chan = anyio .create_memory_object_stream (10 )
385
+
386
+ async def event_publisher (inner_send_chan : MemoryObjectSendStream ):
387
+ async with inner_send_chan :
388
+ try :
389
+ iterator : Iterator [llama_cpp .ChatCompletionChunk ] = await run_in_threadpool (llama .create_chat_completion , ** kwargs ) # type: ignore
390
+ async for chat_chunk in iterate_in_threadpool (iterator ):
391
+ await inner_send_chan .send (dict (data = json .dumps (chat_chunk )))
392
+ if await request .is_disconnected ():
393
+ raise anyio .get_cancelled_exc_class ()()
394
+ await inner_send_chan .send (dict (data = "[DONE]" ))
395
+ except anyio .get_cancelled_exc_class () as e :
396
+ print ("disconnected" )
397
+ with anyio .move_on_after (1 , shield = True ):
398
+ print (
399
+ f"Disconnected from client (via refresh/close) { request .client } "
400
+ )
401
+ await inner_send_chan .send (dict (closing = True ))
402
+ raise e
376
403
377
404
return EventSourceResponse (
378
- server_sent_events (chunks ),
405
+ recv_chan ,
406
+ data_sender_callable = partial (event_publisher , send_chan ),
407
+ )
408
+ else :
409
+ completion : llama_cpp .ChatCompletion = await run_in_threadpool (
410
+ llama .create_chat_completion , ** kwargs # type: ignore
379
411
)
380
- completion : llama_cpp .ChatCompletion = completion_or_chunks # type: ignore
381
- return completion
412
+ return completion
382
413
383
414
384
415
class ModelData (TypedDict ):
@@ -397,7 +428,7 @@ class ModelList(TypedDict):
397
428
398
429
399
430
@router .get ("/v1/models" , response_model = GetModelResponse )
400
- def get_models (
431
+ async def get_models (
401
432
settings : Settings = Depends (get_settings ),
402
433
llama : llama_cpp .Llama = Depends (get_llama ),
403
434
) -> ModelList :
0 commit comments