@@ -287,9 +287,13 @@ def response_loop(self):
287
287
# To signal shutdown a None item will be added to the queue.
288
288
if item is None :
289
289
break
290
- response_sender , response , response_flag = item
290
+ response_state , response , response_flag = item
291
+ response_sender = response_state ["response_sender" ]
291
292
try :
292
293
response_sender .send (response , response_flag )
294
+ # Stop checking for cancellation if the last response is generated.
295
+ if not response_state ["last_response_generated" ]:
296
+ response_state ["is_cancelled" ] = response_sender .is_cancelled ()
293
297
except Exception as e :
294
298
self .logger .log_error (
295
299
f"An error occurred while sending a response: { e } "
@@ -338,6 +342,11 @@ async def generate(self, request):
338
342
Forwards single request to LLM engine and returns responses.
339
343
"""
340
344
response_sender = request .get_response_sender ()
345
+ response_state = {
346
+ "response_sender" : response_sender ,
347
+ "is_cancelled" : False ,
348
+ "last_response_generated" : False , # last response ready but not yet sent
349
+ }
341
350
self .ongoing_request_count += 1
342
351
decrement_ongoing_request_count = True
343
352
try :
@@ -399,10 +408,26 @@ async def generate(self, request):
399
408
)
400
409
401
410
async for output in response_iterator :
402
- if response_sender .is_cancelled ():
411
+ is_cancelled = response_state ["is_cancelled" ]
412
+ if not stream :
413
+ is_cancelled = response_sender .is_cancelled ()
414
+ if is_cancelled :
403
415
self .logger .log_info ("[vllm] Cancelling the request" )
404
416
await self .llm_engine .abort (request_id )
405
417
self .logger .log_info ("[vllm] Successfully cancelled the request" )
418
+ if stream :
419
+ response_state ["last_response_generated" ] = True
420
+ response = pb_utils .InferenceResponse (
421
+ error = pb_utils .TritonError (
422
+ message = "Request was cancelled" ,
423
+ code = pb_utils .TritonError .CANCELLED ,
424
+ )
425
+ )
426
+ flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
427
+ decrement_ongoing_request_count = False
428
+ self ._response_queue .put_nowait (
429
+ (response_state , response , flags )
430
+ )
406
431
break
407
432
if stream :
408
433
prev_outputs_lengths = None
@@ -414,9 +439,10 @@ async def generate(self, request):
414
439
response = self .create_stream_response (output , prev_outputs_lengths )
415
440
flags = 0
416
441
if output .finished :
442
+ response_state ["last_response_generated" ] = True
417
443
flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
418
444
decrement_ongoing_request_count = False
419
- self ._response_queue .put_nowait ((response_sender , response , flags ))
445
+ self ._response_queue .put_nowait ((response_state , response , flags ))
420
446
prev_outputs = output
421
447
422
448
last_output = output
0 commit comments