Skip to content

Support models that don't split stream chunks in tokens #8235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions dspy/streaming/streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,18 @@ def __init__(self, signature_field_name: str, predict: Any = None, predict_name:
self.stream_end = False
self.cache_hit = False

self.json_adapter_start_identifier = f'{{"{self.signature_field_name}":"'
self.json_adapter_end_identifier = re.compile(r"\w*\",\w*")
self.json_adapter_start_identifier = f'"{self.signature_field_name}":'
self.json_adapter_end_identifier = re.compile(r"\w*\"(,|\s*})")

self.chat_adapter_start_identifier = f"[[ ## {self.signature_field_name} ## ]]"
self.chat_adapter_end_identifier = re.compile(r"\[\[ ## (\w+) ## \]\]")

def _buffered_message_end_with_start_identifier(self, concat_message: str, start_identifier: str) -> str:
for i in range(len(concat_message)):
if start_identifier.startswith(concat_message[len(concat_message) - i - 1 :]):
return True
return False

def receive(self, chunk: ModelResponseStream):
if isinstance(settings.adapter, JSONAdapter):
start_identifier = self.json_adapter_start_identifier
Expand All @@ -52,7 +58,7 @@ def receive(self, chunk: ModelResponseStream):
start_identifier = self.chat_adapter_start_identifier
end_identifier = self.chat_adapter_end_identifier

start_indicator = "[["
start_indicator = "["
else:
raise ValueError(
f"Unsupported adapter for streaming: {settings.adapter}, please use either ChatAdapter or "
Expand All @@ -71,13 +77,18 @@ def receive(self, chunk: ModelResponseStream):

if chunk_message and start_identifier in chunk_message:
# If the cache is hit, the chunk_message could be the full response. When it happens we can
# directly end the stream listening.
self.cache_hit = True
self.stream_start = True
self.stream_end = True
return
# directly end the stream listening. In some models like gemini, each stream chunk can be multiple
# tokens, so it's posible that response only has one chunk, we also fall back to this logic.
message_after_start_identifier = chunk_message[
chunk_message.find(start_identifier) + len(start_identifier) :
]
if re.search(end_identifier, message_after_start_identifier):
self.cache_hit = True
self.stream_start = True
self.stream_end = True
return

if len(self.field_start_queue) == 0 and start_indicator in chunk_message:
if len(self.field_start_queue) == 0 and not self.stream_start and start_indicator in chunk_message:
# We look for the pattern of start_identifier, i.e., "[[ ## {self.signature_field_name} ## ]]" for
# ChatAdapter to identify the start of the stream of our target field. Once the start_indicator, i.e., "[["
# for ChatAdapter, is found, we start checking the next tokens
Expand All @@ -89,18 +100,28 @@ def receive(self, chunk: ModelResponseStream):
# tokens no longer match our expected identifier.
self.field_start_queue.append(chunk_message)
concat_message = "".join(self.field_start_queue).strip()
start_token_index = concat_message.find(start_indicator)
concat_message = concat_message[start_token_index:]
if start_identifier == concat_message:

if start_identifier in concat_message:
# We have a full identifier, we can start the stream.
self.stream_start = True
elif start_identifier.startswith(concat_message):
# The concanated tokens still match our expected identifier, we keep listening.
self.field_start_queue = []
# Keep the part after the start_identifier from the concat_message, we need to write it to the buffer.
value_start_index = concat_message.find(start_identifier) + len(start_identifier)
chunk_message = concat_message[value_start_index:].lstrip()
if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'):
# For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier
# because there could be a few splitters between ':' and '"', e.g., '"name": "value"'.
chunk_message = chunk_message[1:]

elif self._buffered_message_end_with_start_identifier(concat_message, start_identifier):
# If the buffered message ends with part of the start_identifier, we can start the stream.
return
else:
# Doesn't match the expected identifier, reset the queue.
self.field_start_queue = []
elif self.stream_start:
return

if self.stream_start:
# The stream is started, we keep returning the token until we see the start of the next field.
token = None
self.field_end_queue.put(chunk_message)
Expand Down Expand Up @@ -130,7 +151,11 @@ def flush(self) -> str:
last_tokens = "".join(self.field_end_queue.queue)
self.field_end_queue = Queue()
if isinstance(settings.adapter, JSONAdapter):
boundary_index = last_tokens.find('",')
match = re.search(r'",|"\s*}', last_tokens)
if match:
boundary_index = match.start()
else:
boundary_index = len(last_tokens)
return last_tokens[:boundary_index]
elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None:
boundary_index = last_tokens.find("[[")
Expand Down
Loading