Skip to content

Commit cd506a2

Browse files
committed
add stream support for tools
1 parent 86cdab6 commit cd506a2

File tree

11 files changed

+103
-134
lines changed

11 files changed

+103
-134
lines changed

src/main/kotlin/com/cjcrafter/openai/OpenAI.kt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import com.cjcrafter.openai.chat.tool.ToolChoice
55
import com.cjcrafter.openai.completions.CompletionRequest
66
import com.cjcrafter.openai.completions.CompletionResponse
77
import com.cjcrafter.openai.completions.CompletionResponseChunk
8-
import com.cjcrafter.openai.jackson.ChatChoiceChunkDeserializer
9-
import com.cjcrafter.openai.jackson.ChatChoiceChunkSerializer
108
import com.cjcrafter.openai.jackson.ToolChoiceDeserializer
119
import com.cjcrafter.openai.jackson.ToolChoiceSerializer
1210
import com.fasterxml.jackson.annotation.JsonInclude
@@ -145,8 +143,6 @@ interface OpenAI {
145143

146144
// Register modules with custom serializers/deserializers
147145
val module = SimpleModule().apply {
148-
addSerializer(ChatChoiceChunk::class.java, ChatChoiceChunkSerializer())
149-
addDeserializer(ChatChoiceChunk::class.java, ChatChoiceChunkDeserializer())
150146
addSerializer(ToolChoice::class.java, ToolChoiceSerializer())
151147
addDeserializer(ToolChoice::class.java, ToolChoiceDeserializer())
152148
}

src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,45 +9,11 @@ import com.fasterxml.jackson.databind.node.ObjectNode
99
import okhttp3.*
1010
import okhttp3.MediaType.Companion.toMediaType
1111
import okhttp3.RequestBody.Companion.toRequestBody
12+
import org.jetbrains.annotations.ApiStatus
1213
import java.io.BufferedReader
1314
import java.io.IOException
1415

15-
/**
16-
* The `OpenAI` class contains all the API calls to OpenAI's endpoint. Whether
17-
* you are working with images, chat, or completions, you need to have an
18-
* `OpenAI` instance to make the API requests.
19-
*
20-
* To get your API key:
21-
* 1. Log in to your account: Go to [https://www.openai.com/](openai.com) and
22-
* log in.
23-
* 2. Access the API dashboard: After logging in, click on the "API" tab.
24-
* 3. Choose a subscription plan: Select a suitable plan based on your needs
25-
* and complete the payment process.
26-
* 4. Obtain your API key: After subscribing to a plan, you will be redirected
27-
* to the API dashboard, where you can find your unique API key. Copy and store it securely.
28-
*
29-
* All API methods in this class have a non-blocking option which will enqueues
30-
* the HTTPS request on a different thread. These method names have `Async
31-
* appended to the end of their names.
32-
*
33-
* Completions API:
34-
* * [createCompletion]
35-
* * [streamCompletion]
36-
* * [createCompletionAsync]
37-
* * [streamCompletionAsync]
38-
*
39-
* Chat API:
40-
* * [createChatCompletion]
41-
* * [streamChatCompletion]
42-
* * [createChatCompletionAsync]
43-
* * [streamChatCompletionAsync]
44-
*
45-
* @property apiKey Your OpenAI API key. It starts with `"sk-"` (without the quotes).
46-
* @property organization If you belong to multiple organizations, specify which one to use (else `null`).
47-
* @property client Controls proxies, timeouts, etc.
48-
* @constructor Create a ChatBot for responding to requests.
49-
*/
50-
open class OpenAIImpl @JvmOverloads constructor(
16+
open class OpenAIImpl @ApiStatus.Internal constructor(
5117
protected val apiKey: String,
5218
protected val organization: String? = null,
5319
private val client: OkHttpClient = OkHttpClient()
@@ -98,7 +64,9 @@ open class OpenAIImpl @JvmOverloads constructor(
9864
}
9965

10066
val json = httpResponse.body?.byteStream()?.bufferedReader() ?: throw IOException("Response body is null")
101-
return objectMapper.readValue(json, ChatResponse::class.java)
67+
val str = json.readText()
68+
println(str)
69+
return objectMapper.readValue(str, ChatResponse::class.java)
10270
}
10371

10472
override fun streamChatCompletion(request: ChatRequest): Iterable<ChatResponseChunk> {
@@ -143,10 +111,11 @@ open class OpenAIImpl @JvmOverloads constructor(
143111

144112
override fun next(): ChatResponseChunk {
145113
val currentLine = nextLine ?: throw NoSuchElementException("No more lines")
146-
//println(" $currentLine")
114+
//println(" $currentLine")
147115
chunk = chunk?.apply { update(objectMapper.readTree(currentLine) as ObjectNode) } ?: objectMapper.readValue(currentLine, ChatResponseChunk::class.java)
148116
nextLine = readNextLine(reader) // Prepare the next line
149117
return chunk!!
118+
//return ChatResponseChunk("1", 1, listOf())
150119
}
151120
}
152121
}
Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package com.cjcrafter.openai.chat
22

33
import com.cjcrafter.openai.FinishReason
4-
import com.cjcrafter.openai.jackson.ChatChoiceChunkDeserializer
5-
import com.cjcrafter.openai.jackson.ChatChoiceChunkSerializer
4+
import com.cjcrafter.openai.OpenAI
5+
import com.cjcrafter.openai.chat.tool.ChatMessageDelta
6+
import com.cjcrafter.openai.chat.tool.ToolCall
67
import com.fasterxml.jackson.annotation.JsonProperty
78
import com.fasterxml.jackson.databind.JsonNode
8-
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
9+
import java.lang.IllegalArgumentException
910

1011
/**
1112
*
@@ -29,24 +30,38 @@ import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
2930
*/
3031
data class ChatChoiceChunk(
3132
val index: Int,
32-
val message: ChatMessage,
33-
var delta: String,
33+
var delta: ChatMessageDelta? = null,
3434
@JsonProperty("finish_reason") var finishReason: FinishReason?
3535
) {
36+
val message: ChatMessage = ChatMessage(delta?.role!!, delta?.content, delta?.toolCalls?.map { it.toToolCall() })
3637

37-
internal fun update(json: String) {
38-
val node: JsonNode = jacksonObjectMapper().readTree(json)
39-
val deltaNode = node.get("delta")
40-
delta = if (deltaNode?.has("content") == true && !deltaNode.get("content").isNull) {
41-
deltaNode.get("content").asText()
42-
} else {
43-
""
44-
}
38+
// Reads the json from the OpenAI API, and sets delta. message accumulates changes
39+
internal fun update(json: JsonNode) {
40+
val deltaNode = json.get("delta") ?: throw IllegalArgumentException("Passed a json without delta")
41+
val delta = OpenAI.createObjectMapper().treeToValue(deltaNode, ChatMessageDelta::class.java)
42+
43+
// The "bread and butter" of streaming. You can start showing the user
44+
// generated content usually *within a second*. However, for Tool Calls,
45+
// this will always be null.
46+
if (message.content == null)
47+
message.content = delta?.content // Always null for tool calls
48+
else if (delta.content != null)
49+
message.content += delta.content
4550

46-
message.content += delta
47-
finishReason = node.get("finish_reason")?.takeIf { !it.isNull }?.asText()?.let {
48-
FinishReason.valueOf(it.uppercase())
51+
// Handle updating the tool call
52+
if (message.toolCalls != null && delta.toolCalls != null) {
53+
for (deltaToolCall in delta.toolCalls) {
54+
val toolCall = message.toolCalls!![deltaToolCall.index]
55+
toolCall.update(deltaToolCall)
56+
}
4957
}
58+
59+
// The reason the bot stopped generating tokens. This is null until done.
60+
finishReason = json.get("finish_reason")?.let { if (it.isNull) null else FinishReason.valueOf(it.asText().uppercase()) }
61+
62+
// People can manually check changes instead of using this API's
63+
// accumulative changes.
64+
this.delta = delta
5065
}
5166

5267
/**
@@ -55,30 +70,4 @@ data class ChatChoiceChunk(
5570
* complete message.
5671
*/
5772
fun isFinished() = finishReason != null
58-
59-
companion object {
60-
@JvmStatic
61-
fun serializer() = ChatChoiceChunkSerializer()
62-
63-
@JvmStatic
64-
fun deserializer() = ChatChoiceChunkDeserializer()
65-
}
6673
}
67-
68-
/*
69-
Below is a potential Steam response from OpenAI. You can see that the first
70-
message contains 0 generated content, and the last message (before "[DONE]")
71-
adds the finish_reason.
72-
73-
data: {"id":"chatcmpl-6xUB4Vi8jEG8u4hMBTMeO8KXgA87z","object":"chat.completion.chunk","created":1679635374,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
74-
75-
data: {"id":"chatcmpl-6xUB4Vi8jEG8u4hMBTMeO8KXgA87z","object":"chat.completion.chunk","created":1679635374,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Hello"},"index":0,"finish_reason":null}]}
76-
77-
data: {"id":"chatcmpl-6xUB4Vi8jEG8u4hMBTMeO8KXgA87z","object":"chat.completion.chunk","created":1679635374,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":" World"},"index":0,"finish_reason":null}]}
78-
79-
data: {"id":"chatcmpl-6xUB4Vi8jEG8u4hMBTMeO8KXgA87z","object":"chat.completion.chunk","created":1679635374,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"."},"index":0,"finish_reason":null}]}
80-
81-
data: {"id":"chatcmpl-6xUB4Vi8jEG8u4hMBTMeO8KXgA87z","object":"chat.completion.chunk","created":1679635374,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{},"index":0,"finish_reason":"stop"}]}
82-
83-
data: [DONE]
84-
*/

src/main/kotlin/com/cjcrafter/openai/chat/ChatResponseChunk.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ data class ChatResponseChunk(
3232
internal fun update(json: ObjectNode) {
3333
val choicesArray = json.get("choices") as? ArrayNode
3434
choicesArray?.forEachIndexed { index, jsonNode ->
35-
choices[index].update(jsonNode.toString())
35+
choices[index].update(jsonNode)
3636
}
3737
}
3838

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package com.cjcrafter.openai.chat.tool
2+
3+
import com.cjcrafter.openai.chat.ChatUser
4+
import com.fasterxml.jackson.annotation.JsonProperty
5+
6+
data class ChatMessageDelta(
7+
val role: ChatUser? = null,
8+
val content: String? = null,
9+
@JsonProperty("tool_calls") val toolCalls: List<ToolCallDelta>? = null,
10+
) {
11+
}

src/main/kotlin/com/cjcrafter/openai/chat/tool/FunctionCall.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import com.fasterxml.jackson.core.JsonProcessingException
55
import com.fasterxml.jackson.databind.JsonNode
66
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
77
import com.fasterxml.jackson.module.kotlin.readValue
8+
import org.jetbrains.annotations.ApiStatus
89

910
/**
1011
* Represents a function call by ChatGPT. When ChatGPT calls a function, you
@@ -23,6 +24,12 @@ data class FunctionCall(
2324
var arguments: String,
2425
) {
2526

27+
@ApiStatus.Internal
28+
fun update(delta: FunctionCallDelta) {
29+
// The only field that updates is arguments
30+
arguments += delta.arguments
31+
}
32+
2633
/**
2734
* Attempts to parse the arguments passed to the function.
2835
*
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.cjcrafter.openai.chat.tool
2+
3+
import org.jetbrains.annotations.ApiStatus
4+
5+
data class FunctionCallDelta(
6+
val name: String?,
7+
val arguments: String,
8+
) {
9+
10+
/**
11+
* Returns an **incomplete** function call.
12+
*/
13+
@ApiStatus.Internal
14+
fun toFunctionCall(): FunctionCall {
15+
return FunctionCall(
16+
name ?: throw IllegalStateException("name must be set"),
17+
arguments
18+
)
19+
}
20+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
package com.cjcrafter.openai.chat.tool
22

3+
import org.jetbrains.annotations.ApiStatus
4+
35
data class ToolCall(
46
var id: String,
57
var type: ToolType,
68
var function: FunctionCall,
79
) {
10+
@ApiStatus.Internal
11+
internal fun update(delta: ToolCallDelta) {
12+
// The only field that updates is function
13+
if (delta.function != null)
14+
function.update(delta.function)
15+
}
816
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package com.cjcrafter.openai.chat.tool
2+
3+
import org.jetbrains.annotations.ApiStatus
4+
5+
data class ToolCallDelta(
6+
val index: Int,
7+
val id: String? = null,
8+
val type: ToolType? = null,
9+
val function: FunctionCallDelta? = null,
10+
) {
11+
@ApiStatus.Internal
12+
fun toToolCall() = ToolCall(
13+
id = id ?: throw IllegalStateException("id must be set"),
14+
type = type ?: throw IllegalStateException("type must be set"),
15+
function = function?.toFunctionCall() ?: throw IllegalStateException("function must be set"),
16+
)
17+
}

src/main/kotlin/com/cjcrafter/openai/jackson/ChatChoiceChunkSerializers.kt

Lines changed: 0 additions & 48 deletions
This file was deleted.

src/test/kotlin/KotlinTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ fun doChat(stream: Boolean) {
125125
do {
126126
if (stream) {
127127
for (chunk in openai.streamChatCompletion(request)) {
128-
print(chunk[0].delta)
128+
chunk[0].delta?.content?.let { print(it) }
129129
if (chunk[0].isFinished()) {
130130
finishReason = chunk[0].finishReason
131131
messages.add(chunk[0].message)

0 commit comments

Comments
 (0)