Skip to content

Commit 15fbf58

Browse files
committed
add completion streaming support
1 parent 0017d93 commit 15fbf58

File tree

4 files changed

+203
-12
lines changed

4 files changed

+203
-12
lines changed

src/main/kotlin/com/cjcrafter/openai/OpenAI.kt

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import com.cjcrafter.openai.gson.ChatChoiceChunkAdapter
44
import com.cjcrafter.openai.chat.*
55
import com.cjcrafter.openai.completions.CompletionRequest
66
import com.cjcrafter.openai.completions.CompletionResponse
7+
import com.cjcrafter.openai.completions.CompletionResponseChunk
78
import com.cjcrafter.openai.exception.OpenAIError
89
import com.cjcrafter.openai.exception.WrappedIOError
910
import com.cjcrafter.openai.gson.ChatUserAdapter
@@ -16,6 +17,7 @@ import okhttp3.*
1617
import okhttp3.MediaType.Companion.toMediaType
1718
import okhttp3.RequestBody.Companion.toRequestBody
1819
import java.io.IOException
20+
import java.lang.IllegalStateException
1921
import java.util.function.Consumer
2022

2123
/**
@@ -56,22 +58,27 @@ class OpenAI @JvmOverloads constructor(
5658
.post(body).build()
5759
}
5860

61+
/**
62+
* Create completion
63+
*
64+
* @param request
65+
* @return
66+
* @since 1.3.0
67+
*/
5968
@Throws(OpenAIError::class)
6069
fun createCompletion(request: CompletionRequest): CompletionResponse {
6170
@Suppress("DEPRECATION")
6271
request.stream = false // use streamCompletion for stream=true
6372
val httpRequest = buildRequest(request, "completions")
6473

65-
// Save the JsonObject to check for errors
66-
var rootObject: JsonObject?
6774
try {
6875
client.newCall(httpRequest).execute().use { response ->
6976

7077
// Servers respond to API calls with json blocks. Since raw JSON isn't
7178
// very developer friendly, we wrap for easy data access.
72-
rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject
73-
if (rootObject!!.has("error"))
74-
throw OpenAIError.fromJson(rootObject!!.get("error").asJsonObject)
79+
val rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject
80+
if (rootObject.has("error"))
81+
throw OpenAIError.fromJson(rootObject.get("error").asJsonObject)
7582

7683
return gson.fromJson(rootObject, CompletionResponse::class.java)
7784
}
@@ -81,6 +88,78 @@ class OpenAI @JvmOverloads constructor(
8188
}
8289
}
8390

91+
/**
92+
* Helper method to call [streamCompletion].
93+
*
94+
* @param request The input information for ChatGPT.
95+
* @param onResponse The method to call for each chunk.
96+
* @since 1.3.0
97+
*/
98+
fun streamCompletionKotlin(request: CompletionRequest, onResponse: CompletionResponseChunk.() -> Unit) {
99+
streamCompletion(request, { it.onResponse() })
100+
}
101+
102+
/**
103+
* This method does not block the thread. Method calls to [onResponse] are
104+
* not handled by the main thread. It is crucial to consider thread safety
105+
* within the context of your program.
106+
*
107+
* @param request The input information for ChatGPT.
108+
* @param onResponse The method to call for each chunk.
109+
* @param onFailure The method to call if the HTTP fails. This method will
110+
* not be called if OpenAI returns an error.
111+
* @see createCompletion
112+
* @see streamCompletionKotlin
113+
* @since 1.3.0
114+
*/
115+
@JvmOverloads
116+
fun streamCompletion(
117+
request: CompletionRequest,
118+
onResponse: Consumer<CompletionResponseChunk>, // use Consumer instead of Kotlin for better Java syntax
119+
onFailure: Consumer<OpenAIError> = Consumer { it.printStackTrace() }
120+
) {
121+
@Suppress("DEPRECATION")
122+
request.stream = true // use requestResponse for stream=false
123+
val httpRequest = buildRequest(request, "completions")
124+
125+
client.newCall(httpRequest).enqueue(object : Callback {
126+
127+
override fun onFailure(call: Call, e: IOException) {
128+
onFailure.accept(WrappedIOError(e))
129+
}
130+
131+
override fun onResponse(call: Call, response: Response) {
132+
response.body?.source()?.use { source ->
133+
while (!source.exhausted()) {
134+
135+
// Parse the JSON string as a map. Every string starts
136+
// with "data: ", so we need to remove that.
137+
var jsonResponse = source.readUtf8Line() ?: continue
138+
if (jsonResponse.isEmpty())
139+
continue
140+
141+
// TODO comment
142+
if (!jsonResponse.startsWith("data: ")) {
143+
System.err.println(jsonResponse)
144+
continue
145+
}
146+
147+
jsonResponse = jsonResponse.substring("data: ".length)
148+
if (jsonResponse == "[DONE]")
149+
continue
150+
151+
val rootObject = JsonParser.parseString(jsonResponse).asJsonObject
152+
if (rootObject.has("error"))
153+
throw OpenAIError.fromJson(rootObject.get("error").asJsonObject)
154+
155+
val cache = gson.fromJson(rootObject, CompletionResponseChunk::class.java)
156+
onResponse.accept(cache)
157+
}
158+
}
159+
}
160+
})
161+
}
162+
84163
/**
85164
* Blocks the current thread until OpenAI responds to https request. The
86165
* returned value includes information including tokens, generated text,
@@ -97,16 +176,14 @@ class OpenAI @JvmOverloads constructor(
97176
request.stream = false // use streamResponse for stream=true
98177
val httpRequest = buildRequest(request, "chat/completions")
99178

100-
// Save the JsonObject to check for errors
101-
var rootObject: JsonObject?
102179
try {
103180
client.newCall(httpRequest).execute().use { response ->
104181

105182
// Servers respond to API calls with json blocks. Since raw JSON isn't
106183
// very developer friendly, we wrap for easy data access.
107-
rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject
108-
if (rootObject!!.has("error"))
109-
throw OpenAIError.fromJson(rootObject!!.get("error").asJsonObject)
184+
val rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject
185+
if (rootObject.has("error"))
186+
throw OpenAIError.fromJson(rootObject.get("error").asJsonObject)
110187

111188
return gson.fromJson(rootObject, ChatResponse::class.java)
112189
}
@@ -176,7 +253,7 @@ class OpenAI @JvmOverloads constructor(
176253
fun streamChatCompletion(
177254
request: ChatRequest,
178255
onResponse: Consumer<ChatResponseChunk>, // use Consumer instead of Kotlin for better Java syntax
179-
onFailure: Consumer<IOException> = Consumer { it.printStackTrace() }
256+
onFailure: Consumer<WrappedIOError> = Consumer { it.printStackTrace() }
180257
) {
181258
@Suppress("DEPRECATION")
182259
request.stream = true // use requestResponse for stream=false
@@ -186,7 +263,7 @@ class OpenAI @JvmOverloads constructor(
186263
var cache: ChatResponseChunk? = null
187264

188265
override fun onFailure(call: Call, e: IOException) {
189-
onFailure.accept(e)
266+
onFailure.accept(WrappedIOError(e))
190267
}
191268

192269
override fun onResponse(call: Call, response: Response) {
@@ -203,6 +280,9 @@ class OpenAI @JvmOverloads constructor(
203280
continue
204281

205282
val rootObject = JsonParser.parseString(jsonResponse).asJsonObject
283+
if (rootObject.has("error"))
284+
throw OpenAIError.fromJson(rootObject.get("error").asJsonObject)
285+
206286
if (cache == null)
207287
cache = gson.fromJson(rootObject, ChatResponseChunk::class.java)
208288
else
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.cjcrafter.openai.completions
2+
3+
import com.cjcrafter.openai.FinishReason
4+
import com.google.gson.annotations.SerializedName
5+
6+
/**
7+
* The OpenAI API returns a list of `CompletionChoice`. Each choice has a
8+
* generated message ([CompletionChoice.text]) and a finish reason
9+
* ([CompletionChoice.finishReason]). For most use cases, you only need the
10+
* generated text.
11+
*
12+
* By default, only 1 choice is generated (since [CompletionRequest.n] == 1).
13+
* When you increase `n` or provide a list of prompts (called batching),
14+
* there will be multiple choices.
15+
*
16+
* @property text The few generated tokens.
17+
* @property index The index in the list... This is 0 for most use cases.
18+
* @property logprobs List of logarithmic probabilities for each token in the generated text.
19+
* @property finishReason The reason the bot stopped generating tokens.
20+
* @constructor Create empty Completion choice, for internal usage.
21+
* @see FinishReason
22+
*/
23+
data class CompletionChoiceChunk(
24+
val text: String,
25+
val index: Int,
26+
val logprobs: List<Float>?,
27+
@field:SerializedName("finish_reason") val finishReason: FinishReason?
28+
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package com.cjcrafter.openai.completions
2+
3+
import java.time.Instant
4+
import java.time.ZoneId
5+
import java.time.ZonedDateTime
6+
import java.util.*
7+
8+
/**
9+
* The `CompletionResponse` contains all the data returned by the OpenAI Completions
10+
* API. For most use cases, [CompletionResponse.get] (passing 0 to the index argument)
11+
* is all you need.
12+
*
13+
* @property id The unique id for your request.
14+
* @property created The Unix timestamp (measured in seconds since 00:00:00 UTC on Junuary 1, 1970) when the API response was created.
15+
* @property model The model used to generate the completion.
16+
* @property choices The generated completion(s).
17+
* @constructor Create Completion response (for internal usage)
18+
*/
19+
data class CompletionResponseChunk(
20+
val id: String,
21+
val created: Long,
22+
val model: String,
23+
val choices: List<CompletionChoiceChunk>,
24+
) {
25+
26+
/**
27+
* Returns the [Instant] time that the OpenAI Completion API sent this response.
28+
* The time is measured as a unix timestamp (measured in seconds since
29+
* 00:00:00 UTC on January 1, 1970).
30+
*
31+
* Note that users expect time to be measured in their timezone, so
32+
* [getZonedTime] is preferred.
33+
*
34+
* @return The instant the api created this response.
35+
* @see getZonedTime
36+
*/
37+
fun getTime(): Instant {
38+
return Instant.ofEpochSecond(created)
39+
}
40+
41+
/**
42+
* Returns the time-zoned instant that the OpenAI Completion API sent this
43+
* response. By default, this method uses the system's timezone.
44+
*
45+
* @param timezone The user's timezone.
46+
* @return The timezone adjusted date time.
47+
* @see TimeZone.getDefault
48+
*/
49+
@JvmOverloads
50+
fun getZonedTime(timezone: ZoneId = TimeZone.getDefault().toZoneId()): ZonedDateTime {
51+
return ZonedDateTime.ofInstant(getTime(), timezone)
52+
}
53+
54+
/**
55+
* Shorthand for accessing the generated messages (shorthand for
56+
* [CompletionResponseChunk.choices]).
57+
*
58+
* @param index The index of the message.
59+
* @return The generated [CompletionChoiceChunk] at the index.
60+
*/
61+
operator fun get(index: Int): CompletionChoiceChunk {
62+
return choices[index]
63+
}
64+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import com.cjcrafter.openai.OpenAI
2+
import com.cjcrafter.openai.completions.CompletionRequest
3+
import io.github.cdimascio.dotenv.dotenv
4+
5+
fun main(args: Array<String>) {
6+
7+
// Prepare the ChatRequest
8+
val request = CompletionRequest(model="davinci", prompt="Hello darkness", maxTokens = 1024)
9+
10+
// Loads the API key from the .env file in the root directory.
11+
val key = dotenv()["OPENAI_TOKEN"]
12+
val openai = OpenAI(key)
13+
14+
// Generate a response, and print it to the user
15+
//println(openai.createCompletion(request))
16+
openai.streamCompletionKotlin(request) {
17+
print(choices[0].text)
18+
}
19+
}

0 commit comments

Comments
 (0)