2
2
3
3
import os
4
4
import json
5
+ import contextlib
5
6
6
7
from threading import Lock
7
8
from functools import partial
@@ -156,6 +157,7 @@ async def get_event_publisher(
156
157
request : Request ,
157
158
inner_send_chan : MemoryObjectSendStream ,
158
159
iterator : Iterator ,
160
+ on_complete = None ,
159
161
):
160
162
async with inner_send_chan :
161
163
try :
@@ -175,6 +177,9 @@ async def get_event_publisher(
175
177
with anyio .move_on_after (1 , shield = True ):
176
178
print (f"Disconnected from client (via refresh/close) { request .client } " )
177
179
raise e
180
+ finally :
181
+ if on_complete :
182
+ on_complete ()
178
183
179
184
180
185
def _logit_bias_tokens_to_input_ids (
@@ -258,8 +263,11 @@ async def authenticate(
258
263
async def create_completion (
259
264
request : Request ,
260
265
body : CreateCompletionRequest ,
261
- llama_proxy : LlamaProxy = Depends (get_llama_proxy ),
262
266
) -> llama_cpp .Completion :
267
+ exit_stack = contextlib .ExitStack ()
268
+ llama_proxy = await run_in_threadpool (
269
+ lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
270
+ )
263
271
if isinstance (body .prompt , list ):
264
272
assert len (body .prompt ) <= 1
265
273
body .prompt = body .prompt [0 ] if len (body .prompt ) > 0 else ""
@@ -312,6 +320,7 @@ async def create_completion(
312
320
def iterator () -> Iterator [llama_cpp .CreateCompletionStreamResponse ]:
313
321
yield first_response
314
322
yield from iterator_or_completion
323
+ exit_stack .close ()
315
324
316
325
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
317
326
return EventSourceResponse (
@@ -321,6 +330,7 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
321
330
request = request ,
322
331
inner_send_chan = send_chan ,
323
332
iterator = iterator (),
333
+ on_complete = exit_stack .close ,
324
334
),
325
335
sep = "\n " ,
326
336
ping_message_factory = _ping_message_factory ,
@@ -449,8 +459,15 @@ async def create_chat_completion(
449
459
},
450
460
}
451
461
),
452
- llama_proxy : LlamaProxy = Depends (get_llama_proxy ),
453
462
) -> llama_cpp .ChatCompletion :
463
+ # This is a workaround for an issue in FastAPI dependencies
464
+ # where the dependency is cleaned up before a StreamingResponse
465
+ # is complete.
466
+ # https://github.com/tiangolo/fastapi/issues/11143
467
+ exit_stack = contextlib .ExitStack ()
468
+ llama_proxy = await run_in_threadpool (
469
+ lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
470
+ )
454
471
exclude = {
455
472
"n" ,
456
473
"logit_bias_type" ,
@@ -491,6 +508,7 @@ async def create_chat_completion(
491
508
def iterator () -> Iterator [llama_cpp .ChatCompletionChunk ]:
492
509
yield first_response
493
510
yield from iterator_or_completion
511
+ exit_stack .close ()
494
512
495
513
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
496
514
return EventSourceResponse (
@@ -500,11 +518,13 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
500
518
request = request ,
501
519
inner_send_chan = send_chan ,
502
520
iterator = iterator (),
521
+ on_complete = exit_stack .close ,
503
522
),
504
523
sep = "\n " ,
505
524
ping_message_factory = _ping_message_factory ,
506
525
)
507
526
else :
527
+ exit_stack .close ()
508
528
return iterator_or_completion
509
529
510
530
0 commit comments