Skip to content

Commit 261d703

Browse files
committed
2 parents 54483b5 + 68aeb73 commit 261d703

File tree

5 files changed

+357
-7
lines changed

5 files changed

+357
-7
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ A community-maintained easy-to-use Java/Kotlin OpenAI API for ChatGPT, Text Comp
1313
## Features
1414
* [Completions](https://platform.openai.com/docs/api-reference/completions)
1515
* [Chat Completions](https://platform.openai.com/docs/api-reference/chat)
16+
* [Azure OpenAI](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference) support via `AzureOpenAI` class
1617

1718
## Installation
1819
For Kotlin DSL (`build.gradle.kts`), add this to your dependencies block:
@@ -85,6 +86,7 @@ public class JavaChatTest {
8586
}
8687
}
8788
```
89+
To use the Azure OpenAI API, use the `AzureOpenAI` class instead of `OpenAI`.
8890
> **Note**: OpenAI recommends using environment variables for your API token
8991
([Read more](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)).
9092

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package com.cjcrafter.openai
2+
3+
import okhttp3.OkHttpClient
4+
import okhttp3.Request
5+
import okhttp3.RequestBody
6+
import okhttp3.RequestBody.Companion.toRequestBody
7+
8+
/**
9+
* The Azure OpenAI API client.
10+
*
11+
* See {@link OpenAI} for more information.
12+
*
13+
* This class constructs url in the form of: https://<azureBaseUrl>/openai/deployments/<modelName>/<endpoint>?api-version=<apiVersion>
14+
*
15+
* @property azureBaseUrl The base URL for the Azure OpenAI API. Usually https://<your_resource_group>.openai.azure.com
16+
* @property apiVersion The API version to use. Defaults to 2023-03-15-preview.
17+
* @property modelName The model name to use. This is the name of the model deployed to Azure.
18+
*/
19+
class AzureOpenAI @JvmOverloads constructor(
20+
apiKey: String,
21+
organization: String? = null,
22+
client: OkHttpClient = OkHttpClient(),
23+
private val azureBaseUrl: String = "",
24+
private val apiVersion: String = "2023-03-15-preview",
25+
private val modelName: String = ""
26+
) : OpenAI(apiKey, organization, client) {
27+
28+
override fun buildRequest(request: Any, endpoint: String): Request {
29+
val json = gson.toJson(request)
30+
val body: RequestBody = json.toRequestBody(mediaType)
31+
return Request.Builder()
32+
.url("$azureBaseUrl/openai/deployments/$modelName/$endpoint?api-version=$apiVersion")
33+
.addHeader("Content-Type", "application/json")
34+
.addHeader("api-key", apiKey)
35+
.apply { if (organization != null) addHeader("OpenAI-Organization", organization) }
36+
.post(body).build()
37+
}
38+
}

src/main/kotlin/com/cjcrafter/openai/OpenAI.kt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ import java.util.function.Consumer
5252
* @property client Controls proxies, timeouts, etc.
5353
* @constructor Create a ChatBot for responding to requests.
5454
*/
55-
class OpenAI @JvmOverloads constructor(
56-
private val apiKey: String,
57-
private val organization: String? = null,
55+
open class OpenAI @JvmOverloads constructor(
56+
protected val apiKey: String,
57+
protected val organization: String? = null,
5858
private val client: OkHttpClient = OkHttpClient()
5959
) {
60-
private val mediaType = "application/json; charset=utf-8".toMediaType()
61-
private val gson = createGson()
60+
protected val mediaType = "application/json; charset=utf-8".toMediaType()
61+
protected val gson = createGson()
6262

63-
private fun buildRequest(request: Any, endpoint: String): Request {
63+
protected open fun buildRequest(request: Any, endpoint: String): Request {
6464
val json = gson.toJson(request)
6565
val body: RequestBody = json.toRequestBody(mediaType)
6666
return Request.Builder()
@@ -95,7 +95,7 @@ class OpenAI @JvmOverloads constructor(
9595
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
9696

9797
try {
98-
val httpResponse = client.newCall(httpRequest).execute();
98+
val httpResponse = client.newCall(httpRequest).execute()
9999
lateinit var response: CompletionResponse
100100
OpenAICallback(true, { throw it }) {
101101
response = gson.fromJson(it, CompletionResponse::class.java)

src/test/java/JavaTestAzure.java

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import com.cjcrafter.openai.AzureOpenAI;
2+
import com.cjcrafter.openai.OpenAI;
3+
import com.cjcrafter.openai.chat.ChatMessage;
4+
import com.cjcrafter.openai.chat.ChatRequest;
5+
import com.cjcrafter.openai.chat.ChatResponse;
6+
import com.cjcrafter.openai.completions.CompletionRequest;
7+
import com.cjcrafter.openai.exception.OpenAIError;
8+
import io.github.cdimascio.dotenv.Dotenv;
9+
10+
import java.util.ArrayList;
11+
import java.util.Collections;
12+
import java.util.List;
13+
import java.util.Scanner;
14+
15+
16+
public class JavaTestAzure {
17+
18+
// Colors for pretty formatting
19+
public static final String RESET = "\033[0m";
20+
public static final String BLACK = "\033[0;30m";
21+
public static final String RED = "\033[0;31m";
22+
public static final String GREEN = "\033[0;32m";
23+
public static final String YELLOW = "\033[0;33m";
24+
public static final String BLUE = "\033[0;34m";
25+
public static final String PURPLE = "\033[0;35m";
26+
public static final String CYAN = "\033[0;36m";
27+
public static final String WHITE = "\033[0;37m";
28+
29+
public static void main(String[] args) throws OpenAIError {
30+
Scanner scanner = new Scanner(System.in);
31+
32+
// Add test cases for AzureOpenAI
33+
System.out.println(GREEN + " 9. Azure Completion (create, sync)");
34+
System.out.println(GREEN + " 10. Azure Completion (stream, sync)");
35+
System.out.println(GREEN + " 11. Azure Completion (create, async)");
36+
System.out.println(GREEN + " 12. Azure Completion (stream, async)");
37+
System.out.println(GREEN + " 13. Azure Chat (create, sync)");
38+
System.out.println(GREEN + " 14. Azure Chat (stream, sync)");
39+
System.out.println(GREEN + " 15. Azure Chat (create, async)");
40+
System.out.println(GREEN + " 16. Azure Chat (stream, async)");
41+
System.out.println();
42+
43+
// Determine which method to call
44+
switch (scanner.nextLine()) {
45+
// ...
46+
case "9":
47+
doCompletionAzure(false, false);
48+
break;
49+
case "10":
50+
doCompletionAzure(true, false);
51+
break;
52+
case "11":
53+
doCompletionAzure(false, true);
54+
break;
55+
case "12":
56+
doCompletionAzure(true, true);
57+
break;
58+
case "13":
59+
doChatAzure(false, false);
60+
break;
61+
case "14":
62+
doChatAzure(true, false);
63+
break;
64+
case "15":
65+
doChatAzure(false, true);
66+
break;
67+
case "16":
68+
doChatAzure(true, true);
69+
break;
70+
default:
71+
System.err.println("Invalid option");
72+
break;
73+
}
74+
}
75+
76+
public static void doCompletionAzure(boolean stream, boolean async) throws OpenAIError {
77+
Scanner scan = new Scanner(System.in);
78+
System.out.println(YELLOW + "Enter completion: ");
79+
String input = scan.nextLine();
80+
81+
// CompletionRequest contains the data we sent to the OpenAI API. We use
82+
// 128 tokens, so we have a bit of a delay before the response (for testing).
83+
CompletionRequest request = CompletionRequest.builder()
84+
.model("davinci")
85+
.prompt(input)
86+
.maxTokens(128).build();
87+
88+
// Loads the API key from the .env file in the root directory.
89+
String key = Dotenv.load().get("OPENAI_TOKEN");
90+
OpenAI openai = new AzureOpenAI(key);
91+
System.out.println(RESET + "Generating Response" + PURPLE);
92+
93+
// Generate a print the message
94+
if (stream) {
95+
if (async)
96+
openai.streamCompletionAsync(request, response -> System.out.print(response.get(0).getText()));
97+
else
98+
openai.streamCompletion(request, response -> System.out.print(response.get(0).getText()));
99+
} else {
100+
if (async)
101+
openai.createCompletionAsync(request, response -> System.out.println(response.get(0).getText()));
102+
else
103+
System.out.println(openai.createCompletion(request).get(0).getText());
104+
}
105+
106+
System.out.println(CYAN + " !!! Code has finished executing. Wait for async code to complete." + RESET);
107+
}
108+
109+
public static void doChatAzure(boolean stream, boolean async) throws OpenAIError {
110+
Scanner scan = new Scanner(System.in);
111+
112+
// This is the prompt that the bot will refer back to for every message.
113+
ChatMessage prompt = ChatMessage.toSystemMessage("You are a helpful chatbot.");
114+
115+
// Use a mutable (modifiable) list! Always! You should be reusing the
116+
// ChatRequest variable, so in order for a conversation to continue
117+
// you need to be able to modify the list.
118+
List<ChatMessage> messages = new ArrayList<>(Collections.singletonList(prompt));
119+
120+
// ChatRequest is the request we send to OpenAI API. You can modify the
121+
// model, temperature, maxTokens, etc. This should be saved, so you can
122+
// reuse it for a conversation.
123+
ChatRequest request = ChatRequest.builder()
124+
.model("gpt-3.5-turbo")
125+
.messages(messages).build();
126+
127+
// Loads the API key from the .env file in the root directory.
128+
String key = Dotenv.load().get("OPENAI_TOKEN");
129+
OpenAI openai = new AzureOpenAI(key);
130+
131+
// The conversation lasts until the user quits the program
132+
while (true) {
133+
134+
// Prompt the user to enter a response
135+
System.out.println(YELLOW + "Enter text below:\n\n");
136+
String input = scan.nextLine();
137+
138+
// Add the newest user message to the conversation
139+
messages.add(ChatMessage.toUserMessage(input));
140+
141+
System.out.println(RESET + "Generating Response" + PURPLE);
142+
if (stream) {
143+
if (async) {
144+
openai.streamChatCompletionAsync(request, response -> {
145+
System.out.print(response.get(0).getDelta());
146+
if (response.get(0).isFinished())
147+
messages.add(response.get(0).getMessage());
148+
});
149+
} else {
150+
openai.streamChatCompletion(request, response -> {
151+
System.out.print(response.get(0).getDelta());
152+
if (response.get(0).isFinished())
153+
messages.add(response.get(0).getMessage());
154+
});
155+
}
156+
} else {
157+
if (async) {
158+
openai.createChatCompletionAsync(request, response -> {
159+
System.out.println(response.get(0).getMessage().getContent());
160+
messages.add(response.get(0).getMessage());
161+
});
162+
} else {
163+
ChatResponse response = openai.createChatCompletion(request);
164+
System.out.println(response.get(0).getMessage().getContent());
165+
messages.add(response.get(0).getMessage());
166+
}
167+
}
168+
169+
System.out.println(CYAN + " !!! Code has finished executing. Wait for async code to complete.");
170+
}
171+
}
172+
}

0 commit comments

Comments
 (0)