Skip to content

Commit 128abc3

Browse files
Tabriziankthui
andauthored
perf: Improve vLLM backend performance by using a separate thread for responses (#46)
Co-authored-by: Jacky <[email protected]>
1 parent 05c5a8b commit 128abc3

File tree

1 file changed

+59
-18
lines changed

1 file changed

+59
-18
lines changed

src/model.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
import asyncio
28+
import gc
2829
import json
2930
import os
31+
import queue
3032
import threading
3133
from typing import Dict, List
3234

@@ -113,13 +115,19 @@ def initialize(self, args):
113115
# Counter to keep track of ongoing request counts
114116
self.ongoing_request_count = 0
115117

118+
# Starting the response thread. It allows vLLM to keep making progress while
119+
# response sender(s) are sending responses to server frontend.
120+
self._response_queue = queue.Queue()
121+
self._response_thread = threading.Thread(target=self.response_loop)
122+
self._response_thread.start()
123+
116124
# Starting asyncio event loop to process the received requests asynchronously.
117125
self._loop = asyncio.get_event_loop()
118-
self._loop_thread = threading.Thread(
126+
self._event_thread = threading.Thread(
119127
target=self.engine_loop, args=(self._loop,)
120128
)
121129
self._shutdown_event = asyncio.Event()
122-
self._loop_thread.start()
130+
self._event_thread.start()
123131

124132
def init_engine(self):
125133
# Currently, Triton needs to use decoupled policy for asynchronously
@@ -273,6 +281,27 @@ def get_sampling_params_dict(self, params_json):
273281

274282
return params_dict
275283

284+
def response_loop(self):
285+
while True:
286+
item = self._response_queue.get()
287+
# To signal shutdown a None item will be added to the queue.
288+
if item is None:
289+
break
290+
response_sender, response, response_flag = item
291+
del item
292+
try:
293+
response_sender.send(response, response_flag)
294+
except Exception as e:
295+
self.logger.log_error(
296+
f"An error occurred while sending a response: {e}"
297+
)
298+
finally:
299+
if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL:
300+
self.ongoing_request_count -= 1
301+
del response_sender
302+
if self.ongoing_request_count == 0:
303+
gc.collect()
304+
276305
def create_response(self, vllm_output, prepend_input):
277306
"""
278307
Parses the output from the vLLM engine into Triton
@@ -314,6 +343,7 @@ async def generate(self, request):
314343
"""
315344
response_sender = request.get_response_sender()
316345
self.ongoing_request_count += 1
346+
decrement_ongoing_request_count = True
317347
try:
318348
request_id = random_uuid()
319349
prompt = pb_utils.get_input_tensor_by_name(
@@ -368,9 +398,11 @@ async def generate(self, request):
368398
lora_local_path = self.lora_repository[lora_name]
369399
lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path)
370400

371-
async for output in self.llm_engine.generate(
372-
prompt, sampling_params, request_id, lora_request=lora_request
373-
):
401+
response_iterator = await self.llm_engine.add_request(
402+
request_id, prompt, sampling_params, lora_request=lora_request
403+
)
404+
405+
async for output in response_iterator:
374406
if response_sender.is_cancelled():
375407
self.logger.log_info("[vllm] Cancelling the request")
376408
await self.llm_engine.abort(request_id)
@@ -383,15 +415,12 @@ async def generate(self, request):
383415
len(prev_output.text)
384416
for prev_output in prev_outputs.outputs
385417
]
418+
response = self.create_stream_response(output, prev_outputs_lengths)
419+
flags = 0
386420
if output.finished:
387-
response_sender.send(
388-
self.create_stream_response(output, prev_outputs_lengths),
389-
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
390-
)
391-
else:
392-
response_sender.send(
393-
self.create_stream_response(output, prev_outputs_lengths)
394-
)
421+
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
422+
decrement_ongoing_request_count = False
423+
self._response_queue.put_nowait((response_sender, response, flags))
395424
prev_outputs = output
396425

397426
last_output = output
@@ -403,7 +432,7 @@ async def generate(self, request):
403432
)
404433

405434
except Exception as e:
406-
self.logger.log_info(f"[vllm] Error generating stream: {e}")
435+
self.logger.log_error(f"[vllm] Error generating stream: {e}")
407436
error = pb_utils.TritonError(f"Error generating stream: {e}")
408437
triton_output_tensor = pb_utils.Tensor(
409438
"text_output", np.asarray(["N/A"], dtype=self.output_dtype)
@@ -416,7 +445,11 @@ async def generate(self, request):
416445
)
417446
raise e
418447
finally:
419-
self.ongoing_request_count -= 1
448+
if decrement_ongoing_request_count:
449+
self.ongoing_request_count -= 1
450+
del response_sender
451+
if self.ongoing_request_count == 0:
452+
gc.collect()
420453

421454
def verify_loras(self, request):
422455
# We will check if the requested lora exists here, if not we will send a
@@ -483,6 +516,14 @@ def finalize(self):
483516
"""
484517
self.logger.log_info("[vllm] Issuing finalize to vllm backend")
485518
self._shutdown_event.set()
486-
if self._loop_thread is not None:
487-
self._loop_thread.join()
488-
self._loop_thread = None
519+
520+
# Shutdown the event thread.
521+
if self._event_thread is not None:
522+
self._event_thread.join()
523+
self._event_thread = None
524+
525+
# Shutdown the response thread.
526+
self._response_queue.put(None)
527+
if self._response_thread is not None:
528+
self._response_thread.join()
529+
self._response_thread = None

0 commit comments

Comments
 (0)