Skip to content

Commit 843cbdd

Browse files
kthuiTabrizian
andauthored
perf: Check for cancellation on response thread (#54)
Co-authored-by: Iman Tabrizian <[email protected]>
1 parent a345a1d commit 843cbdd

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

src/model.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,13 @@ def response_loop(self):
287287
# To signal shutdown a None item will be added to the queue.
288288
if item is None:
289289
break
290-
response_sender, response, response_flag = item
290+
response_state, response, response_flag = item
291+
response_sender = response_state["response_sender"]
291292
try:
292293
response_sender.send(response, response_flag)
294+
# Stop checking for cancellation if the last response is generated.
295+
if not response_state["last_response_generated"]:
296+
response_state["is_cancelled"] = response_sender.is_cancelled()
293297
except Exception as e:
294298
self.logger.log_error(
295299
f"An error occurred while sending a response: {e}"
@@ -338,6 +342,11 @@ async def generate(self, request):
338342
Forwards single request to LLM engine and returns responses.
339343
"""
340344
response_sender = request.get_response_sender()
345+
response_state = {
346+
"response_sender": response_sender,
347+
"is_cancelled": False,
348+
"last_response_generated": False, # last response ready but not yet sent
349+
}
341350
self.ongoing_request_count += 1
342351
decrement_ongoing_request_count = True
343352
try:
@@ -399,10 +408,26 @@ async def generate(self, request):
399408
)
400409

401410
async for output in response_iterator:
402-
if response_sender.is_cancelled():
411+
is_cancelled = response_state["is_cancelled"]
412+
if not stream:
413+
is_cancelled = response_sender.is_cancelled()
414+
if is_cancelled:
403415
self.logger.log_info("[vllm] Cancelling the request")
404416
await self.llm_engine.abort(request_id)
405417
self.logger.log_info("[vllm] Successfully cancelled the request")
418+
if stream:
419+
response_state["last_response_generated"] = True
420+
response = pb_utils.InferenceResponse(
421+
error=pb_utils.TritonError(
422+
message="Request was cancelled",
423+
code=pb_utils.TritonError.CANCELLED,
424+
)
425+
)
426+
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
427+
decrement_ongoing_request_count = False
428+
self._response_queue.put_nowait(
429+
(response_state, response, flags)
430+
)
406431
break
407432
if stream:
408433
prev_outputs_lengths = None
@@ -414,9 +439,10 @@ async def generate(self, request):
414439
response = self.create_stream_response(output, prev_outputs_lengths)
415440
flags = 0
416441
if output.finished:
442+
response_state["last_response_generated"] = True
417443
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
418444
decrement_ongoing_request_count = False
419-
self._response_queue.put_nowait((response_sender, response, flags))
445+
self._response_queue.put_nowait((response_state, response, flags))
420446
prev_outputs = output
421447

422448
last_output = output

0 commit comments

Comments
 (0)