Skip to content

Adding support to Azure API #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 78 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
a32e681
Adding support to Azure API
Feb 9, 2024
ccd6961
merge main
ccreutzi May 23, 2024
26b1272
parameterize getApiKeyFromNvpOrEnv, allowing different env variables …
ccreutzi May 27, 2024
8bd236b
get basic Azure connection working
ccreutzi May 27, 2024
e009e86
even smaller timeout; failed to throw an error in GitHub once
ccreutzi May 27, 2024
a6b8d51
add ollamaChat class
ccreutzi May 27, 2024
4304636
CI setup for ollama
ccreutzi May 27, 2024
61a5152
typos
ccreutzi May 27, 2024
9914a55
disable verySmallTimeOutErrors test points, since they are flaky
ccreutzi May 27, 2024
cd9bbe2
Updated README.md for Azure and Ollama
ccreutzi May 28, 2024
8a2ea28
Remove GPT specific penalties from ollamaChat
ccreutzi May 29, 2024
dde7d95
Implement TopProbabilityNum and StopSequences for ollamaChat
ccreutzi Jun 4, 2024
8d351a2
increase default timeout to 120 seconds for ollamaChat
ccreutzi Jun 4, 2024
e229935
add TailFreeSampling_Z, add comment about currently unsupported ollam…
ccreutzi Jun 5, 2024
b67f5ef
new static method ollamaChat.models
ccreutzi Jun 5, 2024
bcf4d85
update API versions, following https://learn.microsoft.com/en-us/azur…
ccreutzi Jun 5, 2024
e51f5eb
typo in help header
ccreutzi Jun 5, 2024
08a8549
add azureChat and ollamaChat to functionSignatures.json
ccreutzi Jun 5, 2024
410a87b
Make StreamFun work with ollamaChat and azureChat
ccreutzi Jun 5, 2024
6b830a2
remove unused defaults, for more realistic coverage numbers
ccreutzi Jun 5, 2024
c349a58
Add test that azureChat with Seed fixes result
ccreutzi Jun 5, 2024
0815bdf
try telling codecov to not worry about test files
ccreutzi Jun 6, 2024
e8a900d
also ignore errorMessageCatalog.m in codecov, since almost all of it …
ccreutzi Jun 6, 2024
a2d65cd
ignore examples/data/* just like data/*
ccreutzi Jun 6, 2024
5851aca
Merge branch 'main' into AzureAPI
ccreutzi Jun 6, 2024
88f054b
remove disabled timeout tests
ccreutzi Jun 6, 2024
5ed246d
Add explanatory comment for missing key test.
ccreutzi Jun 6, 2024
3bd2208
changed wording as requested
ccreutzi Jun 6, 2024
beb41a5
simplify comment
ccreutzi Jun 6, 2024
8d50886
spell Ollama with a capital O
ccreutzi Jun 6, 2024
1a5ebb9
Fix capitalization: APIKey, by MathWorks naming standards.
ccreutzi Jun 6, 2024
72dca78
Codecov has problems with uploaded coverage. Try not using plus signs.
ccreutzi Jun 7, 2024
3639020
add unit tests for edge cases and errors in responseStreamer
ccreutzi Jun 7, 2024
a2b893a
Add test point for function calls
ccreutzi Jun 7, 2024
0dd0e91
CI setup stores the API key in a different variable
ccreutzi Jun 7, 2024
5d2fd99
for better coverage, run tools test through streaming API
ccreutzi Jun 7, 2024
4ae0248
Function calling on Azure.
ccreutzi Jun 7, 2024
735416b
Throw server errors as errors
ccreutzi Jun 10, 2024
95f2f13
For CI, add `$OPENAI_API_KEY` such that `openAIChat` works
ccreutzi Jun 10, 2024
492fefc
Short error messages for bad endpoints
ccreutzi Jun 11, 2024
3000dc4
Nicer help headers
ccreutzi Jun 11, 2024
0973f11
typos
ccreutzi Jun 12, 2024
238e11a
Rename openAIMessages to messageHistory
ccreutzi Jun 12, 2024
101cb59
Add openAIMessages fallback for backward compatibility
ccreutzi Jun 12, 2024
e095802
Modify chatbot example to use Ollama
ccreutzi Jun 12, 2024
f680fb5
Merge branch 'main' into AzureAPI
ccreutzi Jun 12, 2024
7e81d37
minimal and complete test for the backward compatibility function
ccreutzi Jun 12, 2024
2611327
Avoid bogus json
ccreutzi Jun 12, 2024
16a3833
Improve ollamaChat tab completion
ccreutzi Jun 12, 2024
b33dad5
Remove unused error ID
ccreutzi Jun 12, 2024
cdf0971
add test for Ollama chatbot example
ccreutzi Jun 13, 2024
3abfae4
mark trademarks
ccreutzi Jun 13, 2024
d577d36
Include ._* in .gitignore
ccreutzi Jun 13, 2024
466460d
Rename `TopProbabilityMass` → `TopP`, `TopProbabilityNum` → `TopK`
ccreutzi Jun 17, 2024
24059e9
Take properties out of `ollamaChat` that do not apply: Tools and API key
ccreutzi Jun 17, 2024
632ac22
Merge branch 'main' into AzureAPI
ccreutzi Jun 17, 2024
d5eb25d
Only drop `:latest` from model list
ccreutzi Jun 17, 2024
6108d0f
ollamaChat.models should never return <missing>
ccreutzi Jun 19, 2024
5200776
fix indentations changed by renaming `TopProbabilityMass` to `TopP`
ccreutzi Jun 19, 2024
32547fb
accept char and cellstr input for generate
ccreutzi Jun 19, 2024
4bd315c
split README.md by backend
ccreutzi Jun 19, 2024
fa9f06e
`FunctionNames` should only exist for connectors with tools
ccreutzi Jun 20, 2024
8336cbb
Fix link typo
ccreutzi Jun 20, 2024
0543e3e
Fix typo: This is not using OpenAI
ccreutzi Jun 20, 2024
3ebc529
update tests to expect correct errors
ccreutzi Jun 20, 2024
71f88a9
tabs to spaces
ccreutzi Jun 20, 2024
949fe42
`openAIImages` should derive from `needsAPIKey`
ccreutzi Jun 20, 2024
09b8662
test `NumCompletions`
ccreutzi Jun 20, 2024
a360f89
Ollama does not support `NumCompletions`
ccreutzi Jun 20, 2024
348fb67
`azureChat` should get endpoint and deployment from env
ccreutzi Jun 20, 2024
4fb866c
clean up comments
ccreutzi Jun 21, 2024
e29e65c
move error text to catalogue
ccreutzi Jun 21, 2024
b4eb44a
Reorder test points for maintainability
ccreutzi Jun 21, 2024
757f260
Add a streaming example for `ollamaChat`
ccreutzi Jun 21, 2024
56e6aaf
Log Ollama version during CI
ccreutzi Jun 24, 2024
f6a9106
Add Seed test to tollamaChat.m
ccreutzi Jun 24, 2024
392749f
Use message from error catalog
ccreutzi Jun 24, 2024
1ac24ff
Disable flaky test points
ccreutzi Jun 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions +llms/+azure/apiVersions.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
function versions = apiVersions
%VERSIONS - supported azure API versions

% Copyright 2024 The MathWorks, Inc.
versions = [...
"2024-05-01-preview", ...
"2024-04-01-preview", ...
"2024-03-01-preview", ...
"2024-02-01", ...
"2023-05-15", ...
];
end
130 changes: 130 additions & 0 deletions +llms/+internal/callAzureChatAPI.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
function [text, message, response] = callAzureChatAPI(endpoint, deploymentID, messages, functions, nvp)
% This function is undocumented and will change in a future release

%callAzureChatAPI Calls the openAI chat completions API on Azure.
%
% MESSAGES and FUNCTIONS should be structs matching the json format
% required by the OpenAI Chat Completions API.
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
%
% More details on the parameters: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/chatgpt
%
% Example
%
% % Create messages struct
% messages = {struct("role", "system",...
% "content", "You are a helpful assistant");
% struct("role", "user", ...
% "content", "What is the edit distance between hi and hello?")};
%
% % Create functions struct
% functions = {struct("name", "editDistance", ...
% "description", "Find edit distance between two strings or documents.", ...
% "parameters", struct( ...
% "type", "object", ...
% "properties", struct(...
% "str1", struct(...
% "description", "Source string.", ...
% "type", "string"),...
% "str2", struct(...
% "description", "Target string.", ...
% "type", "string")),...
% "required", ["str1", "str2"]))};
%
% % Define your API key
% apiKey = "your-api-key-here"
%
% % Send a request
% [text, message] = llms.internal.callAzureChatAPI(messages, functions, APIKey=apiKey)

% Copyright 2023-2024 The MathWorks, Inc.

arguments
endpoint
deploymentID
messages
functions
nvp.ToolChoice
nvp.APIVersion
nvp.Temperature
nvp.TopP
nvp.NumCompletions
nvp.StopSequences
nvp.MaxNumTokens
nvp.PresencePenalty
nvp.FrequencyPenalty
nvp.ResponseFormat
nvp.Seed
nvp.APIKey
nvp.TimeOut
nvp.StreamFun
end

URL = endpoint + "openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion;

parameters = buildParametersCall(messages, functions, nvp);

[response, streamedText] = llms.internal.sendRequest(parameters,nvp.APIKey, URL, nvp.TimeOut, nvp.StreamFun);

% If call errors, "choices" will not be part of response.Body.Data, instead
% we get response.Body.Data.error
if response.StatusCode=="OK"
% Outputs the first generation
if isempty(nvp.StreamFun)
message = response.Body.Data.choices(1).message;
else
message = struct("role", "assistant", ...
"content", streamedText);
end
if isfield(message, "tool_choice")
text = "";
else
text = string(message.content);
end
else
text = "";
message = struct();
end
end

function parameters = buildParametersCall(messages, functions, nvp)
% Builds a struct in the format that is expected by the API, combining
% MESSAGES, FUNCTIONS and parameters in NVP.

parameters = struct();
parameters.messages = messages;

parameters.stream = ~isempty(nvp.StreamFun);

if ~isempty(functions)
parameters.tools = functions;
end

if ~isempty(nvp.ToolChoice)
parameters.tool_choice = nvp.ToolChoice;
end

if ~isempty(nvp.Seed)
parameters.seed = nvp.Seed;
end

dict = mapNVPToParameters;

nvpOptions = keys(dict);
for opt = nvpOptions.'
if isfield(nvp, opt)
parameters.(dict(opt)) = nvp.(opt);
end
end
end

function dict = mapNVPToParameters()
dict = dictionary();
dict("Temperature") = "temperature";
dict("TopP") = "top_p";
dict("NumCompletions") = "n";
dict("StopSequences") = "stop";
dict("MaxNumTokens") = "max_tokens";
dict("PresencePenalty") = "presence_penalty";
dict("FrequencyPenalty") = "frequency_penalty";
end
106 changes: 106 additions & 0 deletions +llms/+internal/callOllamaChatAPI.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
function [text, message, response] = callOllamaChatAPI(model, messages, nvp)
% This function is undocumented and will change in a future release

%callOllamaChatAPI Calls the Ollama® chat completions API.
%
% MESSAGES and FUNCTIONS should be structs matching the json format
% required by the Ollama Chat Completions API.
% Ref: https://github.com/ollama/ollama/blob/main/docs/api.md
%
% More details on the parameters: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
%
% Example
%
% model = "mistral";
%
% % Create messages struct
% messages = {struct("role", "system",...
% "content", "You are a helpful assistant");
% struct("role", "user", ...
% "content", "What is the edit distance between hi and hello?")};
%
% % Send a request
% [text, message] = llms.internal.callOllamaChatAPI(model, messages)

% Copyright 2023-2024 The MathWorks, Inc.

arguments
model
messages
nvp.Temperature
nvp.TopP
nvp.TopK
nvp.TailFreeSamplingZ
nvp.StopSequences
nvp.MaxNumTokens
nvp.ResponseFormat
nvp.Seed
nvp.TimeOut
nvp.StreamFun
end

URL = "http://localhost:11434/api/chat";

% The JSON for StopSequences must have an array, and cannot say "stop": "foo".
% The easiest way to ensure that is to never pass in a scalar …
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about: "To ensure that the JSON for StopSequences has a non-scalar array (is a non-scalar array?), create a two-element array if a scalar is passed."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in JSON, a one-element array is still an array and would be fine. MATLAB's jsonencode just doesn't create "stop": ["foo"] as we need it to. But we can make it produce "stop": ["foo","foo"] instead.

if isscalar(nvp.StopSequences)
nvp.StopSequences = [nvp.StopSequences, nvp.StopSequences];
end

parameters = buildParametersCall(model, messages, nvp);

[response, streamedText] = llms.internal.sendRequest(parameters,[],URL,nvp.TimeOut,nvp.StreamFun);

% If call errors, "choices" will not be part of response.Body.Data, instead
% we get response.Body.Data.error
if response.StatusCode=="OK"
% Outputs the first generation
if isempty(nvp.StreamFun)
message = response.Body.Data.message;
else
message = struct("role", "assistant", ...
"content", streamedText);
end
text = string(message.content);
else
text = "";
message = struct();
end
end

function parameters = buildParametersCall(model, messages, nvp)
% Builds a struct in the format that is expected by the API, combining
% MESSAGES, FUNCTIONS and parameters in NVP.

parameters = struct();
parameters.model = model;
parameters.messages = messages;

parameters.stream = ~isempty(nvp.StreamFun);

options = struct;
if ~isempty(nvp.Seed)
options.seed = nvp.Seed;
end

dict = mapNVPToParameters;

nvpOptions = keys(dict);
for opt = nvpOptions.'
if isfield(nvp, opt) && ~isempty(nvp.(opt)) && ~isequaln(nvp.(opt),Inf)
options.(dict(opt)) = nvp.(opt);
end
end

parameters.options = options;
end

function dict = mapNVPToParameters()
dict = dictionary();
dict("Temperature") = "temperature";
dict("TopP") = "top_p";
dict("TopK") = "top_k";
dict("TailFreeSamplingZ") = "tfs_z";
dict("StopSequences") = "stop";
dict("MaxNumTokens") = "num_predict";
end
51 changes: 19 additions & 32 deletions +llms/+internal/callOpenAIChatAPI.m
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
function [text, message, response] = callOpenAIChatAPI(messages, functions, nvp)
% This function is undocumented and will change in a future release

%callOpenAIChatAPI Calls the openAI chat completions API.
%
% MESSAGES and FUNCTIONS should be structs matching the json format
% required by the OpenAI Chat Completions API.
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
%
% Currently, the supported NVP are, including the equivalent name in the API:
% - ToolChoice (tool_choice)
% - ModelName (model)
% - Temperature (temperature)
% - TopProbabilityMass (top_p)
% - NumCompletions (n)
% - StopSequences (stop)
% - MaxNumTokens (max_tokens)
% - PresencePenalty (presence_penalty)
% - FrequencyPenalty (frequence_penalty)
% - ResponseFormat (response_format)
% - Seed (seed)
% - ApiKey
% - TimeOut
% - StreamFun
% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create
%
% Example
Expand Down Expand Up @@ -48,34 +35,34 @@
% apiKey = "your-api-key-here"
%
% % Send a request
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, APIKey=apiKey)

% Copyright 2023-2024 The MathWorks, Inc.

arguments
messages
functions
nvp.ToolChoice = []
nvp.ModelName = "gpt-3.5-turbo"
nvp.Temperature = 1
nvp.TopProbabilityMass = 1
nvp.NumCompletions = 1
nvp.StopSequences = []
nvp.MaxNumTokens = inf
nvp.PresencePenalty = 0
nvp.FrequencyPenalty = 0
nvp.ResponseFormat = "text"
nvp.Seed = []
nvp.ApiKey = ""
nvp.TimeOut = 10
nvp.StreamFun = []
nvp.ToolChoice
nvp.ModelName
nvp.Temperature
nvp.TopP
nvp.NumCompletions
nvp.StopSequences
nvp.MaxNumTokens
nvp.PresencePenalty
nvp.FrequencyPenalty
nvp.ResponseFormat
nvp.Seed
nvp.APIKey
nvp.TimeOut
nvp.StreamFun
end

END_POINT = "https://api.openai.com/v1/chat/completions";

parameters = buildParametersCall(messages, functions, nvp);

[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun);
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.APIKey, END_POINT, nvp.TimeOut, nvp.StreamFun);

% If call errors, "choices" will not be part of response.Body.Data, instead
% we get response.Body.Data.error
Expand Down Expand Up @@ -160,7 +147,7 @@
function dict = mapNVPToParameters()
dict = dictionary();
dict("Temperature") = "temperature";
dict("TopProbabilityMass") = "top_p";
dict("TopP") = "top_p";
dict("NumCompletions") = "n";
dict("StopSequences") = "stop";
dict("MaxNumTokens") = "max_tokens";
Expand Down
26 changes: 13 additions & 13 deletions +llms/+internal/getApiKeyFromNvpOrEnv.m
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
function key = getApiKeyFromNvpOrEnv(nvp)
function key = getApiKeyFromNvpOrEnv(nvp,envVarName)
% This function is undocumented and will change in a future release

%getApiKeyFromNvpOrEnv Retrieves an API key from a Name-Value Pair struct or environment variable.
%
% This function takes a struct nvp containing name-value pairs and checks
% if it contains a field called "ApiKey". If the field is not found,
% the function attempts to retrieve the API key from an environment
% variable called "OPENAI_API_KEY". If both methods fail, the function
% throws an error.
% This function takes a struct nvp containing name-value pairs and checks if
% it contains a field called "APIKey". If the field is not found, the
% function attempts to retrieve the API key from an environment variable
% whose name is given as the second argument. If both methods fail, the
% function throws an error.

% Copyright 2023 The MathWorks, Inc.
% Copyright 2023-2024 The MathWorks, Inc.

if isfield(nvp, "ApiKey")
key = nvp.ApiKey;
if isfield(nvp, "APIKey")
key = nvp.APIKey;
else
if isenv("OPENAI_API_KEY")
key = getenv("OPENAI_API_KEY");
if isenv(envVarName)
key = getenv(envVarName);
else
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified"));
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified", envVarName));
end
end
end
end
12 changes: 12 additions & 0 deletions +llms/+internal/gptPenalties.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
classdef (Abstract) gptPenalties
% This class is undocumented and will change in a future release

% Copyright 2024 The MathWorks, Inc.
properties
%PRESENCEPENALTY Penalty for using a token in the response that has already been used.
PresencePenalty {llms.utils.mustBeValidPenalty} = 0

%FREQUENCYPENALTY Penalty for using a token that is frequent in the training data.
FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0
end
end
Loading