Skip to content

Commit 410a87b

Browse files
committed
Make StreamFun work with ollamaChat and azureChat
ollama sends a (correct) header of Content-Type: application/ndjson, so we need to switch from StreamConsumer to BinaryConsumer, or the response will never get dispatched. ollama uses a slightly different JSON format, need to branch based on presence of fields. Azure sends JSON broken up in several packets. We remember incomplete lines in the object and reassemble when the next response comes in. Azure sends "choices":[]. Handle by ignoring that line.
1 parent 08a8549 commit 410a87b

File tree

3 files changed

+81
-29
lines changed

3 files changed

+81
-29
lines changed

+llms/+stream/responseStreamer.m

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
classdef responseStreamer < matlab.net.http.io.StringConsumer
1+
classdef responseStreamer < matlab.net.http.io.BinaryConsumer
22
%responseStreamer Responsible for obtaining the streaming results from the
33
%API
44

@@ -7,6 +7,7 @@
77
properties
88
ResponseText
99
StreamFun
10+
Incomplete = ""
1011
end
1112

1213
methods
@@ -20,17 +21,19 @@
2021
if this.Response.StatusCode ~= matlab.net.http.StatusCode.OK
2122
length = 0;
2223
else
23-
length = this.start@matlab.net.http.io.StringConsumer;
24+
length = this.start@matlab.net.http.io.BinaryConsumer;
2425
end
2526
end
2627
end
2728

2829
methods
2930
function [len,stop] = putData(this, data)
30-
[len,stop] = this.putData@matlab.net.http.io.StringConsumer(data);
31+
[len,stop] = this.putData@matlab.net.http.io.BinaryConsumer(data);
3132

3233
% Extract out the response text from the message
3334
str = native2unicode(data','UTF-8');
35+
str = this.Incomplete + string(str);
36+
this.Incomplete = "";
3437
str = split(str,newline);
3538
str = str(strlength(str)>0);
3639
str = erase(str,"data: ");
@@ -43,35 +46,50 @@
4346
try
4447
json = jsondecode(str{i});
4548
catch ME
49+
if i == length(str)
50+
this.Incomplete = str{i};
51+
return;
52+
end
4653
errID = 'llms:stream:responseStreamer:InvalidInput';
4754
msg = "Input does not have the expected json format. " + str{i};
4855
ME = MException(errID,msg);
4956
throw(ME)
5057
end
51-
if ischar(json.choices.finish_reason) && ismember(json.choices.finish_reason,["stop","tool_calls"])
52-
stop = true;
53-
return
54-
else
55-
if isfield(json.choices.delta,"tool_calls")
56-
if isfield(json.choices.delta.tool_calls,"id")
57-
id = json.choices.delta.tool_calls.id;
58-
type = json.choices.delta.tool_calls.type;
59-
fcn = json.choices.delta.tool_calls.function;
60-
s = struct('id',id,'type',type,'function',fcn);
61-
txt = jsonencode(s);
58+
if isfield(json,'choices')
59+
if isempty(json.choices)
60+
continue;
61+
end
62+
if isfield(json.choices,'finish_reason') && ...
63+
ischar(json.choices.finish_reason) && ismember(json.choices.finish_reason,["stop","tool_calls"])
64+
stop = true;
65+
return
66+
else
67+
if isfield(json.choices,"delta") && ...
68+
isfield(json.choices.delta,"tool_calls")
69+
if isfield(json.choices.delta.tool_calls,"id")
70+
id = json.choices.delta.tool_calls.id;
71+
type = json.choices.delta.tool_calls.type;
72+
fcn = json.choices.delta.tool_calls.function;
73+
s = struct('id',id,'type',type,'function',fcn);
74+
txt = jsonencode(s);
75+
else
76+
s = jsondecode(this.ResponseText);
77+
args = json.choices.delta.tool_calls.function.arguments;
78+
s.function.arguments = [s.function.arguments args];
79+
txt = jsonencode(s);
80+
end
81+
this.StreamFun('');
82+
this.ResponseText = txt;
6283
else
63-
s = jsondecode(this.ResponseText);
64-
args = json.choices.delta.tool_calls.function.arguments;
65-
s.function.arguments = [s.function.arguments args];
66-
txt = jsonencode(s);
84+
txt = json.choices.delta.content;
85+
this.StreamFun(txt);
86+
this.ResponseText = [this.ResponseText txt];
6787
end
68-
this.StreamFun('');
69-
this.ResponseText = txt;
70-
else
71-
txt = json.choices.delta.content;
72-
this.StreamFun(txt);
73-
this.ResponseText = [this.ResponseText txt];
7488
end
89+
else
90+
txt = json.message.content;
91+
this.StreamFun(txt);
92+
this.ResponseText = [this.ResponseText txt];
7593
end
7694
end
7795
end

tests/tazureChat.m

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@ function doGenerate(testCase)
4848
testCase.verifyGreaterThan(strlength(response),0);
4949
end
5050

51+
function createOpenAIChatWithStreamFunc(testCase)
52+
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
53+
function seen = sf(str)
54+
persistent data;
55+
if isempty(data)
56+
data = strings(1, 0);
57+
end
58+
% Append streamed text to an empty string array of length 1
59+
data = [data, str];
60+
seen = data;
61+
end
62+
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), ...
63+
StreamFun=@sf);
64+
65+
testCase.verifyWarningFree(@()generate(chat, "Hello world."));
66+
% Checking that persistent data, which is still stored in
67+
% memory, is greater than 1. This would mean that the stream
68+
% function has been called and streamed some text.
69+
testCase.verifyGreaterThan(numel(sf("")), 1);
70+
end
71+
5172
%% Test is currently unreliable, reasons unclear
5273
% function verySmallTimeOutErrors(testCase)
5374
% chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), TimeOut=1e-10, ApiKey="false-key");

tests/tollamaChat.m

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,24 @@ function stopSequences(testCase)
6868
testCase.verifyEqual(response2, extractBefore(response1,"1"));
6969
end
7070

71-
%% Test is currently unreliable, reasons unclear
72-
% function verySmallTimeOutErrors(testCase)
73-
% chat = ollamaChat("mistral", TimeOut=1e-10);
74-
% testCase.verifyError(@() generate(chat, "please count from 1 to 5000"), "MATLAB:webservices:Timeout")
75-
% end
71+
function streamFunc(testCase)
72+
function seen = sf(str)
73+
persistent data;
74+
if isempty(data)
75+
data = strings(1, 0);
76+
end
77+
% Append streamed text to an empty string array of length 1
78+
data = [data, str];
79+
seen = data;
80+
end
81+
chat = ollamaChat("mistral", StreamFun=@sf);
82+
83+
testCase.verifyWarningFree(@()generate(chat, "Hello world."));
84+
% Checking that persistent data, which is still stored in
85+
% memory, is greater than 1. This would mean that the stream
86+
% function has been called and streamed some text.
87+
testCase.verifyGreaterThan(numel(sf("")), 1);
88+
end
7689

7790
function invalidInputsConstructor(testCase, InvalidConstructorInput)
7891
testCase.verifyError(@() ollamaChat("mistral", InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);

0 commit comments

Comments
 (0)