@@ -144,10 +144,8 @@ class ErrorResponseFormatters:
144
144
145
145
@staticmethod
146
146
def context_length_exceeded (
147
- request : Union [
148
- "CreateCompletionRequest" , "CreateChatCompletionRequest"
149
- ],
150
- match , # type: Match[str] # type: ignore
147
+ request : Union ["CreateCompletionRequest" , "CreateChatCompletionRequest" ],
148
+ match , # type: Match[str] # type: ignore
151
149
) -> Tuple [int , ErrorResponse ]:
152
150
"""Formatter for context length exceeded error"""
153
151
@@ -184,10 +182,8 @@ def context_length_exceeded(
184
182
185
183
@staticmethod
186
184
def model_not_found (
187
- request : Union [
188
- "CreateCompletionRequest" , "CreateChatCompletionRequest"
189
- ],
190
- match # type: Match[str] # type: ignore
185
+ request : Union ["CreateCompletionRequest" , "CreateChatCompletionRequest" ],
186
+ match , # type: Match[str] # type: ignore
191
187
) -> Tuple [int , ErrorResponse ]:
192
188
"""Formatter for model_not_found error"""
193
189
@@ -315,12 +311,7 @@ def create_app(settings: Optional[Settings] = None):
315
311
settings = Settings ()
316
312
317
313
middleware = [
318
- Middleware (
319
- RawContextMiddleware ,
320
- plugins = (
321
- plugins .RequestIdPlugin (),
322
- )
323
- )
314
+ Middleware (RawContextMiddleware , plugins = (plugins .RequestIdPlugin (),))
324
315
]
325
316
app = FastAPI (
326
317
middleware = middleware ,
@@ -426,12 +417,13 @@ async def get_event_publisher(
426
417
except anyio .get_cancelled_exc_class () as e :
427
418
print ("disconnected" )
428
419
with anyio .move_on_after (1 , shield = True ):
429
- print (
430
- f"Disconnected from client (via refresh/close) { request .client } "
431
- )
420
+ print (f"Disconnected from client (via refresh/close) { request .client } " )
432
421
raise e
433
422
434
- model_field = Field (description = "The model to use for generating completions." , default = None )
423
+
424
+ model_field = Field (
425
+ description = "The model to use for generating completions." , default = None
426
+ )
435
427
436
428
max_tokens_field = Field (
437
429
default = 16 , ge = 1 , description = "The maximum number of tokens to generate."
@@ -625,9 +617,9 @@ async def create_completion(
625
617
]
626
618
)
627
619
628
- iterator_or_completion : Union [llama_cpp . Completion , Iterator [
629
- llama_cpp .CompletionChunk
630
- ]] = await run_in_threadpool (llama , ** kwargs )
620
+ iterator_or_completion : Union [
621
+ llama_cpp .Completion , Iterator [ llama_cpp . CompletionChunk ]
622
+ ] = await run_in_threadpool (llama , ** kwargs )
631
623
632
624
if isinstance (iterator_or_completion , Iterator ):
633
625
# EAFP: It's easier to ask for forgiveness than permission
@@ -641,12 +633,13 @@ def iterator() -> Iterator[llama_cpp.CompletionChunk]:
641
633
642
634
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
643
635
return EventSourceResponse (
644
- recv_chan , data_sender_callable = partial ( # type: ignore
636
+ recv_chan ,
637
+ data_sender_callable = partial ( # type: ignore
645
638
get_event_publisher ,
646
639
request = request ,
647
640
inner_send_chan = send_chan ,
648
641
iterator = iterator (),
649
- )
642
+ ),
650
643
)
651
644
else :
652
645
return iterator_or_completion
@@ -762,9 +755,9 @@ async def create_chat_completion(
762
755
]
763
756
)
764
757
765
- iterator_or_completion : Union [llama_cpp . ChatCompletion , Iterator [
766
- llama_cpp .ChatCompletionChunk
767
- ]] = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
758
+ iterator_or_completion : Union [
759
+ llama_cpp .ChatCompletion , Iterator [ llama_cpp . ChatCompletionChunk ]
760
+ ] = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
768
761
769
762
if isinstance (iterator_or_completion , Iterator ):
770
763
# EAFP: It's easier to ask for forgiveness than permission
@@ -778,12 +771,13 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
778
771
779
772
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
780
773
return EventSourceResponse (
781
- recv_chan , data_sender_callable = partial ( # type: ignore
774
+ recv_chan ,
775
+ data_sender_callable = partial ( # type: ignore
782
776
get_event_publisher ,
783
777
request = request ,
784
778
inner_send_chan = send_chan ,
785
779
iterator = iterator (),
786
- )
780
+ ),
787
781
)
788
782
else :
789
783
return iterator_or_completion
0 commit comments