Skip to content

Add Azure OpenAI support #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ A community-maintained easy-to-use Java/Kotlin OpenAI API for ChatGPT, Text Comp
## Features
* [Completions](https://platform.openai.com/docs/api-reference/completions)
* [Chat Completions](https://platform.openai.com/docs/api-reference/chat)
* [Azure OpenAI](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference) support via `AzureOpenAI` class

## Installation
For Kotlin DSL (`build.gradle.kts`), add this to your dependencies block:
Expand Down Expand Up @@ -85,6 +86,7 @@ public class JavaChatTest {
}
}
```
To use the Azure OpenAI API, use the `AzureOpenAI` class instead of `OpenAI`.
> **Note**: OpenAI recommends using environment variables for your API token
([Read more](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)).

Expand Down
38 changes: 38 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/AzureOpenAI.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.cjcrafter.openai

import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.RequestBody
import okhttp3.RequestBody.Companion.toRequestBody

/**
* The Azure OpenAI API client.
*
* See {@link OpenAI} for more information.
*
* This class constructs url in the form of: https://<azureBaseUrl>/openai/deployments/<modelName>/<endpoint>?api-version=<apiVersion>
*
* @property azureBaseUrl The base URL for the Azure OpenAI API. Usually https://<your_resource_group>.openai.azure.com
* @property apiVersion The API version to use. Defaults to 2023-03-15-preview.
* @property modelName The model name to use. This is the name of the model deployed to Azure.
*/
class AzureOpenAI @JvmOverloads constructor(
apiKey: String,
organization: String? = null,
client: OkHttpClient = OkHttpClient(),
private val azureBaseUrl: String = "",
private val apiVersion: String = "2023-03-15-preview",
private val modelName: String = ""
) : OpenAI(apiKey, organization, client) {

override fun buildRequest(request: Any, endpoint: String): Request {
val json = gson.toJson(request)
val body: RequestBody = json.toRequestBody(mediaType)
return Request.Builder()
.url("$azureBaseUrl/openai/deployments/$modelName/$endpoint?api-version=$apiVersion")
.addHeader("Content-Type", "application/json")
.addHeader("api-key", apiKey)
.apply { if (organization != null) addHeader("OpenAI-Organization", organization) }
.post(body).build()
}
}
14 changes: 7 additions & 7 deletions src/main/kotlin/com/cjcrafter/openai/OpenAI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ import java.util.function.Consumer
* @property client Controls proxies, timeouts, etc.
* @constructor Create a ChatBot for responding to requests.
*/
class OpenAI @JvmOverloads constructor(
private val apiKey: String,
private val organization: String? = null,
open class OpenAI @JvmOverloads constructor(
protected val apiKey: String,
protected val organization: String? = null,
private val client: OkHttpClient = OkHttpClient()
) {
private val mediaType = "application/json; charset=utf-8".toMediaType()
private val gson = createGson()
protected val mediaType = "application/json; charset=utf-8".toMediaType()
protected val gson = createGson()

private fun buildRequest(request: Any, endpoint: String): Request {
protected open fun buildRequest(request: Any, endpoint: String): Request {
val json = gson.toJson(request)
val body: RequestBody = json.toRequestBody(mediaType)
return Request.Builder()
Expand Down Expand Up @@ -95,7 +95,7 @@ class OpenAI @JvmOverloads constructor(
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)

try {
val httpResponse = client.newCall(httpRequest).execute();
val httpResponse = client.newCall(httpRequest).execute()
lateinit var response: CompletionResponse
OpenAICallback(true, { throw it }) {
response = gson.fromJson(it, CompletionResponse::class.java)
Expand Down
172 changes: 172 additions & 0 deletions src/test/java/JavaTestAzure.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import com.cjcrafter.openai.AzureOpenAI;
import com.cjcrafter.openai.OpenAI;
import com.cjcrafter.openai.chat.ChatMessage;
import com.cjcrafter.openai.chat.ChatRequest;
import com.cjcrafter.openai.chat.ChatResponse;
import com.cjcrafter.openai.completions.CompletionRequest;
import com.cjcrafter.openai.exception.OpenAIError;
import io.github.cdimascio.dotenv.Dotenv;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Scanner;


public class JavaTestAzure {

// Colors for pretty formatting
public static final String RESET = "\033[0m";
public static final String BLACK = "\033[0;30m";
public static final String RED = "\033[0;31m";
public static final String GREEN = "\033[0;32m";
public static final String YELLOW = "\033[0;33m";
public static final String BLUE = "\033[0;34m";
public static final String PURPLE = "\033[0;35m";
public static final String CYAN = "\033[0;36m";
public static final String WHITE = "\033[0;37m";

public static void main(String[] args) throws OpenAIError {
Scanner scanner = new Scanner(System.in);

// Add test cases for AzureOpenAI
System.out.println(GREEN + " 9. Azure Completion (create, sync)");
System.out.println(GREEN + " 10. Azure Completion (stream, sync)");
System.out.println(GREEN + " 11. Azure Completion (create, async)");
System.out.println(GREEN + " 12. Azure Completion (stream, async)");
System.out.println(GREEN + " 13. Azure Chat (create, sync)");
System.out.println(GREEN + " 14. Azure Chat (stream, sync)");
System.out.println(GREEN + " 15. Azure Chat (create, async)");
System.out.println(GREEN + " 16. Azure Chat (stream, async)");
System.out.println();

// Determine which method to call
switch (scanner.nextLine()) {
// ...
case "9":
doCompletionAzure(false, false);
break;
case "10":
doCompletionAzure(true, false);
break;
case "11":
doCompletionAzure(false, true);
break;
case "12":
doCompletionAzure(true, true);
break;
case "13":
doChatAzure(false, false);
break;
case "14":
doChatAzure(true, false);
break;
case "15":
doChatAzure(false, true);
break;
case "16":
doChatAzure(true, true);
break;
default:
System.err.println("Invalid option");
break;
}
}

public static void doCompletionAzure(boolean stream, boolean async) throws OpenAIError {
Scanner scan = new Scanner(System.in);
System.out.println(YELLOW + "Enter completion: ");
String input = scan.nextLine();

// CompletionRequest contains the data we sent to the OpenAI API. We use
// 128 tokens, so we have a bit of a delay before the response (for testing).
CompletionRequest request = CompletionRequest.builder()
.model("davinci")
.prompt(input)
.maxTokens(128).build();

// Loads the API key from the .env file in the root directory.
String key = Dotenv.load().get("OPENAI_TOKEN");
OpenAI openai = new AzureOpenAI(key);
System.out.println(RESET + "Generating Response" + PURPLE);

// Generate a print the message
if (stream) {
if (async)
openai.streamCompletionAsync(request, response -> System.out.print(response.get(0).getText()));
else
openai.streamCompletion(request, response -> System.out.print(response.get(0).getText()));
} else {
if (async)
openai.createCompletionAsync(request, response -> System.out.println(response.get(0).getText()));
else
System.out.println(openai.createCompletion(request).get(0).getText());
}

System.out.println(CYAN + " !!! Code has finished executing. Wait for async code to complete." + RESET);
}

public static void doChatAzure(boolean stream, boolean async) throws OpenAIError {
Scanner scan = new Scanner(System.in);

// This is the prompt that the bot will refer back to for every message.
ChatMessage prompt = ChatMessage.toSystemMessage("You are a helpful chatbot.");

// Use a mutable (modifiable) list! Always! You should be reusing the
// ChatRequest variable, so in order for a conversation to continue
// you need to be able to modify the list.
List<ChatMessage> messages = new ArrayList<>(Collections.singletonList(prompt));

// ChatRequest is the request we send to OpenAI API. You can modify the
// model, temperature, maxTokens, etc. This should be saved, so you can
// reuse it for a conversation.
ChatRequest request = ChatRequest.builder()
.model("gpt-3.5-turbo")
.messages(messages).build();

// Loads the API key from the .env file in the root directory.
String key = Dotenv.load().get("OPENAI_TOKEN");
OpenAI openai = new AzureOpenAI(key);

// The conversation lasts until the user quits the program
while (true) {

// Prompt the user to enter a response
System.out.println(YELLOW + "Enter text below:\n\n");
String input = scan.nextLine();

// Add the newest user message to the conversation
messages.add(ChatMessage.toUserMessage(input));

System.out.println(RESET + "Generating Response" + PURPLE);
if (stream) {
if (async) {
openai.streamChatCompletionAsync(request, response -> {
System.out.print(response.get(0).getDelta());
if (response.get(0).isFinished())
messages.add(response.get(0).getMessage());
});
} else {
openai.streamChatCompletion(request, response -> {
System.out.print(response.get(0).getDelta());
if (response.get(0).isFinished())
messages.add(response.get(0).getMessage());
});
}
} else {
if (async) {
openai.createChatCompletionAsync(request, response -> {
System.out.println(response.get(0).getMessage().getContent());
messages.add(response.get(0).getMessage());
});
} else {
ChatResponse response = openai.createChatCompletion(request);
System.out.println(response.get(0).getMessage().getContent());
messages.add(response.get(0).getMessage());
}
}

System.out.println(CYAN + " !!! Code has finished executing. Wait for async code to complete.");
}
}
}
Loading