Skip to content

Commit 305f371

Browse files
committed
add basic testing code for java and kotlin
1 parent c75b97f commit 305f371

File tree

9 files changed

+355
-218
lines changed

9 files changed

+355
-218
lines changed

src/main/kotlin/com/cjcrafter/openai/MyCallback.kt

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package com.cjcrafter.openai
33
import com.cjcrafter.openai.exception.OpenAIError
44
import com.cjcrafter.openai.exception.WrappedIOError
55
import com.google.gson.JsonObject
6+
import com.google.gson.JsonParseException
67
import com.google.gson.JsonParser
78
import okhttp3.Call
89
import okhttp3.Callback
@@ -45,20 +46,29 @@ internal class MyCallback(
4546

4647
private fun handleStream(response: Response) {
4748
response.body?.source()?.use { source ->
48-
while (!source.exhausted()) {
49-
var jsonResponse = source.readUtf8()
5049

51-
// OpenAI returns a json string, but they prepend the content with
52-
// "data: " (which is not valid json). In order to parse this into
53-
// a JsonObject, we have to strip away this extra string.
54-
jsonResponse = jsonResponse.substring("data: ".length)
50+
while (!source.exhausted()) {
51+
var jsonResponse = source.readUtf8Line()
5552

56-
// After OpenAI's final message (which already contains a non-null
57-
// finish reason), they redundantly send "data: [DONE]". Ignore it.
58-
if (jsonResponse == "[DONE]")
53+
// Or data is separated by empty lines, ignore them. The final
54+
// line is always "data: [DONE]", ignore it.
55+
if (jsonResponse.isNullOrEmpty() || jsonResponse == "data: [DONE]")
5956
continue
6057

61-
val rootObject = JsonParser.parseString(jsonResponse).asJsonObject
58+
// The CHAT API returns a json string, but they prepend the content
59+
// with "data: " (which is not valid json). In order to parse this
60+
// into a JsonObject, we have to strip away this extra string.
61+
if (jsonResponse.startsWith("data: "))
62+
jsonResponse = jsonResponse.substring("data: ".length)
63+
64+
lateinit var rootObject: JsonObject
65+
try {
66+
rootObject = JsonParser.parseString(jsonResponse).asJsonObject
67+
} catch (ex: JsonParseException) {
68+
println(jsonResponse)
69+
ex.printStackTrace()
70+
continue
71+
}
6272

6373
// Sometimes OpenAI will respond with an error code for malformed
6474
// requests, timeouts, rate limits, etc. We need to let the dev

src/main/kotlin/com/cjcrafter/openai/OpenAI.kt

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,14 @@ class OpenAI @JvmOverloads constructor(
262262

263263
try {
264264
val httpResponse = client.newCall(httpRequest).execute()
265+
var response: ChatResponseChunk? = null
265266
MyCallback(true, onFailure) {
266-
val response = gson.fromJson(it, ChatResponseChunk::class.java)
267-
onResponse.accept(response)
267+
if (response == null)
268+
response = gson.fromJson(it, ChatResponseChunk::class.java)
269+
else
270+
response!!.update(it)
271+
272+
onResponse.accept(response!!)
268273
}.onResponse(httpResponse)
269274
} catch (ex: IOException) {
270275
onFailure.accept(WrappedIOError(ex))
@@ -291,17 +296,22 @@ class OpenAI @JvmOverloads constructor(
291296
*/
292297
@JvmOverloads
293298
fun streamChatCompletionAsync(
294-
request: CompletionRequest,
299+
request: ChatRequest,
295300
onResponse: Consumer<ChatResponseChunk>,
296301
onFailure: Consumer<OpenAIError> = Consumer { it.printStackTrace() }
297302
) {
298303
@Suppress("DEPRECATION")
299304
request.stream = true // use requestResponse for stream=false
300305
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
301306

307+
var response: ChatResponseChunk? = null
302308
client.newCall(httpRequest).enqueue(MyCallback(true, onFailure) {
303-
val response = gson.fromJson(it, ChatResponseChunk::class.java)
304-
onResponse.accept(response)
309+
if (response == null)
310+
response = gson.fromJson(it, ChatResponseChunk::class.java)
311+
else
312+
response!!.update(it)
313+
314+
onResponse.accept(response!!)
305315
})
306316
}
307317

src/test/java/JavaChatStreamTest.java

Lines changed: 0 additions & 38 deletions
This file was deleted.

src/test/java/JavaChatTest.java

Lines changed: 0 additions & 57 deletions
This file was deleted.

src/test/java/JavaTest.java

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import com.cjcrafter.openai.OpenAI;
2+
import com.cjcrafter.openai.chat.ChatMessage;
3+
import com.cjcrafter.openai.chat.ChatRequest;
4+
import com.cjcrafter.openai.chat.ChatResponse;
5+
import com.cjcrafter.openai.completions.CompletionRequest;
6+
import com.cjcrafter.openai.exception.OpenAIError;
7+
import io.github.cdimascio.dotenv.Dotenv;
8+
9+
import java.util.ArrayList;
10+
import java.util.Collections;
11+
import java.util.List;
12+
import java.util.Scanner;
13+
14+
public class JavaTest {
15+
16+
// Colors for pretty formatting
17+
public static final String RESET = "\033[0m";
18+
public static final String BLACK = "\033[0;30m";
19+
public static final String RED = "\033[0;31m";
20+
public static final String GREEN = "\033[0;32m";
21+
public static final String YELLOW = "\033[0;33m";
22+
public static final String BLUE = "\033[0;34m";
23+
public static final String PURPLE = "\033[0;35m";
24+
public static final String CYAN = "\033[0;36m";
25+
public static final String WHITE = "\033[0;37m";
26+
27+
public static void main(String[] args) throws OpenAIError {
28+
Scanner scanner = new Scanner(System.in);
29+
30+
// Print out the menu of options
31+
System.out.println(GREEN + "Please select one of the options below by typing a number.");
32+
System.out.println();
33+
System.out.println(GREEN + " 1. Completion (create, sync)");
34+
System.out.println(GREEN + " 2. Completion (stream, sync)");
35+
System.out.println(GREEN + " 3. Completion (create, async)");
36+
System.out.println(GREEN + " 4. Completion (stream, async)");
37+
System.out.println(GREEN + " 5. Chat (create, sync)");
38+
System.out.println(GREEN + " 6. Chat (stream, sync)");
39+
System.out.println(GREEN + " 7. Chat (create, async)");
40+
System.out.println(GREEN + " 8. Chat (stream, async)");
41+
System.out.println();
42+
43+
// Determine which method to call
44+
switch (scanner.nextLine()) {
45+
case "1":
46+
doCompletion(false, false);
47+
break;
48+
case "2":
49+
doCompletion(true, false);
50+
break;
51+
case "3":
52+
doCompletion(false, true);
53+
break;
54+
case "4":
55+
doCompletion(true, true);
56+
break;
57+
case "5":
58+
doChat(false, false);
59+
break;
60+
case "6":
61+
doChat(true, false);
62+
break;
63+
case "7":
64+
doChat(false, true);
65+
break;
66+
case "8":
67+
doChat(true, true);
68+
break;
69+
default:
70+
System.err.println("Invalid option");
71+
break;
72+
}
73+
}
74+
75+
public static void doCompletion(boolean stream, boolean async) throws OpenAIError {
76+
Scanner scan = new Scanner(System.in);
77+
System.out.println(YELLOW + "Enter completion: ");
78+
String input = scan.nextLine();
79+
80+
// CompletionRequest contains the data we sent to the OpenAI API. We use
81+
// 128 tokens, so we have a bit of a delay before the response (for testing).
82+
CompletionRequest request = CompletionRequest.builder()
83+
.model("davinci")
84+
.prompt(input)
85+
.maxTokens(128).build();
86+
87+
// Loads the API key from the .env file in the root directory.
88+
String key = Dotenv.load().get("OPENAI_TOKEN");
89+
OpenAI openai = new OpenAI(key);
90+
91+
System.out.println(RESET + "Generating Response" + PURPLE);
92+
if (stream) {
93+
if (async)
94+
openai.streamCompletionAsync(request, response -> System.out.print(response.get(0).getText()));
95+
else
96+
openai.streamCompletion(request, response -> System.out.print(response.get(0).getText()));
97+
} else {
98+
if (async)
99+
openai.createCompletionAsync(request, response -> System.out.println(response.get(0).getText()));
100+
else
101+
System.out.println(openai.createCompletion(request).get(0).getText());
102+
}
103+
104+
System.out.println(CYAN + " !!! Code has finished executing. Wait for async code to complete." + RESET);
105+
}
106+
107+
public static void doChat(boolean stream, boolean async) throws OpenAIError {
108+
Scanner scan = new Scanner(System.in);
109+
110+
// This is the prompt that the bot will refer back to for every message.
111+
ChatMessage prompt = ChatMessage.toSystemMessage("You are a customer support chat-bot. Write brief summaries of the user's questions so that agents can easily find the answer in a database.");
112+
113+
// Use a mutable (modifiable) list! Always! You should be reusing the
114+
// ChatRequest variable, so in order for a conversation to continue
115+
// you need to be able to modify the list.
116+
List<ChatMessage> messages = new ArrayList<>(Collections.singletonList(prompt));
117+
118+
// ChatRequest is the request we send to OpenAI API. You can modify the
119+
// model, temperature, maxTokens, etc. This should be saved, so you can
120+
// reuse it for a conversation.
121+
ChatRequest request = ChatRequest.builder()
122+
.model("gpt-3.5-turbo")
123+
.messages(messages).build();
124+
125+
// Loads the API key from the .env file in the root directory.
126+
String key = Dotenv.load().get("OPENAI_TOKEN");
127+
OpenAI openai = new OpenAI(key);
128+
129+
// The conversation lasts until the user quits the program
130+
while (true) {
131+
132+
// Prompt the user to enter a response
133+
System.out.println(YELLOW + "Enter text below:\n\n");
134+
String input = scan.nextLine();
135+
136+
// Add the newest user message to the conversation
137+
messages.add(ChatMessage.toUserMessage(input));
138+
139+
System.out.println(RESET + "Generating Response" + PURPLE);
140+
if (stream) {
141+
if (async) {
142+
openai.streamChatCompletionAsync(request, response -> {
143+
System.out.print(response.get(0).getDelta());
144+
if (response.get(0).isFinished())
145+
messages.add(response.get(0).getMessage());
146+
});
147+
} else {
148+
openai.streamChatCompletion(request, response -> {
149+
System.out.print(response.get(0).getDelta());
150+
if (response.get(0).isFinished())
151+
messages.add(response.get(0).getMessage());
152+
});
153+
}
154+
} else {
155+
if (async) {
156+
openai.createChatCompletionAsync(request, response -> {
157+
System.out.println(response.get(0).getMessage().getContent());
158+
messages.add(response.get(0).getMessage());
159+
});
160+
} else {
161+
ChatResponse response = openai.createChatCompletion(request);
162+
System.out.println(response.get(0).getMessage().getContent());
163+
messages.add(response.get(0).getMessage());
164+
}
165+
}
166+
167+
System.out.println(CYAN + " !!! Code has finished executing. Wait for async code to complete.");
168+
}
169+
}
170+
}

0 commit comments

Comments
 (0)