@@ -71,8 +71,14 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
71
71
})
72
72
content = ""
73
73
last_cmpl_id = None
74
- for data in res :
74
+ for i , data in enumerate ( res ) :
75
75
choice = data ["choices" ][0 ]
76
+ if i == 0 :
77
+ # Check first role message for stream=True
78
+ assert choice ["delta" ]["content" ] == ""
79
+ assert choice ["delta" ]["role" ] == "assistant"
80
+ else :
81
+ assert "role" not in choice ["delta" ]
76
82
assert data ["system_fingerprint" ].startswith ("b" )
77
83
assert "gpt-3.5" in data ["model" ] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
78
84
if last_cmpl_id is None :
@@ -242,12 +248,18 @@ def test_chat_completion_with_timings_per_token():
242
248
"stream" : True ,
243
249
"timings_per_token" : True ,
244
250
})
245
- for data in res :
246
- assert "timings" in data
247
- assert "prompt_per_second" in data ["timings" ]
248
- assert "predicted_per_second" in data ["timings" ]
249
- assert "predicted_n" in data ["timings" ]
250
- assert data ["timings" ]["predicted_n" ] <= 10
251
+ for i , data in enumerate (res ):
252
+ if i == 0 :
253
+ # Check first role message for stream=True
254
+ assert data ["choices" ][0 ]["delta" ]["content" ] == ""
255
+ assert data ["choices" ][0 ]["delta" ]["role" ] == "assistant"
256
+ else :
257
+ assert "role" not in data ["choices" ][0 ]["delta" ]
258
+ assert "timings" in data
259
+ assert "prompt_per_second" in data ["timings" ]
260
+ assert "predicted_per_second" in data ["timings" ]
261
+ assert "predicted_n" in data ["timings" ]
262
+ assert data ["timings" ]["predicted_n" ] <= 10
251
263
252
264
253
265
def test_logprobs ():
@@ -295,17 +307,23 @@ def test_logprobs_stream():
295
307
)
296
308
output_text = ''
297
309
aggregated_text = ''
298
- for data in res :
310
+ for i , data in enumerate ( res ) :
299
311
choice = data .choices [0 ]
300
- if choice .finish_reason is None :
301
- if choice .delta .content :
302
- output_text += choice .delta .content
303
- assert choice .logprobs is not None
304
- assert choice .logprobs .content is not None
305
- for token in choice .logprobs .content :
306
- aggregated_text += token .token
307
- assert token .logprob <= 0.0
308
- assert token .bytes is not None
309
- assert token .top_logprobs is not None
310
- assert len (token .top_logprobs ) > 0
312
+ if i == 0 :
313
+ # Check first role message for stream=True
314
+ assert choice .delta .content == ""
315
+ assert choice .delta .role == "assistant"
316
+ else :
317
+ assert choice .delta .role is None
318
+ if choice .finish_reason is None :
319
+ if choice .delta .content :
320
+ output_text += choice .delta .content
321
+ assert choice .logprobs is not None
322
+ assert choice .logprobs .content is not None
323
+ for token in choice .logprobs .content :
324
+ aggregated_text += token .token
325
+ assert token .logprob <= 0.0
326
+ assert token .bytes is not None
327
+ assert token .top_logprobs is not None
328
+ assert len (token .top_logprobs ) > 0
311
329
assert aggregated_text == output_text
0 commit comments