25
25
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
import asyncio
28
+ import gc
28
29
import json
29
30
import os
31
+ import queue
30
32
import threading
31
33
from typing import Dict , List
32
34
@@ -115,13 +117,19 @@ def initialize(self, args):
115
117
# Counter to keep track of ongoing request counts
116
118
self .ongoing_request_count = 0
117
119
120
+ # Starting the response thread. It allows vLLM to keep making progress while
121
+ # response sender(s) are sending responses to server frontend.
122
+ self ._response_queue = queue .Queue ()
123
+ self ._response_thread = threading .Thread (target = self .response_loop )
124
+ self ._response_thread .start ()
125
+
118
126
# Starting asyncio event loop to process the received requests asynchronously.
119
127
self ._loop = asyncio .get_event_loop ()
120
- self ._loop_thread = threading .Thread (
128
+ self ._event_thread = threading .Thread (
121
129
target = self .engine_loop , args = (self ._loop ,)
122
130
)
123
131
self ._shutdown_event = asyncio .Event ()
124
- self ._loop_thread .start ()
132
+ self ._event_thread .start ()
125
133
126
134
def init_engine (self ):
127
135
# Currently, Triton needs to use decoupled policy for asynchronously
@@ -290,6 +298,27 @@ def get_sampling_params_dict(self, params_json):
290
298
291
299
return params_dict
292
300
301
+ def response_loop (self ):
302
+ while True :
303
+ item = self ._response_queue .get ()
304
+ # To signal shutdown a None item will be added to the queue.
305
+ if item is None :
306
+ break
307
+ response_state , response , response_flag = item
308
+ response_sender = response_state ["response_sender" ]
309
+ try :
310
+ response_sender .send (response , response_flag )
311
+ # Stop checking for cancellation if the last response is generated.
312
+ if not response_state ["last_response_generated" ]:
313
+ response_state ["is_cancelled" ] = response_sender .is_cancelled ()
314
+ except Exception as e :
315
+ self .logger .log_error (
316
+ f"An error occurred while sending a response: { e } "
317
+ )
318
+ finally :
319
+ if response_flag == pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL :
320
+ self .ongoing_request_count -= 1
321
+
293
322
def create_response (self , vllm_output , prepend_input ):
294
323
"""
295
324
Parses the output from the vLLM engine into Triton
@@ -330,7 +359,13 @@ async def generate(self, request):
330
359
Forwards single request to LLM engine and returns responses.
331
360
"""
332
361
response_sender = request .get_response_sender ()
362
+ response_state = {
363
+ "response_sender" : response_sender ,
364
+ "is_cancelled" : False ,
365
+ "last_response_generated" : False , # last response ready but not yet sent
366
+ }
333
367
self .ongoing_request_count += 1
368
+ decrement_ongoing_request_count = True
334
369
try :
335
370
request_id = random_uuid ()
336
371
prompt = pb_utils .get_input_tensor_by_name (
@@ -385,13 +420,31 @@ async def generate(self, request):
385
420
lora_local_path = self .lora_repository [lora_name ]
386
421
lora_request = LoRARequest (lora_id , lora_int_id , lora_local_path )
387
422
388
- async for output in self .llm_engine .generate (
389
- prompt , sampling_params , request_id , lora_request = lora_request
390
- ):
391
- if response_sender .is_cancelled ():
423
+ response_iterator = await self .llm_engine .add_request (
424
+ request_id , prompt , sampling_params , lora_request = lora_request
425
+ )
426
+
427
+ async for output in response_iterator :
428
+ is_cancelled = response_state ["is_cancelled" ]
429
+ if not stream :
430
+ is_cancelled = response_sender .is_cancelled ()
431
+ if is_cancelled :
392
432
self .logger .log_info ("[vllm] Cancelling the request" )
393
433
await self .llm_engine .abort (request_id )
394
434
self .logger .log_info ("[vllm] Successfully cancelled the request" )
435
+ if stream :
436
+ response_state ["last_response_generated" ] = True
437
+ response = pb_utils .InferenceResponse (
438
+ error = pb_utils .TritonError (
439
+ message = "Request was cancelled" ,
440
+ code = pb_utils .TritonError .CANCELLED ,
441
+ )
442
+ )
443
+ flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
444
+ decrement_ongoing_request_count = False
445
+ self ._response_queue .put_nowait (
446
+ (response_state , response , flags )
447
+ )
395
448
break
396
449
if stream :
397
450
prev_outputs_lengths = None
@@ -400,15 +453,13 @@ async def generate(self, request):
400
453
len (prev_output .text )
401
454
for prev_output in prev_outputs .outputs
402
455
]
456
+ response = self .create_stream_response (output , prev_outputs_lengths )
457
+ flags = 0
403
458
if output .finished :
404
- response_sender .send (
405
- self .create_stream_response (output , prev_outputs_lengths ),
406
- flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL ,
407
- )
408
- else :
409
- response_sender .send (
410
- self .create_stream_response (output , prev_outputs_lengths )
411
- )
459
+ response_state ["last_response_generated" ] = True
460
+ flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
461
+ decrement_ongoing_request_count = False
462
+ self ._response_queue .put_nowait ((response_state , response , flags ))
412
463
prev_outputs = output
413
464
414
465
last_output = output
@@ -420,7 +471,7 @@ async def generate(self, request):
420
471
)
421
472
422
473
except Exception as e :
423
- self .logger .log_info (f"[vllm] Error generating stream: { e } " )
474
+ self .logger .log_error (f"[vllm] Error generating stream: { e } " )
424
475
error = pb_utils .TritonError (f"Error generating stream: { e } " )
425
476
triton_output_tensor = pb_utils .Tensor (
426
477
"text_output" , np .asarray (["N/A" ], dtype = self .output_dtype )
@@ -433,7 +484,8 @@ async def generate(self, request):
433
484
)
434
485
raise e
435
486
finally :
436
- self .ongoing_request_count -= 1
487
+ if decrement_ongoing_request_count :
488
+ self .ongoing_request_count -= 1
437
489
438
490
def verify_loras (self , request ):
439
491
# We will check if the requested lora exists here, if not we will send a
@@ -500,6 +552,20 @@ def finalize(self):
500
552
"""
501
553
self .logger .log_info ("[vllm] Issuing finalize to vllm backend" )
502
554
self ._shutdown_event .set ()
503
- if self ._loop_thread is not None :
504
- self ._loop_thread .join ()
505
- self ._loop_thread = None
555
+
556
+ # Shutdown the event thread.
557
+ if self ._event_thread is not None :
558
+ self ._event_thread .join ()
559
+ self ._event_thread = None
560
+
561
+ # Shutdown the response thread.
562
+ self ._response_queue .put (None )
563
+ if self ._response_thread is not None :
564
+ self ._response_thread .join ()
565
+ self ._response_thread = None
566
+
567
+ # When using parallel tensors, the stub process may not shutdown due to
568
+ # unreleased references, so manually run the garbage collector once.
569
+ self .logger .log_info ("[vllm] Running Garbage Collector on finalize..." )
570
+ gc .collect ()
571
+ self .logger .log_info ("[vllm] Garbage Collector on finalize... done" )
0 commit comments