Skip to content

Commit 4daf77e

Browse files
committed
Format
1 parent 2920c4b commit 4daf77e

File tree

1 file changed

+22
-28
lines changed

1 file changed

+22
-28
lines changed

llama_cpp/server/app.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,8 @@ class ErrorResponseFormatters:
144144

145145
@staticmethod
146146
def context_length_exceeded(
147-
request: Union[
148-
"CreateCompletionRequest", "CreateChatCompletionRequest"
149-
],
150-
match, # type: Match[str] # type: ignore
147+
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
148+
match, # type: Match[str] # type: ignore
151149
) -> Tuple[int, ErrorResponse]:
152150
"""Formatter for context length exceeded error"""
153151

@@ -184,10 +182,8 @@ def context_length_exceeded(
184182

185183
@staticmethod
186184
def model_not_found(
187-
request: Union[
188-
"CreateCompletionRequest", "CreateChatCompletionRequest"
189-
],
190-
match # type: Match[str] # type: ignore
185+
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
186+
match, # type: Match[str] # type: ignore
191187
) -> Tuple[int, ErrorResponse]:
192188
"""Formatter for model_not_found error"""
193189

@@ -315,12 +311,7 @@ def create_app(settings: Optional[Settings] = None):
315311
settings = Settings()
316312

317313
middleware = [
318-
Middleware(
319-
RawContextMiddleware,
320-
plugins=(
321-
plugins.RequestIdPlugin(),
322-
)
323-
)
314+
Middleware(RawContextMiddleware, plugins=(plugins.RequestIdPlugin(),))
324315
]
325316
app = FastAPI(
326317
middleware=middleware,
@@ -426,12 +417,13 @@ async def get_event_publisher(
426417
except anyio.get_cancelled_exc_class() as e:
427418
print("disconnected")
428419
with anyio.move_on_after(1, shield=True):
429-
print(
430-
f"Disconnected from client (via refresh/close) {request.client}"
431-
)
420+
print(f"Disconnected from client (via refresh/close) {request.client}")
432421
raise e
433422

434-
model_field = Field(description="The model to use for generating completions.", default=None)
423+
424+
model_field = Field(
425+
description="The model to use for generating completions.", default=None
426+
)
435427

436428
max_tokens_field = Field(
437429
default=16, ge=1, description="The maximum number of tokens to generate."
@@ -625,9 +617,9 @@ async def create_completion(
625617
]
626618
)
627619

628-
iterator_or_completion: Union[llama_cpp.Completion, Iterator[
629-
llama_cpp.CompletionChunk
630-
]] = await run_in_threadpool(llama, **kwargs)
620+
iterator_or_completion: Union[
621+
llama_cpp.Completion, Iterator[llama_cpp.CompletionChunk]
622+
] = await run_in_threadpool(llama, **kwargs)
631623

632624
if isinstance(iterator_or_completion, Iterator):
633625
# EAFP: It's easier to ask for forgiveness than permission
@@ -641,12 +633,13 @@ def iterator() -> Iterator[llama_cpp.CompletionChunk]:
641633

642634
send_chan, recv_chan = anyio.create_memory_object_stream(10)
643635
return EventSourceResponse(
644-
recv_chan, data_sender_callable=partial( # type: ignore
636+
recv_chan,
637+
data_sender_callable=partial( # type: ignore
645638
get_event_publisher,
646639
request=request,
647640
inner_send_chan=send_chan,
648641
iterator=iterator(),
649-
)
642+
),
650643
)
651644
else:
652645
return iterator_or_completion
@@ -762,9 +755,9 @@ async def create_chat_completion(
762755
]
763756
)
764757

765-
iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[
766-
llama_cpp.ChatCompletionChunk
767-
]] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
758+
iterator_or_completion: Union[
759+
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
760+
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
768761

769762
if isinstance(iterator_or_completion, Iterator):
770763
# EAFP: It's easier to ask for forgiveness than permission
@@ -778,12 +771,13 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
778771

779772
send_chan, recv_chan = anyio.create_memory_object_stream(10)
780773
return EventSourceResponse(
781-
recv_chan, data_sender_callable=partial( # type: ignore
774+
recv_chan,
775+
data_sender_callable=partial( # type: ignore
782776
get_event_publisher,
783777
request=request,
784778
inner_send_chan=send_chan,
785779
iterator=iterator(),
786-
)
780+
),
787781
)
788782
else:
789783
return iterator_or_completion

0 commit comments

Comments
 (0)