Skip to content

Commit c8bdb6e

Browse files
committed
Merge branch 'main' of github.com:triton-inference-server/vllm_backend into yinggeh-DLIS-7061-add-vllm-metrics
2 parents d22fd03 + 843cbdd commit c8bdb6e

File tree

1 file changed

+85
-19
lines changed

1 file changed

+85
-19
lines changed

src/model.py

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
import asyncio
28+
import gc
2829
import json
2930
import os
31+
import queue
3032
import threading
3133
from typing import Dict, List
3234

@@ -115,13 +117,19 @@ def initialize(self, args):
115117
# Counter to keep track of ongoing request counts
116118
self.ongoing_request_count = 0
117119

120+
# Starting the response thread. It allows vLLM to keep making progress while
121+
# response sender(s) are sending responses to server frontend.
122+
self._response_queue = queue.Queue()
123+
self._response_thread = threading.Thread(target=self.response_loop)
124+
self._response_thread.start()
125+
118126
# Starting asyncio event loop to process the received requests asynchronously.
119127
self._loop = asyncio.get_event_loop()
120-
self._loop_thread = threading.Thread(
128+
self._event_thread = threading.Thread(
121129
target=self.engine_loop, args=(self._loop,)
122130
)
123131
self._shutdown_event = asyncio.Event()
124-
self._loop_thread.start()
132+
self._event_thread.start()
125133

126134
def init_engine(self):
127135
# Currently, Triton needs to use decoupled policy for asynchronously
@@ -290,6 +298,27 @@ def get_sampling_params_dict(self, params_json):
290298

291299
return params_dict
292300

301+
def response_loop(self):
302+
while True:
303+
item = self._response_queue.get()
304+
# To signal shutdown a None item will be added to the queue.
305+
if item is None:
306+
break
307+
response_state, response, response_flag = item
308+
response_sender = response_state["response_sender"]
309+
try:
310+
response_sender.send(response, response_flag)
311+
# Stop checking for cancellation if the last response is generated.
312+
if not response_state["last_response_generated"]:
313+
response_state["is_cancelled"] = response_sender.is_cancelled()
314+
except Exception as e:
315+
self.logger.log_error(
316+
f"An error occurred while sending a response: {e}"
317+
)
318+
finally:
319+
if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL:
320+
self.ongoing_request_count -= 1
321+
293322
def create_response(self, vllm_output, prepend_input):
294323
"""
295324
Parses the output from the vLLM engine into Triton
@@ -330,7 +359,13 @@ async def generate(self, request):
330359
Forwards single request to LLM engine and returns responses.
331360
"""
332361
response_sender = request.get_response_sender()
362+
response_state = {
363+
"response_sender": response_sender,
364+
"is_cancelled": False,
365+
"last_response_generated": False, # last response ready but not yet sent
366+
}
333367
self.ongoing_request_count += 1
368+
decrement_ongoing_request_count = True
334369
try:
335370
request_id = random_uuid()
336371
prompt = pb_utils.get_input_tensor_by_name(
@@ -385,13 +420,31 @@ async def generate(self, request):
385420
lora_local_path = self.lora_repository[lora_name]
386421
lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path)
387422

388-
async for output in self.llm_engine.generate(
389-
prompt, sampling_params, request_id, lora_request=lora_request
390-
):
391-
if response_sender.is_cancelled():
423+
response_iterator = await self.llm_engine.add_request(
424+
request_id, prompt, sampling_params, lora_request=lora_request
425+
)
426+
427+
async for output in response_iterator:
428+
is_cancelled = response_state["is_cancelled"]
429+
if not stream:
430+
is_cancelled = response_sender.is_cancelled()
431+
if is_cancelled:
392432
self.logger.log_info("[vllm] Cancelling the request")
393433
await self.llm_engine.abort(request_id)
394434
self.logger.log_info("[vllm] Successfully cancelled the request")
435+
if stream:
436+
response_state["last_response_generated"] = True
437+
response = pb_utils.InferenceResponse(
438+
error=pb_utils.TritonError(
439+
message="Request was cancelled",
440+
code=pb_utils.TritonError.CANCELLED,
441+
)
442+
)
443+
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
444+
decrement_ongoing_request_count = False
445+
self._response_queue.put_nowait(
446+
(response_state, response, flags)
447+
)
395448
break
396449
if stream:
397450
prev_outputs_lengths = None
@@ -400,15 +453,13 @@ async def generate(self, request):
400453
len(prev_output.text)
401454
for prev_output in prev_outputs.outputs
402455
]
456+
response = self.create_stream_response(output, prev_outputs_lengths)
457+
flags = 0
403458
if output.finished:
404-
response_sender.send(
405-
self.create_stream_response(output, prev_outputs_lengths),
406-
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
407-
)
408-
else:
409-
response_sender.send(
410-
self.create_stream_response(output, prev_outputs_lengths)
411-
)
459+
response_state["last_response_generated"] = True
460+
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
461+
decrement_ongoing_request_count = False
462+
self._response_queue.put_nowait((response_state, response, flags))
412463
prev_outputs = output
413464

414465
last_output = output
@@ -420,7 +471,7 @@ async def generate(self, request):
420471
)
421472

422473
except Exception as e:
423-
self.logger.log_info(f"[vllm] Error generating stream: {e}")
474+
self.logger.log_error(f"[vllm] Error generating stream: {e}")
424475
error = pb_utils.TritonError(f"Error generating stream: {e}")
425476
triton_output_tensor = pb_utils.Tensor(
426477
"text_output", np.asarray(["N/A"], dtype=self.output_dtype)
@@ -433,7 +484,8 @@ async def generate(self, request):
433484
)
434485
raise e
435486
finally:
436-
self.ongoing_request_count -= 1
487+
if decrement_ongoing_request_count:
488+
self.ongoing_request_count -= 1
437489

438490
def verify_loras(self, request):
439491
# We will check if the requested lora exists here, if not we will send a
@@ -500,6 +552,20 @@ def finalize(self):
500552
"""
501553
self.logger.log_info("[vllm] Issuing finalize to vllm backend")
502554
self._shutdown_event.set()
503-
if self._loop_thread is not None:
504-
self._loop_thread.join()
505-
self._loop_thread = None
555+
556+
# Shutdown the event thread.
557+
if self._event_thread is not None:
558+
self._event_thread.join()
559+
self._event_thread = None
560+
561+
# Shutdown the response thread.
562+
self._response_queue.put(None)
563+
if self._response_thread is not None:
564+
self._response_thread.join()
565+
self._response_thread = None
566+
567+
# When using parallel tensors, the stub process may not shutdown due to
568+
# unreleased references, so manually run the garbage collector once.
569+
self.logger.log_info("[vllm] Running Garbage Collector on finalize...")
570+
gc.collect()
571+
self.logger.log_info("[vllm] Garbage Collector on finalize... done")

0 commit comments

Comments
 (0)