-
Notifications
You must be signed in to change notification settings - Fork 38
Adding support to Azure API #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a32e681
ccd6961
26b1272
8bd236b
e009e86
a6b8d51
4304636
61a5152
9914a55
cd9bbe2
8a2ea28
dde7d95
8d351a2
e229935
b67f5ef
bcf4d85
e51f5eb
08a8549
410a87b
6b830a2
c349a58
0815bdf
e8a900d
a2d65cd
5851aca
88f054b
5ed246d
3bd2208
beb41a5
8d50886
1a5ebb9
72dca78
3639020
a2b893a
0dd0e91
5d2fd99
4ae0248
735416b
95f2f13
492fefc
3000dc4
0973f11
238e11a
101cb59
e095802
f680fb5
7e81d37
2611327
16a3833
b33dad5
cdf0971
3abfae4
d577d36
466460d
24059e9
632ac22
d5eb25d
6108d0f
5200776
32547fb
4bd315c
fa9f06e
8336cbb
0543e3e
3ebc529
71f88a9
949fe42
09b8662
a360f89
348fb67
4fb866c
e29e65c
b4eb44a
757f260
56e6aaf
f6a9106
392749f
1ac24ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
function versions = apiVersions | ||
%VERSIONS - supported azure API versions | ||
|
||
% Copyright 2024 The MathWorks, Inc. | ||
versions = [... | ||
"2024-05-01-preview", ... | ||
"2024-04-01-preview", ... | ||
"2024-03-01-preview", ... | ||
"2024-02-01", ... | ||
"2023-05-15", ... | ||
]; | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
function [text, message, response] = callAzureChatAPI(endpoint, deploymentID, messages, functions, nvp) | ||
% This function is undocumented and will change in a future release | ||
|
||
%callAzureChatAPI Calls the openAI chat completions API on Azure. | ||
% | ||
% MESSAGES and FUNCTIONS should be structs matching the json format | ||
% required by the OpenAI Chat Completions API. | ||
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api | ||
% | ||
% More details on the parameters: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/chatgpt | ||
% | ||
% Example | ||
% | ||
% % Create messages struct | ||
% messages = {struct("role", "system",... | ||
% "content", "You are a helpful assistant"); | ||
% struct("role", "user", ... | ||
% "content", "What is the edit distance between hi and hello?")}; | ||
% | ||
% % Create functions struct | ||
% functions = {struct("name", "editDistance", ... | ||
% "description", "Find edit distance between two strings or documents.", ... | ||
% "parameters", struct( ... | ||
% "type", "object", ... | ||
% "properties", struct(... | ||
% "str1", struct(... | ||
% "description", "Source string.", ... | ||
% "type", "string"),... | ||
% "str2", struct(... | ||
% "description", "Target string.", ... | ||
% "type", "string")),... | ||
% "required", ["str1", "str2"]))}; | ||
% | ||
% % Define your API key | ||
% apiKey = "your-api-key-here" | ||
% | ||
% % Send a request | ||
% [text, message] = llms.internal.callAzureChatAPI(messages, functions, APIKey=apiKey) | ||
|
||
% Copyright 2023-2024 The MathWorks, Inc. | ||
|
||
arguments | ||
endpoint | ||
deploymentID | ||
messages | ||
functions | ||
nvp.ToolChoice | ||
nvp.APIVersion | ||
nvp.Temperature | ||
nvp.TopP | ||
nvp.NumCompletions | ||
nvp.StopSequences | ||
nvp.MaxNumTokens | ||
nvp.PresencePenalty | ||
nvp.FrequencyPenalty | ||
nvp.ResponseFormat | ||
nvp.Seed | ||
nvp.APIKey | ||
nvp.TimeOut | ||
nvp.StreamFun | ||
end | ||
|
||
URL = endpoint + "openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion; | ||
|
||
parameters = buildParametersCall(messages, functions, nvp); | ||
|
||
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.APIKey, URL, nvp.TimeOut, nvp.StreamFun); | ||
|
||
% If call errors, "choices" will not be part of response.Body.Data, instead | ||
% we get response.Body.Data.error | ||
if response.StatusCode=="OK" | ||
% Outputs the first generation | ||
if isempty(nvp.StreamFun) | ||
message = response.Body.Data.choices(1).message; | ||
else | ||
message = struct("role", "assistant", ... | ||
"content", streamedText); | ||
end | ||
if isfield(message, "tool_choice") | ||
text = ""; | ||
else | ||
text = string(message.content); | ||
end | ||
else | ||
text = ""; | ||
message = struct(); | ||
end | ||
end | ||
|
||
function parameters = buildParametersCall(messages, functions, nvp) | ||
% Builds a struct in the format that is expected by the API, combining | ||
% MESSAGES, FUNCTIONS and parameters in NVP. | ||
|
||
parameters = struct(); | ||
parameters.messages = messages; | ||
|
||
parameters.stream = ~isempty(nvp.StreamFun); | ||
|
||
if ~isempty(functions) | ||
parameters.tools = functions; | ||
end | ||
|
||
if ~isempty(nvp.ToolChoice) | ||
parameters.tool_choice = nvp.ToolChoice; | ||
end | ||
|
||
if ~isempty(nvp.Seed) | ||
parameters.seed = nvp.Seed; | ||
end | ||
|
||
dict = mapNVPToParameters; | ||
|
||
nvpOptions = keys(dict); | ||
for opt = nvpOptions.' | ||
if isfield(nvp, opt) | ||
parameters.(dict(opt)) = nvp.(opt); | ||
end | ||
end | ||
end | ||
|
||
function dict = mapNVPToParameters() | ||
dict = dictionary(); | ||
dict("Temperature") = "temperature"; | ||
dict("TopP") = "top_p"; | ||
dict("NumCompletions") = "n"; | ||
dict("StopSequences") = "stop"; | ||
dict("MaxNumTokens") = "max_tokens"; | ||
dict("PresencePenalty") = "presence_penalty"; | ||
dict("FrequencyPenalty") = "frequency_penalty"; | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
function [text, message, response] = callOllamaChatAPI(model, messages, nvp) | ||
% This function is undocumented and will change in a future release | ||
|
||
%callOllamaChatAPI Calls the Ollama® chat completions API. | ||
% | ||
% MESSAGES and FUNCTIONS should be structs matching the json format | ||
% required by the Ollama Chat Completions API. | ||
% Ref: https://github.com/ollama/ollama/blob/main/docs/api.md | ||
% | ||
% More details on the parameters: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values | ||
% | ||
% Example | ||
% | ||
% model = "mistral"; | ||
% | ||
% % Create messages struct | ||
% messages = {struct("role", "system",... | ||
% "content", "You are a helpful assistant"); | ||
% struct("role", "user", ... | ||
% "content", "What is the edit distance between hi and hello?")}; | ||
% | ||
% % Send a request | ||
% [text, message] = llms.internal.callOllamaChatAPI(model, messages) | ||
|
||
% Copyright 2023-2024 The MathWorks, Inc. | ||
|
||
arguments | ||
model | ||
messages | ||
nvp.Temperature | ||
nvp.TopP | ||
nvp.TopK | ||
nvp.TailFreeSamplingZ | ||
nvp.StopSequences | ||
nvp.MaxNumTokens | ||
nvp.ResponseFormat | ||
nvp.Seed | ||
nvp.TimeOut | ||
nvp.StreamFun | ||
end | ||
|
||
URL = "http://localhost:11434/api/chat"; | ||
|
||
% The JSON for StopSequences must have an array, and cannot say "stop": "foo". | ||
% The easiest way to ensure that is to never pass in a scalar … | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about: "To ensure that the JSON for StopSequences has a non-scalar array (is a non-scalar array?), create a two-element array if a scalar is passed." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, in JSON, a one-element array is still an array and would be fine. MATLAB's |
||
if isscalar(nvp.StopSequences) | ||
nvp.StopSequences = [nvp.StopSequences, nvp.StopSequences]; | ||
end | ||
|
||
parameters = buildParametersCall(model, messages, nvp); | ||
|
||
[response, streamedText] = llms.internal.sendRequest(parameters,[],URL,nvp.TimeOut,nvp.StreamFun); | ||
|
||
% If call errors, "choices" will not be part of response.Body.Data, instead | ||
% we get response.Body.Data.error | ||
if response.StatusCode=="OK" | ||
% Outputs the first generation | ||
if isempty(nvp.StreamFun) | ||
message = response.Body.Data.message; | ||
else | ||
message = struct("role", "assistant", ... | ||
"content", streamedText); | ||
end | ||
text = string(message.content); | ||
else | ||
text = ""; | ||
message = struct(); | ||
end | ||
end | ||
|
||
function parameters = buildParametersCall(model, messages, nvp) | ||
% Builds a struct in the format that is expected by the API, combining | ||
% MESSAGES, FUNCTIONS and parameters in NVP. | ||
|
||
parameters = struct(); | ||
parameters.model = model; | ||
parameters.messages = messages; | ||
|
||
parameters.stream = ~isempty(nvp.StreamFun); | ||
|
||
options = struct; | ||
if ~isempty(nvp.Seed) | ||
options.seed = nvp.Seed; | ||
end | ||
|
||
dict = mapNVPToParameters; | ||
|
||
nvpOptions = keys(dict); | ||
for opt = nvpOptions.' | ||
if isfield(nvp, opt) && ~isempty(nvp.(opt)) && ~isequaln(nvp.(opt),Inf) | ||
options.(dict(opt)) = nvp.(opt); | ||
end | ||
end | ||
|
||
parameters.options = options; | ||
end | ||
|
||
function dict = mapNVPToParameters() | ||
dict = dictionary(); | ||
dict("Temperature") = "temperature"; | ||
dict("TopP") = "top_p"; | ||
dict("TopK") = "top_k"; | ||
dict("TailFreeSamplingZ") = "tfs_z"; | ||
dict("StopSequences") = "stop"; | ||
dict("MaxNumTokens") = "num_predict"; | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,23 @@ | ||
function key = getApiKeyFromNvpOrEnv(nvp) | ||
function key = getApiKeyFromNvpOrEnv(nvp,envVarName) | ||
ccreutzi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
% This function is undocumented and will change in a future release | ||
|
||
%getApiKeyFromNvpOrEnv Retrieves an API key from a Name-Value Pair struct or environment variable. | ||
% | ||
% This function takes a struct nvp containing name-value pairs and checks | ||
% if it contains a field called "ApiKey". If the field is not found, | ||
% the function attempts to retrieve the API key from an environment | ||
% variable called "OPENAI_API_KEY". If both methods fail, the function | ||
% throws an error. | ||
% This function takes a struct nvp containing name-value pairs and checks if | ||
% it contains a field called "APIKey". If the field is not found, the | ||
% function attempts to retrieve the API key from an environment variable | ||
% whose name is given as the second argument. If both methods fail, the | ||
% function throws an error. | ||
|
||
% Copyright 2023 The MathWorks, Inc. | ||
% Copyright 2023-2024 The MathWorks, Inc. | ||
|
||
if isfield(nvp, "ApiKey") | ||
key = nvp.ApiKey; | ||
if isfield(nvp, "APIKey") | ||
key = nvp.APIKey; | ||
else | ||
if isenv("OPENAI_API_KEY") | ||
key = getenv("OPENAI_API_KEY"); | ||
if isenv(envVarName) | ||
key = getenv(envVarName); | ||
else | ||
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified")); | ||
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified", envVarName)); | ||
end | ||
end | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
classdef (Abstract) gptPenalties | ||
% This class is undocumented and will change in a future release | ||
|
||
% Copyright 2024 The MathWorks, Inc. | ||
properties | ||
%PRESENCEPENALTY Penalty for using a token in the response that has already been used. | ||
PresencePenalty {llms.utils.mustBeValidPenalty} = 0 | ||
|
||
%FREQUENCYPENALTY Penalty for using a token that is frequent in the training data. | ||
FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0 | ||
end | ||
end |
Uh oh!
There was an error while loading. Please reload this page.