Skip to content

Commit cf65c34

Browse files
committed
implement tools
1 parent 0adbbc7 commit cf65c34

27 files changed

+756
-335
lines changed

build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ repositories {
2020
dependencies {
2121
implementation("com.squareup.okhttp3:okhttp:4.9.2")
2222
implementation("com.google.code.gson:gson:2.10.1")
23+
implementation("org.jetbrains:annotations:24.0.1")
2324

2425
testImplementation("io.github.cdimascio:dotenv-kotlin:6.4.1")
2526
testImplementation("org.junit.jupiter:junit-jupiter:5.9.2")
Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package com.cjcrafter.openai
22

3-
import com.cjcrafter.openai.gson.FinishReasonAdapter
3+
import com.google.gson.annotations.SerializedName
44

55
/**
66
* [FinishReason] wraps the possible reasons that a generation model may stop
@@ -15,6 +15,7 @@ enum class FinishReason {
1515
* completely generates its entire message, and has nothing else to add.
1616
* Ideally, you always want your finish reason to be [STOP].
1717
*/
18+
@SerializedName("stop")
1819
STOP,
1920

2021
/**
@@ -23,33 +24,23 @@ enum class FinishReason {
2324
* message with finish reason [LENGTH]. Some models have a higher token
2425
* limit than others.
2526
*/
27+
@SerializedName("length")
2628
LENGTH,
2729

2830
/**
2931
* Occurs due to a flag from OpenAI's content filters. This occurrence is
3032
* rare, and tends to happen when you blatantly violate OpenAI's terms.
3133
*/
34+
@SerializedName("content_filter")
3235
CONTENT_FILTER,
3336

3437
/**
3538
* Occurs when the model uses one of the available tools.
3639
*/
40+
@SerializedName("tool_calls")
3741
TOOL_CALLS,
3842

3943
@Deprecated("functions have been replaced by tools")
44+
@SerializedName("function_call")
4045
FUNCTION_CALL;
41-
42-
companion object {
43-
44-
/**
45-
* Returns the google gson adapter for serializing this enum to a json
46-
* file. Whe
47-
*
48-
* @return
49-
*/
50-
@JvmStatic
51-
fun adapter() : FinishReasonAdapter {
52-
return FinishReasonAdapter()
53-
}
54-
}
5546
}

src/main/kotlin/com/cjcrafter/openai/OpenAI.kt

Lines changed: 95 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,163 @@
11
package com.cjcrafter.openai
22

33
import com.cjcrafter.openai.chat.*
4+
import com.cjcrafter.openai.chat.tool.ToolChoice
45
import com.cjcrafter.openai.completions.CompletionRequest
56
import com.cjcrafter.openai.completions.CompletionResponse
67
import com.cjcrafter.openai.completions.CompletionResponseChunk
7-
import com.cjcrafter.openai.exception.OpenAIError
8-
import com.cjcrafter.openai.gson.ChatChoiceChunkAdapter
9-
import com.cjcrafter.openai.gson.ChatUserAdapter
10-
import com.cjcrafter.openai.gson.FinishReasonAdapter
118
import com.google.gson.Gson
129
import com.google.gson.GsonBuilder
1310
import okhttp3.OkHttpClient
1411

1512
interface OpenAI {
1613

17-
@Throws(OpenAIError::class)
14+
/**
15+
* Calls the [completions](https://platform.openai.com/docs/api-reference/completions)
16+
* API endpoint. This method is blocking.
17+
*
18+
* Completions are considered Legacy, and OpenAI officially recommends that
19+
* all developers use the **chat completion** endpoint instead. See
20+
* [createChatCompletion].
21+
*
22+
* @param request The request to send to the API
23+
* @return The response from the API
24+
*/
1825
fun createCompletion(request: CompletionRequest): CompletionResponse
1926

20-
@Throws(OpenAIError::class)
27+
/**
28+
* Calls the [completions](https://platform.openai.com/docs/api-reference/completions)
29+
* API endpoint and streams each token 1 at a time for a faster response
30+
* time.
31+
*
32+
* This method is **technically** not blocking, but the returned iterable
33+
* will block until the next token is generated.
34+
* ```
35+
* // Each iteration of the loop will block until the next token is streamed
36+
* for (chunk in openAI.streamCompletion(request)) {
37+
* // Do something with the chunk
38+
* }
39+
* ```
40+
*
41+
* Completions are considered Legacy, and OpenAI officially recommends that
42+
* all developers use the **chat completion** endpoint isntead. See
43+
* [streamChatCompletion].
44+
*
45+
* @param request The request to send to the API
46+
* @return The response from the API
47+
*/
2148
fun streamCompletion(request: CompletionRequest): Iterable<CompletionResponseChunk>
2249

23-
@Throws(OpenAIError::class)
50+
/**
51+
* Calls the [chat completions](https://platform.openai.com/docs/api-reference/chat)
52+
* API endpoint. This method is blocking.
53+
*
54+
* @param request The request to send to the API
55+
* @return The response from the API
56+
*/
2457
fun createChatCompletion(request: ChatRequest): ChatResponse
2558

26-
@Throws(OpenAIError::class)
59+
/**
60+
* Calls the [chat completions](https://platform.openai.com/docs/api-reference/chat)
61+
* API endpoint and streams each token 1 at a time for a faster response.
62+
*
63+
* This method is **technically** not blocking, but the returned iterable
64+
* will block until the next token is generated.
65+
* ```
66+
* // Each iteration of the loop will block until the next token is streamed
67+
* for (chunk in openAI.streamChatCompletion(request)) {
68+
* // Do something with the chunk
69+
* }
70+
* ```
71+
*
72+
* @param request The request to send to the API
73+
* @return The response from the API
74+
*/
2775
fun streamChatCompletion(request: ChatRequest): Iterable<ChatResponseChunk>
2876

29-
open class Builder {
77+
open class Builder internal constructor() {
3078
protected var apiKey: String? = null
3179
protected var organization: String? = null
3280
protected var client: OkHttpClient = OkHttpClient()
3381

3482
fun apiKey(apiKey: String) = apply { this.apiKey = apiKey }
35-
3683
fun organization(organization: String?) = apply { this.organization = organization }
37-
3884
fun client(client: OkHttpClient) = apply { this.client = client }
3985

4086
open fun build(): OpenAI {
41-
checkNotNull(apiKey) { "apiKey must be defined to use OpenAI" }
42-
return OpenAIImpl(apiKey!!, organization, client)
87+
return OpenAIImpl(
88+
apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
89+
organization,
90+
client
91+
)
4392
}
4493
}
4594

46-
class AzureBuilder : Builder() {
95+
class AzureBuilder internal constructor(): Builder() {
4796
private var azureBaseUrl: String? = null
4897
private var apiVersion: String? = null
4998
private var modelName: String? = null
5099

51100
fun azureBaseUrl(azureBaseUrl: String) = apply { this.azureBaseUrl = azureBaseUrl }
52-
53101
fun apiVersion(apiVersion: String) = apply { this.apiVersion = apiVersion }
54-
55102
fun modelName(modelName: String) = apply { this.modelName = modelName }
56103

57104
override fun build(): OpenAI {
58-
checkNotNull(apiKey) { "apiKey must be defined to use OpenAI" }
59-
checkNotNull(azureBaseUrl) { "azureBaseUrl must be defined for azure" }
60-
checkNotNull(apiVersion) { "apiVersion must be defined for azure" }
61-
checkNotNull(modelName) { "modelName must be defined for azure" }
62-
63-
return AzureOpenAI(apiKey!!, organization, client, azureBaseUrl!!, apiVersion!!, modelName!!)
105+
return AzureOpenAI(
106+
apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
107+
organization,
108+
client,
109+
azureBaseUrl ?: throw IllegalStateException("azureBaseUrl must be defined for azure"),
110+
apiVersion ?: throw IllegalStateException("apiVersion must be defined for azure"),
111+
modelName ?: throw IllegalStateException("modelName must be defined for azure")
112+
)
64113
}
65114
}
66115

67116
companion object {
68117

118+
/**
119+
* Instantiates a builder for a default OpenAI instance. For Azure's
120+
* OpenAI, use [azureBuilder] instead.
121+
*/
69122
@JvmStatic
70123
fun builder() = Builder()
71124

125+
/**
126+
* Instantiates a builder for an Azure OpenAI.
127+
*/
72128
@JvmStatic
73129
fun azureBuilder() = AzureBuilder()
74130

131+
/**
132+
* Returns a Gson instance with the default OpenAI adapters registered.
133+
* This can be used to save conversations (and other data) to file.
134+
*/
75135
@JvmStatic
76136
fun createGson(): Gson = createGsonBuilder().create()
77137

138+
/**
139+
* Returns a GsonBuilder instance with the default OpenAI adapters
140+
* registered.
141+
*/
78142
@JvmStatic
79143
fun createGsonBuilder(): GsonBuilder {
80144
return GsonBuilder()
81-
.registerTypeAdapter(ChatUser::class.java, ChatUserAdapter())
82-
.registerTypeAdapter(FinishReason::class.java, FinishReasonAdapter())
83-
.registerTypeAdapter(ChatChoiceChunk::class.java, ChatChoiceChunkAdapter())
145+
.serializeNulls()
146+
.registerTypeAdapter(ChatChoiceChunk::class.java, ChatChoiceChunk.adapter())
147+
.registerTypeAdapter(ToolChoice::class.java, ToolChoice.adapter())
84148
}
85149

150+
/**
151+
* Extension function to stream a completion using kotlin coroutines.
152+
*/
86153
fun OpenAI.streamCompletion(request: CompletionRequest, consumer: (CompletionResponseChunk) -> Unit) {
87154
for (chunk in streamCompletion(request))
88155
consumer(chunk)
89156
}
90157

158+
/**
159+
* Extension function to stream a chat completion using kotlin coroutines.
160+
*/
91161
fun OpenAI.streamChatCompletion(request: ChatRequest, consumer: (ChatResponseChunk) -> Unit) {
92162
for (chunk in streamChatCompletion(request))
93163
consumer(chunk)

src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,15 @@ open class OpenAIImpl @JvmOverloads constructor(
9696
request.stream = false // use streamChatCompletion for stream=true
9797
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
9898

99-
return ChatResponse("1", 1, listOf(), ChatUsage(1, 1, 1))
99+
val httpResponse = client.newCall(httpRequest).execute()
100+
if (!httpResponse.isSuccessful) {
101+
val json = httpResponse.body?.byteStream()?.bufferedReader()?.readText()
102+
httpResponse.close()
103+
throw IOException("Unexpected code $httpResponse, recieved: $json")
104+
}
105+
106+
val json = httpResponse.body?.byteStream()?.bufferedReader() ?: throw IOException("Response body is null")
107+
return gson.fromJson(json, ChatResponse::class.java)
100108
}
101109

102110
override fun streamChatCompletion(request: ChatRequest): Iterable<ChatResponseChunk> {

src/main/kotlin/com/cjcrafter/openai/chat/ChatChoiceChunk.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.cjcrafter.openai.chat
22

33
import com.cjcrafter.openai.FinishReason
4+
import com.cjcrafter.openai.gson.ChatChoiceChunkAdapter
45
import com.google.gson.JsonObject
56
import com.google.gson.annotations.SerializedName
67

@@ -44,6 +45,11 @@ data class ChatChoiceChunk(
4445
* complete message.
4546
*/
4647
fun isFinished() = finishReason != null
48+
49+
companion object {
50+
@JvmStatic
51+
fun adapter() = ChatChoiceChunkAdapter()
52+
}
4753
}
4854

4955
/*

src/main/kotlin/com/cjcrafter/openai/chat/ChatMessage.kt

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package com.cjcrafter.openai.chat
22

3+
import com.cjcrafter.openai.chat.tool.ToolCall
4+
import com.google.gson.annotations.SerializedName
5+
36
/**
47
* ChatGPT's biggest innovation is its conversation memory. To remember the
58
* conversation, we need to map each message to who sent it. This data class
@@ -9,7 +12,17 @@ package com.cjcrafter.openai.chat
912
* @property content The string content of the message.
1013
* @see ChatUser
1114
*/
12-
data class ChatMessage(var role: ChatUser, var content: String) {
15+
data class ChatMessage @JvmOverloads constructor(
16+
var role: ChatUser,
17+
var content: String?,
18+
@field:SerializedName("tool_calls") var toolCalls: List<ToolCall>? = null,
19+
@field:SerializedName("tool_call_id") var toolCallId: String? = null,
20+
) {
21+
init {
22+
if (role == ChatUser.TOOL) {
23+
requireNotNull(toolCallId) { "toolCallId must be set when role is TOOL" }
24+
}
25+
}
1326

1427
companion object {
1528

@@ -36,5 +49,13 @@ data class ChatMessage(var role: ChatUser, var content: String) {
3649
fun String.toAssistantMessage(): ChatMessage {
3750
return ChatMessage(ChatUser.ASSISTANT, this)
3851
}
52+
53+
/**
54+
* Returns a new [ChatMessage] using [ChatUser.TOOL].
55+
*/
56+
@JvmStatic
57+
fun String.toToolMessage(toolCallId: String): ChatMessage {
58+
return ChatMessage(ChatUser.TOOL, this, toolCallId = toolCallId)
59+
}
3960
}
4061
}

0 commit comments

Comments
 (0)