Skip to content

Commit 296304b

Browse files
committed
fix(server): Fix bug in FastAPI streaming response where dependency was released before request completes causing SEGFAULT
1 parent dc20e8c commit 296304b

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

llama_cpp/server/app.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import json
5+
import contextlib
56

67
from threading import Lock
78
from functools import partial
@@ -156,6 +157,7 @@ async def get_event_publisher(
156157
request: Request,
157158
inner_send_chan: MemoryObjectSendStream,
158159
iterator: Iterator,
160+
on_complete=None,
159161
):
160162
async with inner_send_chan:
161163
try:
@@ -175,6 +177,9 @@ async def get_event_publisher(
175177
with anyio.move_on_after(1, shield=True):
176178
print(f"Disconnected from client (via refresh/close) {request.client}")
177179
raise e
180+
finally:
181+
if on_complete:
182+
on_complete()
178183

179184

180185
def _logit_bias_tokens_to_input_ids(
@@ -258,8 +263,11 @@ async def authenticate(
258263
async def create_completion(
259264
request: Request,
260265
body: CreateCompletionRequest,
261-
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
262266
) -> llama_cpp.Completion:
267+
exit_stack = contextlib.ExitStack()
268+
llama_proxy = await run_in_threadpool(
269+
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
270+
)
263271
if isinstance(body.prompt, list):
264272
assert len(body.prompt) <= 1
265273
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
@@ -312,6 +320,7 @@ async def create_completion(
312320
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
313321
yield first_response
314322
yield from iterator_or_completion
323+
exit_stack.close()
315324

316325
send_chan, recv_chan = anyio.create_memory_object_stream(10)
317326
return EventSourceResponse(
@@ -321,6 +330,7 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
321330
request=request,
322331
inner_send_chan=send_chan,
323332
iterator=iterator(),
333+
on_complete=exit_stack.close,
324334
),
325335
sep="\n",
326336
ping_message_factory=_ping_message_factory,
@@ -449,8 +459,15 @@ async def create_chat_completion(
449459
},
450460
}
451461
),
452-
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
453462
) -> llama_cpp.ChatCompletion:
463+
# This is a workaround for an issue in FastAPI dependencies
464+
# where the dependency is cleaned up before a StreamingResponse
465+
# is complete.
466+
# https://github.com/tiangolo/fastapi/issues/11143
467+
exit_stack = contextlib.ExitStack()
468+
llama_proxy = await run_in_threadpool(
469+
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
470+
)
454471
exclude = {
455472
"n",
456473
"logit_bias_type",
@@ -491,6 +508,7 @@ async def create_chat_completion(
491508
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
492509
yield first_response
493510
yield from iterator_or_completion
511+
exit_stack.close()
494512

495513
send_chan, recv_chan = anyio.create_memory_object_stream(10)
496514
return EventSourceResponse(
@@ -500,11 +518,13 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
500518
request=request,
501519
inner_send_chan=send_chan,
502520
iterator=iterator(),
521+
on_complete=exit_stack.close,
503522
),
504523
sep="\n",
505524
ping_message_factory=_ping_message_factory,
506525
)
507526
else:
527+
exit_stack.close()
508528
return iterator_or_completion
509529

510530

0 commit comments

Comments
 (0)