Skip to content

Commit 348fb67

Browse files
committed
azureChat should get endpoint and deployment from env
Removed the positional arguments from `azureChat` constructor, and replaced with NVPs that default to reading from the environment instead. Updated instructions to include those environment variables in `.env`.
1 parent a360f89 commit 348fb67

File tree

5 files changed

+74
-43
lines changed

5 files changed

+74
-43
lines changed

+llms/+utils/errorMessageCatalog.m

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters.";
4444
catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1}).";
4545
catalog("llms:stopSequencesMustHaveMax4Elements") = "Number of elements must not be larger than 4.";
46+
catalog("llms:endpointMustBeSpecified") = "Unable to find endpoint. Either set environment variable AZURE_OPENAI_ENDPOINT or specify name-value argument ""Endpoint"".";
47+
catalog("llms:deploymentMustBeSpecified") = "Unable to find deployment name. Either set environment variable AZURE_OPENAI_DEPLOYMENT or specify name-value argument ""Deployment"".";
4648
catalog("llms:keyMustBeSpecified") = "Unable to find API key. Either set environment variable {1} or specify name-value argument ""APIKey"".";
4749
catalog("llms:mustHaveMessages") = "Value must contain at least one message in Messages.";
4850
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, ToolChoice must not be specified.";

azureChat.m

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,27 @@
22
llms.internal.gptPenalties & llms.internal.hasTools & llms.internal.needsAPIKey
33
%azureChat Chat completion API from Azure.
44
%
5-
% CHAT = azureChat(endpoint, deploymentID) creates an azureChat object with
6-
% the endpoint and deployment ID path parameters required by Azure to
7-
% establish the connection.
5+
% CHAT = azureChat creates an azureChat object, with the parameters needed
6+
% to connect to Azure taken from the environment.
87
%
9-
% CHAT = azureChat(__,systemPrompt) creates an azureChat object with the
8+
% CHAT = azureChat(systemPrompt) creates an azureChat object with the
109
% specified system prompt.
1110
%
1211
% CHAT = azureChat(__,Name=Value) specifies additional options
1312
% using one or more name-value arguments:
1413
%
14+
% Endpoint - The endpoint as defined in the Azure OpenAI Services
15+
% interface. Needs to be specified or stored in the
16+
% environment variable AZURE_OPENAI_ENDPOINT.
17+
%
18+
% Deployment - The deployment as defined in the Azure OpenAI Services
19+
% interface. Needs to be specified or stored in the
20+
% environment variable AZURE_OPENAI_DEPLOYMENT.
21+
%
22+
% APIKey - The API key for accessing the Azure OpenAI Chat API.
23+
% Needs to be specified or stored in the
24+
% environment variable AZURE_OPENAI_API_KEY.
25+
%
1526
% Temperature - Temperature value for controlling the randomness
1627
% of the output. Default value is 1; higher values
1728
% increase the randomness (in some sense,
@@ -33,8 +44,6 @@
3344
% ResponseFormat - The format of response the model returns.
3445
% "text" (default) | "json"
3546
%
36-
% APIKey - The API key for accessing the OpenAI Chat API.
37-
%
3847
% PresencePenalty - Penalty value for using a token in the response
3948
% that has already been used. Default value is 0.
4049
% Higher values reduce repetition of words in the output.
@@ -91,18 +100,18 @@
91100
end
92101

93102
methods
94-
function this = azureChat(endpoint, deploymentID, systemPrompt, nvp)
103+
function this = azureChat(systemPrompt, nvp)
95104
arguments
96-
endpoint {mustBeTextScalar}
97-
deploymentID {mustBeTextScalar}
98105
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
106+
nvp.Endpoint {mustBeNonzeroLengthTextScalar}
107+
nvp.Deployment {mustBeNonzeroLengthTextScalar}
108+
nvp.APIKey {mustBeNonzeroLengthTextScalar}
99109
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
100110
nvp.APIVersion (1,1) {mustBeAPIVersion} = "2024-02-01"
101111
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
102112
nvp.TopP {llms.utils.mustBeValidTopP} = 1
103113
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
104114
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
105-
nvp.APIKey {mustBeNonzeroLengthTextScalar}
106115
nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = 0
107116
nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0
108117
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
@@ -131,16 +140,16 @@
131140
end
132141
end
133142

134-
this.Endpoint = endpoint;
135-
this.DeploymentID = deploymentID;
143+
this.Endpoint = getEndpoint(nvp);
144+
this.DeploymentID = getDeployment(nvp);
145+
this.APIKey = llms.internal.getApiKeyFromNvpOrEnv(nvp,"AZURE_OPENAI_API_KEY");
136146
this.APIVersion = nvp.APIVersion;
137147
this.ResponseFormat = nvp.ResponseFormat;
138148
this.Temperature = nvp.Temperature;
139149
this.TopP = nvp.TopP;
140150
this.StopSequences = nvp.StopSequences;
141151
this.PresencePenalty = nvp.PresencePenalty;
142152
this.FrequencyPenalty = nvp.FrequencyPenalty;
143-
this.APIKey = llms.internal.getApiKeyFromNvpOrEnv(nvp,"AZURE_OPENAI_API_KEY");
144153
this.TimeOut = nvp.TimeOut;
145154
end
146155

@@ -285,3 +294,27 @@ function mustBeIntegerOrEmpty(value)
285294
function mustBeAPIVersion(model)
286295
mustBeMember(model,llms.azure.apiVersions);
287296
end
297+
298+
function endpoint = getEndpoint(nvp)
299+
if isfield(nvp, "Endpoint")
300+
endpoint = nvp.Endpoint;
301+
else
302+
if isenv("AZURE_OPENAI_ENDPOINT")
303+
endpoint = getenv("AZURE_OPENAI_ENDPOINT");
304+
else
305+
error("llms:endpointMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:endpointMustBeSpecified"));
306+
end
307+
end
308+
end
309+
310+
function deployment = getDeployment(nvp)
311+
if isfield(nvp, "Deployment")
312+
deployment = nvp.Deployment;
313+
else
314+
if isenv("AZURE_OPENAI_DEPLOYMENT")
315+
deployment = getenv("AZURE_OPENAI_DEPLOYMENT");
316+
else
317+
error("llms:deploymentMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:deploymentMustBeSpecified"));
318+
end
319+
end
320+
end

doc/Azure.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ Some of the [current LLMs supported on Azure](https://learn.microsoft.com/en-us/
1616
Set up your [endpoint and deployment and retrieve one of the API keys](https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart?tabs=command-line%2Cpython-new&pivots=rest-api#retrieve-key-and-endpoint). Create a `.env` file in the project root directory with the following content.
1717

1818
```
19+
AZURE_OPENAI_ENDPOINT=<your endpoint>
20+
AZURE_OPENAI_DEPLOYMENT=<your deployment>
1921
AZURE_OPENAI_API_KEY=<your key>
2022
```
2123

@@ -29,11 +31,11 @@ loadenv(".env")
2931

3032
## Establishing a connection to Chat Completions API using Azure
3133

32-
To connect MATLAB to Chat Completions API via Azure, you will have to create an `azureChat` object. See [the Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart) for details on the setup required and where to find your key, endpoint, and deployment name. As explained above, the key should be in the environment variable `AZURE_OPENAI_API_KEY`, or provided as `APIKey=…` in the `azureChat` call below.
34+
To connect MATLAB to Chat Completions API via Azure, you will have to create an `azureChat` object. See [the Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart) for details on the setup required and where to find your key, endpoint, and deployment name. As explained above, the endpoint, deployment, and key should be in the environment variables `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENYT`, and `AZURE_OPENAI_API_KEY`, or provided as `Endpoint=…`, `Deployment=…`, and `APIKey=…` in the `azureChat` call below.
3335

34-
In order to create the chat assistant, specify your Azure OpenAI Resource and the LLM you want to use:
36+
In order to create the chat assistant, use the `azureChat` function, optionally providing a system prompt:
3537
```matlab
36-
chat = azureChat(YOUR_ENDPOINT_NAME, YOUR_DEPLOYMENT_NAME, "You are a helpful AI assistant");
38+
chat = azureChat("You are a helpful AI assistant");
3739
```
3840

3941
The `azureChat` object also allows to specify additional options. Call `help azureChat` for more information.
@@ -60,7 +62,7 @@ systemPrompt = "You are a sentiment analyser. You will look at a sentence and ou
6062
"His attitude was terribly discouraging to the team." + newline +...
6163
"negative" + newline + newline;
6264
63-
chat = azureChat(YOUR_ENDPOINT_NAME, YOUR_DEPLOYMENT_NAME, systemPrompt);
65+
chat = azureChat(systemPrompt);
6466
6567
% Generate a response, passing a new sentence for classification
6668
txt = generate(chat,"The team is feeling very motivated")
@@ -80,7 +82,7 @@ history = messageHistory;
8082
Then create the chat assistant:
8183

8284
```matlab
83-
chat = azureChat(YOUR_ENDPOINT_NAME, YOUR_DEPLOYMENT_NAME, "You are a helpful AI assistant.");
85+
chat = azureChat;
8486
```
8587

8688
Add a user message to the history and pass it to `generate`:
@@ -108,7 +110,7 @@ Streaming allows you to start receiving the output from the API as it is generat
108110
```matlab
109111
% streaming function
110112
sf = @(x) fprintf("%s",x);
111-
chat = azureChat(YOUR_ENDPOINT_NAME, YOUR_DEPLOYMENT_NAME, StreamFun=sf);
113+
chat = azureChat(StreamFun=sf);
112114
txt = generate(chat,"What is Model-Based Design and how is it related to Digital Twin?")
113115
% Should stream the response token by token
114116
```
@@ -123,7 +125,7 @@ For example, if you want to use the API for mathematical operations such as `sin
123125
```matlab
124126
f = openAIFunction("sind","Sine of argument in degrees");
125127
f = addParameter(f,"x",type="number",description="Angle in degrees.");
126-
chat = azureChat(YOUR_ENDPOINT_NAME,YOUR_DEPLOYMENT_NAME,"You are a helpful assistant.",Tools=f);
128+
chat = azureChat("You are a helpful assistant.",Tools=f);
127129
```
128130

129131
When the model identifies that it could use the defined functions to answer a query, it will return a `tool_calls` request, instead of directly generating the response:
@@ -217,8 +219,7 @@ f = addParameter(f,"patientSymptoms",type="string",description="Symptoms that th
217219
Note that this function does not need to exist, since it will only be used to extract the Name, Age and Symptoms of the patient and it does not need to be called:
218220

219221
```matlab
220-
chat = azureChat(YOUR_ENDPOINT_NAME, YOUR_DEPLOYMENT_NAME, ...
221-
"You are helpful assistant that reads patient records and extracts information", ...
222+
chat = azureChat("You are helpful assistant that reads patient records and extracts information", ...
222223
Tools=f);
223224
messages = messageHistory;
224225
messages = addUserMessage(messages,"Extract the information from the report:" + newline + patientReport);
@@ -258,4 +259,3 @@ ans =
258259
```
259260

260261
You can extract the arguments and write the data to a table, for example.
261-

functionSignatures.json

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,16 @@
4444
{
4545
"inputs":
4646
[
47-
{"name":"endpoint","kind":"positional","type":["string","scalar"]},
48-
{"name":"deploymentID","kind":"positional","type":["string","scalar"]},
4947
{"name":"systemPrompt","kind":"ordered","type":["string","scalar"]},
48+
{"name":"Endpoint","kind":"namevalue","type":["string","scalar"]},
49+
{"name":"Deployment","kind":"namevalue","type":["string","scalar"]},
50+
{"name":"APIKey","kind":"namevalue","type":["string","scalar"]},
5051
{"name":"Tools","kind":"namevalue","type":"openAIFunction"},
5152
{"name":"APIVersion","kind":"namevalue","type":"choices=llms.azure.apiVersions"},
5253
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
5354
{"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
5455
{"name":"StopSequences","kind":"namevalue","type":["string","vector"]},
5556
{"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"},
56-
{"name":"APIKey","kind":"namevalue","type":["string","scalar"]},
5757
{"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
5858
{"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
5959
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
@@ -108,7 +108,6 @@
108108
[
109109
{"name":"this","kind":"required","type":["ollamaChat","scalar"]},
110110
{"name":"messages","kind":"required","type":[["messageHistory","row"],["string","scalar"]]},
111-
{"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]},
112111
{"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]},
113112
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]}
114113
],

tests/tazureChat.m

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
methods(Test)
1414
function constructChatWithAllNVP(testCase)
15-
endpoint = getenv("AZURE_OPENAI_ENDPOINT");
1615
deploymentID = "hello";
1716
functions = openAIFunction("funName");
1817
temperature = 0;
@@ -23,7 +22,7 @@ function constructChatWithAllNVP(testCase)
2322
frequenceP = 2;
2423
systemPrompt = "This is a system prompt";
2524
timeout = 3;
26-
chat = azureChat(endpoint, deploymentID, systemPrompt, Tools=functions, ...
25+
chat = azureChat(systemPrompt, Deployment=deploymentID, Tools=functions, ...
2726
Temperature=temperature, TopP=topP, StopSequences=stop, APIKey=apiKey,...
2827
FrequencyPenalty=frequenceP, PresencePenalty=presenceP, TimeOut=timeout);
2928
testCase.verifyEqual(chat.Temperature, temperature);
@@ -35,22 +34,22 @@ function constructChatWithAllNVP(testCase)
3534

3635
function doGenerate(testCase,StringInputs)
3736
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
38-
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"));
37+
chat = azureChat;
3938
response = testCase.verifyWarningFree(@() generate(chat,StringInputs));
4039
testCase.verifyClass(response,'string');
4140
testCase.verifyGreaterThan(strlength(response),0);
4241
end
4342

4443
function generateMultipleResponses(testCase)
45-
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"));
44+
chat = azureChat;
4645
[~,~,response] = generate(chat,"What is a cat?",NumCompletions=3);
4746
testCase.verifySize(response.Body.Data.choices,[3,1]);
4847
end
4948

5049

5150
function doReturnErrors(testCase)
5251
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
53-
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"));
52+
chat = azureChat;
5453
% This input is considerably longer than accepted as input for
5554
% GPT-3.5 (16385 tokens)
5655
wayTooLong = string(repmat('a ',1,20000));
@@ -59,7 +58,7 @@ function doReturnErrors(testCase)
5958

6059
function seedFixesResult(testCase)
6160
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
62-
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"));
61+
chat = azureChat;
6362
response1 = generate(chat,"hi",Seed=1234);
6463
response2 = generate(chat,"hi",Seed=1234);
6564
testCase.verifyEqual(response1,response2);
@@ -76,8 +75,7 @@ function createAzureChatWithStreamFunc(testCase)
7675
data = [data, str];
7776
seen = data;
7877
end
79-
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), ...
80-
StreamFun=@sf);
78+
chat = azureChat(StreamFun=@sf);
8179

8280
testCase.verifyWarningFree(@()generate(chat, "Hello world."));
8381
% Checking that persistent data, which is still stored in
@@ -93,8 +91,7 @@ function generateWithTools(testCase)
9391
f = addParameter(f, "location", type="string", description="The city and country, optionally state. E.g., San Francisco, CA, USA");
9492
f = addParameter(f, "unit", type="string", enum=["Kelvin","Celsius"], RequiredParameter=false);
9593

96-
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), ...
97-
Tools=f);
94+
chat = azureChat(Tools=f);
9895

9996
prompt = "What's the weather like in San Francisco, Tokyo, and Paris?";
10097
[~, response] = generate(chat, prompt, ToolChoice="getCurrentWeather");
@@ -108,12 +105,12 @@ function generateWithTools(testCase)
108105
end
109106

110107
function errorsWhenPassingToolChoiceWithEmptyTools(testCase)
111-
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), APIKey="this-is-not-a-real-key");
108+
chat = azureChat(APIKey="this-is-not-a-real-key");
112109
testCase.verifyError(@()generate(chat,"input", ToolChoice="bla"), "llms:mustSetFunctionsForCall");
113110
end
114111

115112
function shortErrorForBadEndpoint(testCase)
116-
chat = azureChat("https://nobodyhere.whatever/","deployment");
113+
chat = azureChat(Endpoint="https://nobodyhere.whatever/");
117114
caught = false;
118115
try
119116
generate(chat,"input");
@@ -126,17 +123,17 @@ function shortErrorForBadEndpoint(testCase)
126123
end
127124

128125
function invalidInputsConstructor(testCase, InvalidConstructorInput)
129-
testCase.verifyError(@()azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);
126+
testCase.verifyError(@()azureChat(InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);
130127
end
131128

132129
function invalidInputsGenerate(testCase, InvalidGenerateInput)
133130
f = openAIFunction("validfunction");
134-
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), Tools=f, APIKey="this-is-not-a-real-key");
131+
chat = azureChat(Tools=f, APIKey="this-is-not-a-real-key");
135132
testCase.verifyError(@()generate(chat,InvalidGenerateInput.Input{:}), InvalidGenerateInput.Error);
136133
end
137134

138135
function invalidSetters(testCase, InvalidValuesSetters)
139-
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), APIKey="this-is-not-a-real-key");
136+
chat = azureChat(APIKey="this-is-not-a-real-key");
140137
function assignValueToProperty(property, value)
141138
chat.(property) = value;
142139
end
@@ -151,7 +148,7 @@ function keyNotFound(testCase)
151148
import matlab.unittest.fixtures.EnvironmentVariableFixture
152149
testCase.applyFixture(EnvironmentVariableFixture("AZURE_OPENAI_API_KEY","dummy"));
153150
unsetenv("AZURE_OPENAI_API_KEY");
154-
testCase.verifyError(@()azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT")), "llms:keyMustBeSpecified");
151+
testCase.verifyError(@()azureChat, "llms:keyMustBeSpecified");
155152
end
156153
end
157154
end

0 commit comments

Comments
 (0)