Skip to content

Commit 8db59dd

Browse files
authored
Merge pull request #13 from CJCrafter/completions
Completions
2 parents 7f38791 + 953eedf commit 8db59dd

19 files changed

+1240
-307
lines changed

src/main/kotlin/com/cjcrafter/openai/OpenAI.kt

Lines changed: 339 additions & 112 deletions
Large diffs are not rendered by default.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package com.cjcrafter.openai
2+
3+
import com.cjcrafter.openai.exception.OpenAIError
4+
import com.cjcrafter.openai.exception.WrappedIOError
5+
import com.google.gson.JsonObject
6+
import com.google.gson.JsonParseException
7+
import com.google.gson.JsonParser
8+
import okhttp3.Call
9+
import okhttp3.Callback
10+
import okhttp3.Response
11+
import java.io.IOException
12+
import java.util.function.Consumer
13+
14+
internal class OpenAICallback(
15+
private val isStream: Boolean,
16+
private val onFailure: Consumer<OpenAIError>,
17+
private val onResponse: Consumer<JsonObject>
18+
) : Callback {
19+
20+
override fun onFailure(call: Call, e: IOException) {
21+
onFailure.accept(WrappedIOError(e))
22+
}
23+
24+
override fun onResponse(call: Call, response: Response) {
25+
onResponse(response)
26+
}
27+
28+
fun onResponse(response: Response) {
29+
if (isStream) {
30+
handleStream(response)
31+
return
32+
}
33+
34+
val rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject
35+
36+
// Sometimes OpenAI will respond with an error code for malformed
37+
// requests, timeouts, rate limits, etc. We need to let the dev
38+
// know that an error occurred.
39+
if (rootObject.has("error")) {
40+
onFailure.accept(OpenAIError.fromJson(rootObject.get("error").asJsonObject))
41+
return
42+
}
43+
44+
onResponse.accept(rootObject)
45+
}
46+
47+
private fun handleStream(response: Response) {
48+
response.body?.source()?.use { source ->
49+
50+
while (!source.exhausted()) {
51+
var jsonResponse = source.readUtf8Line()
52+
53+
// Or data is separated by empty lines, ignore them. The final
54+
// line is always "data: [DONE]", ignore it.
55+
if (jsonResponse.isNullOrEmpty() || jsonResponse == "data: [DONE]")
56+
continue
57+
58+
// The CHAT API returns a json string, but they prepend the content
59+
// with "data: " (which is not valid json). In order to parse this
60+
// into a JsonObject, we have to strip away this extra string.
61+
if (jsonResponse.startsWith("data: "))
62+
jsonResponse = jsonResponse.substring("data: ".length)
63+
64+
lateinit var rootObject: JsonObject
65+
try {
66+
rootObject = JsonParser.parseString(jsonResponse).asJsonObject
67+
} catch (ex: JsonParseException) {
68+
println(jsonResponse)
69+
ex.printStackTrace()
70+
continue
71+
}
72+
73+
// Sometimes OpenAI will respond with an error code for malformed
74+
// requests, timeouts, rate limits, etc. We need to let the dev
75+
// know that an error occurred.
76+
if (rootObject.has("error")) {
77+
onFailure.accept(OpenAIError.fromJson(rootObject.get("error").asJsonObject))
78+
continue
79+
}
80+
81+
// Developer defined code to run
82+
onResponse.accept(rootObject)
83+
}
84+
}
85+
}
86+
}

src/main/kotlin/com/cjcrafter/openai/chat/ChatChoice.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import com.google.gson.JsonObject
55
import com.google.gson.annotations.SerializedName
66

77
/**
8-
* The OpenAI API returns a list of [ChatChoice]. Each chat choice has a
8+
* The OpenAI API returns a list of `ChatChoice`. Each choice has a
99
* generated message ([ChatChoice.message]) and a finish reason
1010
* ([ChatChoice.finishReason]). For most use cases, you only need the generated
1111
* message.

src/main/kotlin/com/cjcrafter/openai/chat/ChatChoiceChunk.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ data class ChatChoiceChunk(
3737
message.content += delta
3838
finishReason = if (json["finish_reason"].isJsonNull) null else FinishReason.valueOf(json["finish_reason"].asString.uppercase())
3939
}
40+
41+
/**
42+
* Returns `true` if this message chunk is complete. Once complete, no more
43+
* tokens will be generated, and [ChatChoiceChunk.message] will contain the
44+
* complete message.
45+
*/
46+
fun isFinished() = finishReason != null
4047
}
4148

4249
/*

src/main/kotlin/com/cjcrafter/openai/chat/ChatRequest.kt

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import com.google.gson.annotations.SerializedName
66
* [ChatRequest] holds the configurable options that can be sent to the OpenAI
77
* Chat API. For most use cases, you only need to set [model] and [messages].
88
* For more detailed descriptions for each option, refer to the
9-
* [Chat Uncyclo](https://platform.openai.com/docs/api-reference/chat)
9+
* [Chat Uncyclo](https://platform.openai.com/docs/api-reference/chat).
1010
*
1111
* [messages] stores **ALL** previous messages from the conversation. It is
1212
* **YOUR RESPONSIBILITY** to store and update this list for your conversations
@@ -49,7 +49,7 @@ data class ChatRequest @JvmOverloads constructor(
4949
var temperature: Float? = null,
5050
@field:SerializedName("top_p") var topP: Float? = null,
5151
var n: Int? = null,
52-
@Deprecated("Use ChatBot#streamResponse") var stream: Boolean? = null,
52+
@Deprecated("Use OpenAI#streamChatCompletion") var stream: Boolean? = null,
5353
var stop: String? = null,
5454
@field:SerializedName("max_tokens") var maxTokens: Int? = null,
5555
@field:SerializedName("presence_penalty") var presencePenalty: Float? = null,
@@ -58,20 +58,8 @@ data class ChatRequest @JvmOverloads constructor(
5858
var user: String? = null
5959
) {
6060

61-
companion object {
62-
63-
/**
64-
* A static method that provides a new [Builder] instance for the
65-
* [ChatRequest] class.
66-
*
67-
* @return a new [Builder] instance for creating a [ChatRequest] object.
68-
*/
69-
@JvmStatic
70-
fun builder(): Builder = Builder()
71-
}
72-
7361
/**
74-
* [Builder] is a helper class to build a [ChatRequest] instance with a fluent API.
62+
* [Builder] is a helper class to build a [ChatRequest] instance with a stable API.
7563
* It provides methods for setting the properties of the [ChatRequest] object.
7664
* The [build] method returns a new [ChatRequest] instance with the specified properties.
7765
*
@@ -80,7 +68,6 @@ data class ChatRequest @JvmOverloads constructor(
8068
* val chatRequest = ChatRequest.builder()
8169
* .model("gpt-3.5-turbo")
8270
* .messages(mutableListOf("Be as helpful as possible".toSystemMessage()))
83-
* .temperature(0.7f)
8471
* .build()
8572
* ```
8673
*
@@ -222,4 +209,16 @@ data class ChatRequest @JvmOverloads constructor(
222209
)
223210
}
224211
}
212+
213+
companion object {
214+
215+
/**
216+
* A static method that provides a new [Builder] instance for the
217+
* [ChatRequest] class.
218+
*
219+
* @return a new [Builder] instance for creating a [ChatRequest] object.
220+
*/
221+
@JvmStatic
222+
fun builder(): Builder = Builder()
223+
}
225224
}

src/main/kotlin/com/cjcrafter/openai/chat/ChatResponse.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import java.time.ZonedDateTime
77
import java.util.*
88

99
/**
10-
* The [ChatResponse] contains all the data returned by the OpenAI Chat API.
10+
* The `ChatResponse` contains all the data returned by the OpenAI Chat API.
1111
* For most use cases, [ChatResponse.get] (passing 0 to the index argument) is
1212
* all you need.
1313
*
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.cjcrafter.openai.completions
2+
3+
import com.cjcrafter.openai.FinishReason
4+
import com.google.gson.annotations.SerializedName
5+
6+
/**
7+
* The OpenAI API returns a list of `CompletionChoice`. Each choice has a
8+
* generated message ([CompletionChoice.text]) and a finish reason
9+
* ([CompletionChoice.finishReason]). For most use cases, you only need the
10+
* generated text.
11+
*
12+
* By default, only 1 choice is generated (since [CompletionRequest.n] == 1).
13+
* When you increase `n` or provide a list of prompts (called batching),
14+
* there will be multiple choices.
15+
*
16+
* @property text The generated text.
17+
* @property index The index in the list... This is 0 for most use cases.
18+
* @property logprobs List of logarithmic probabilities for each token in the generated text.
19+
* @property finishReason The reason the bot stopped generating tokens.
20+
* @constructor Create empty Completion choice, for internal usage.
21+
* @see FinishReason
22+
*/
23+
data class CompletionChoice(
24+
val text: String,
25+
val index: Int,
26+
val logprobs: List<Float>?,
27+
@field:SerializedName("finish_reason") val finishReason: FinishReason
28+
)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package com.cjcrafter.openai.completions
2+
3+
import com.cjcrafter.openai.FinishReason
4+
import com.cjcrafter.openai.chat.ChatChoiceChunk
5+
import com.google.gson.annotations.SerializedName
6+
7+
/**
8+
* The OpenAI API returns a list of `CompletionChoice`. Each choice has a
9+
* generated message ([CompletionChoice.text]) and a finish reason
10+
* ([CompletionChoice.finishReason]). For most use cases, you only need the
11+
* generated text.
12+
*
13+
* By default, only 1 choice is generated (since [CompletionRequest.n] == 1).
14+
* When you increase `n` or provide a list of prompts (called batching),
15+
* there will be multiple choices.
16+
*
17+
* @property text The few generated tokens.
18+
* @property index The index in the list... This is 0 for most use cases.
19+
* @property logprobs List of logarithmic probabilities for each token in the generated text.
20+
* @property finishReason The reason the bot stopped generating tokens.
21+
* @constructor Create empty Completion choice, for internal usage.
22+
* @see FinishReason
23+
*/
24+
data class CompletionChoiceChunk(
25+
val text: String,
26+
val index: Int,
27+
val logprobs: List<Float>?,
28+
@field:SerializedName("finish_reason") val finishReason: FinishReason?
29+
) {
30+
/**
31+
* Returns `true` if this message chunk is complete. Once complete, no more
32+
* tokens will be generated.
33+
*/
34+
fun isFinished() = finishReason != null
35+
}

0 commit comments

Comments
 (0)