25
25
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
import asyncio
28
+ import gc
28
29
import json
29
30
import os
31
+ import queue
30
32
import threading
31
33
from typing import Dict , List
32
34
@@ -113,13 +115,19 @@ def initialize(self, args):
113
115
# Counter to keep track of ongoing request counts
114
116
self .ongoing_request_count = 0
115
117
118
+ # Starting the response thread. It allows vLLM to keep making progress while
119
+ # response sender(s) are sending responses to server frontend.
120
+ self ._response_queue = queue .Queue ()
121
+ self ._response_thread = threading .Thread (target = self .response_loop )
122
+ self ._response_thread .start ()
123
+
116
124
# Starting asyncio event loop to process the received requests asynchronously.
117
125
self ._loop = asyncio .get_event_loop ()
118
- self ._loop_thread = threading .Thread (
126
+ self ._event_thread = threading .Thread (
119
127
target = self .engine_loop , args = (self ._loop ,)
120
128
)
121
129
self ._shutdown_event = asyncio .Event ()
122
- self ._loop_thread .start ()
130
+ self ._event_thread .start ()
123
131
124
132
def init_engine (self ):
125
133
# Currently, Triton needs to use decoupled policy for asynchronously
@@ -273,6 +281,27 @@ def get_sampling_params_dict(self, params_json):
273
281
274
282
return params_dict
275
283
284
+ def response_loop (self ):
285
+ while True :
286
+ item = self ._response_queue .get ()
287
+ # To signal shutdown a None item will be added to the queue.
288
+ if item is None :
289
+ break
290
+ response_sender , response , response_flag = item
291
+ del item
292
+ try :
293
+ response_sender .send (response , response_flag )
294
+ except Exception as e :
295
+ self .logger .log_error (
296
+ f"An error occurred while sending a response: { e } "
297
+ )
298
+ finally :
299
+ if response_flag == pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL :
300
+ self .ongoing_request_count -= 1
301
+ del response_sender
302
+ if self .ongoing_request_count == 0 :
303
+ gc .collect ()
304
+
276
305
def create_response (self , vllm_output , prepend_input ):
277
306
"""
278
307
Parses the output from the vLLM engine into Triton
@@ -314,6 +343,7 @@ async def generate(self, request):
314
343
"""
315
344
response_sender = request .get_response_sender ()
316
345
self .ongoing_request_count += 1
346
+ decrement_ongoing_request_count = True
317
347
try :
318
348
request_id = random_uuid ()
319
349
prompt = pb_utils .get_input_tensor_by_name (
@@ -368,9 +398,11 @@ async def generate(self, request):
368
398
lora_local_path = self .lora_repository [lora_name ]
369
399
lora_request = LoRARequest (lora_id , lora_int_id , lora_local_path )
370
400
371
- async for output in self .llm_engine .generate (
372
- prompt , sampling_params , request_id , lora_request = lora_request
373
- ):
401
+ response_iterator = await self .llm_engine .add_request (
402
+ request_id , prompt , sampling_params , lora_request = lora_request
403
+ )
404
+
405
+ async for output in response_iterator :
374
406
if response_sender .is_cancelled ():
375
407
self .logger .log_info ("[vllm] Cancelling the request" )
376
408
await self .llm_engine .abort (request_id )
@@ -383,15 +415,12 @@ async def generate(self, request):
383
415
len (prev_output .text )
384
416
for prev_output in prev_outputs .outputs
385
417
]
418
+ response = self .create_stream_response (output , prev_outputs_lengths )
419
+ flags = 0
386
420
if output .finished :
387
- response_sender .send (
388
- self .create_stream_response (output , prev_outputs_lengths ),
389
- flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL ,
390
- )
391
- else :
392
- response_sender .send (
393
- self .create_stream_response (output , prev_outputs_lengths )
394
- )
421
+ flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
422
+ decrement_ongoing_request_count = False
423
+ self ._response_queue .put_nowait ((response_sender , response , flags ))
395
424
prev_outputs = output
396
425
397
426
last_output = output
@@ -403,7 +432,7 @@ async def generate(self, request):
403
432
)
404
433
405
434
except Exception as e :
406
- self .logger .log_info (f"[vllm] Error generating stream: { e } " )
435
+ self .logger .log_error (f"[vllm] Error generating stream: { e } " )
407
436
error = pb_utils .TritonError (f"Error generating stream: { e } " )
408
437
triton_output_tensor = pb_utils .Tensor (
409
438
"text_output" , np .asarray (["N/A" ], dtype = self .output_dtype )
@@ -416,7 +445,11 @@ async def generate(self, request):
416
445
)
417
446
raise e
418
447
finally :
419
- self .ongoing_request_count -= 1
448
+ if decrement_ongoing_request_count :
449
+ self .ongoing_request_count -= 1
450
+ del response_sender
451
+ if self .ongoing_request_count == 0 :
452
+ gc .collect ()
420
453
421
454
def verify_loras (self , request ):
422
455
# We will check if the requested lora exists here, if not we will send a
@@ -483,6 +516,14 @@ def finalize(self):
483
516
"""
484
517
self .logger .log_info ("[vllm] Issuing finalize to vllm backend" )
485
518
self ._shutdown_event .set ()
486
- if self ._loop_thread is not None :
487
- self ._loop_thread .join ()
488
- self ._loop_thread = None
519
+
520
+ # Shutdown the event thread.
521
+ if self ._event_thread is not None :
522
+ self ._event_thread .join ()
523
+ self ._event_thread = None
524
+
525
+ # Shutdown the response thread.
526
+ self ._response_queue .put (None )
527
+ if self ._response_thread is not None :
528
+ self ._response_thread .join ()
529
+ self ._response_thread = None
0 commit comments