Skip to content

Commit 96aa365

Browse files
committed
add gson adapters for string -> json -> ChatResponse compatibility
1 parent 2c2f63d commit 96aa365

File tree

7 files changed

+176
-13
lines changed

7 files changed

+176
-13
lines changed

src/main/kotlin/com/cjcrafter/openai/FinishReason.kt

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.cjcrafter.openai
22

3+
import com.cjcrafter.openai.gson.FinishReasonAdapter
4+
35
/**
46
* [FinishReason] wraps the possible reasons that a generation model may stop
57
* generating tokens. For most **PROPER** use cases (see [best practices](https://platform.openai.com/docs/guides/chat/introduction)),
@@ -27,5 +29,19 @@ enum class FinishReason {
2729
* This occurrence is rare, and usually only happens when you blatantly
2830
* misuse/violate OpenAI's terms.
2931
*/
30-
CONTENT_FILTER
32+
CONTENT_FILTER;
33+
34+
companion object {
35+
36+
/**
37+
* Returns the google gson adapter for serializing this enum to a json
38+
* file. Whe
39+
*
40+
* @return
41+
*/
42+
@JvmStatic
43+
fun adapter() : FinishReasonAdapter {
44+
return FinishReasonAdapter()
45+
}
46+
}
3147
}

src/main/kotlin/com/cjcrafter/openai/OpenAI.kt

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
package com.cjcrafter.openai
22

3-
import com.cjcrafter.openai.chat.ChatRequest
4-
import com.cjcrafter.openai.chat.ChatResponse
5-
import com.cjcrafter.openai.chat.ChatResponseChunk
6-
import com.cjcrafter.openai.chat.ChatUser
3+
import ChatChoiceChunkAdapter
4+
import com.cjcrafter.openai.chat.*
75
import com.cjcrafter.openai.exception.OpenAIError
86
import com.cjcrafter.openai.exception.WrappedIOError
7+
import com.cjcrafter.openai.gson.ChatUserAdapter
8+
import com.cjcrafter.openai.gson.FinishReasonAdapter
99
import com.google.gson.Gson
1010
import com.google.gson.GsonBuilder
1111
import com.google.gson.JsonObject
@@ -39,10 +39,8 @@ class OpenAI @JvmOverloads constructor(
3939
private val client: OkHttpClient = OkHttpClient()
4040
) {
4141

42-
private val mediaType: MediaType = "application/json; charset=utf-8".toMediaType()
43-
private val gson: Gson = GsonBuilder()
44-
.registerTypeAdapter(ChatUser::class.java, JsonSerializer<ChatUser> { src, _, context -> context!!.serialize(src!!.name.lowercase())!! })
45-
.create()
42+
private val mediaType = "application/json; charset=utf-8".toMediaType()
43+
private val gson = createGson()
4644

4745
private fun buildRequest(request: Any): Request {
4846
val json = gson.toJson(request)
@@ -178,7 +176,7 @@ class OpenAI @JvmOverloads constructor(
178176

179177
val rootObject = JsonParser.parseString(jsonResponse).asJsonObject
180178
if (cache == null)
181-
cache = ChatResponseChunk(rootObject)
179+
cache = gson.fromJson(rootObject, ChatResponseChunk::class.java)
182180
else
183181
cache!!.update(rootObject)
184182

@@ -188,4 +186,16 @@ class OpenAI @JvmOverloads constructor(
188186
}
189187
})
190188
}
189+
190+
companion object {
191+
192+
@JvmStatic
193+
fun createGson(): Gson {
194+
return GsonBuilder()
195+
.registerTypeAdapter(ChatUser::class.java, ChatUserAdapter())
196+
.registerTypeAdapter(FinishReason::class.java, FinishReasonAdapter())
197+
.registerTypeAdapter(ChatChoiceChunk::class.java, ChatChoiceChunkAdapter())
198+
.create()
199+
}
200+
}
191201
}

src/main/kotlin/com/cjcrafter/openai/chat/ChatBot.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class ChatBot @JvmOverloads constructor(
7979
if (rootObject!!.has("error"))
8080
throw OpenAIError.fromJson(rootObject!!["error"].asJsonObject)
8181

82-
return ChatResponse(rootObject!!)
82+
return gson.fromJson(rootObject, ChatResponse::class.java)
8383
}
8484
} catch (ex: Throwable) {
8585
println(rootObject)
@@ -174,7 +174,7 @@ class ChatBot @JvmOverloads constructor(
174174

175175
val rootObject = JsonParser.parseString(jsonResponse).asJsonObject
176176
if (cache == null)
177-
cache = ChatResponseChunk(rootObject)
177+
cache = gson.fromJson(rootObject, ChatResponseChunk::class.java)
178178
else
179179
cache!!.update(rootObject)
180180

src/main/kotlin/com/cjcrafter/openai/chat/ChatUser.kt

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.cjcrafter.openai.chat
22

3+
import com.cjcrafter.openai.gson.ChatUserAdapter
4+
35
/**
46
* ChatGPT's biggest innovation is its conversational memory. To remember the
57
* conversation, we need to map each message to who sent it. This enum stores
@@ -31,5 +33,18 @@ enum class ChatUser {
3133
/**
3234
* [ASSISTANT] is the AI that generates responses.
3335
*/
34-
ASSISTANT
36+
ASSISTANT;
37+
38+
companion object {
39+
40+
/**
41+
* Adapter
42+
*
43+
* @return
44+
*/
45+
@JvmStatic
46+
fun adapter() : ChatUserAdapter {
47+
return ChatUserAdapter()
48+
}
49+
}
3550
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import com.cjcrafter.openai.FinishReason
2+
import com.cjcrafter.openai.chat.ChatChoiceChunk
3+
import com.cjcrafter.openai.chat.ChatMessage
4+
import com.cjcrafter.openai.chat.ChatUser
5+
import com.google.gson.GsonBuilder
6+
import com.google.gson.TypeAdapter
7+
import com.google.gson.stream.JsonReader
8+
import com.google.gson.stream.JsonToken
9+
import com.google.gson.stream.JsonWriter
10+
11+
class ChatChoiceChunkAdapter : TypeAdapter<ChatChoiceChunk?>() {
12+
13+
override fun write(writer: JsonWriter, value: ChatChoiceChunk?) {
14+
if (value == null) {
15+
writer.nullValue()
16+
} else {
17+
writer.beginObject()
18+
writer.name("index").value(value.index)
19+
writer.name("message").jsonValue(GsonBuilder().create().toJson(value.message))
20+
writer.name("delta").value(value.delta)
21+
writer.name("finish_reason")
22+
if (value.finishReason == null) {
23+
writer.nullValue()
24+
} else {
25+
writer.value(value.finishReason!!.name)
26+
}
27+
writer.endObject()
28+
}
29+
}
30+
31+
override fun read(reader: JsonReader): ChatChoiceChunk? {
32+
var index: Int = -1
33+
var message: ChatMessage? = null
34+
var delta: String? = null
35+
var finishReason: FinishReason? = null
36+
37+
reader.beginObject()
38+
while (reader.hasNext()) {
39+
when (reader.nextName()) {
40+
"index" -> index = reader.nextInt()
41+
"message" -> {
42+
when (reader.peek()) {
43+
JsonToken.NULL -> reader.nextNull()
44+
else -> message = GsonBuilder().create().fromJson(reader, ChatMessage::class.java)
45+
}
46+
}
47+
"delta" -> {
48+
when (reader.peek()) {
49+
JsonToken.BEGIN_OBJECT -> reader.skipValue()
50+
else -> delta = reader.nextString()
51+
}
52+
}
53+
"finish_reason" -> {
54+
when (reader.peek()) {
55+
JsonToken.NULL -> reader.skipValue()
56+
else -> finishReason = FinishReason.valueOf(reader.nextString().uppercase())
57+
}
58+
}
59+
else -> reader.skipValue()
60+
}
61+
}
62+
reader.endObject()
63+
64+
return ChatChoiceChunk(index, message ?: ChatMessage(ChatUser.ASSISTANT, ""), delta ?: "", finishReason)
65+
}
66+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.cjcrafter.openai.gson
2+
3+
import com.cjcrafter.openai.chat.ChatUser
4+
import com.google.gson.TypeAdapter
5+
import com.google.gson.stream.JsonReader
6+
import com.google.gson.stream.JsonToken
7+
import com.google.gson.stream.JsonWriter
8+
9+
class ChatUserAdapter : TypeAdapter<ChatUser?>() {
10+
11+
override fun write(writer: JsonWriter, value: ChatUser?) {
12+
if (value == null) {
13+
writer.nullValue()
14+
} else {
15+
writer.value(value.name.lowercase())
16+
}
17+
}
18+
19+
override fun read(reader: JsonReader): ChatUser? {
20+
return if (reader.peek() == JsonToken.NULL) {
21+
reader.nextNull()
22+
null
23+
} else {
24+
val name = reader.nextString()
25+
ChatUser.valueOf(name.uppercase())
26+
}
27+
}
28+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.cjcrafter.openai.gson
2+
3+
import com.cjcrafter.openai.FinishReason
4+
import com.google.gson.TypeAdapter
5+
import com.google.gson.stream.JsonReader
6+
import com.google.gson.stream.JsonToken
7+
import com.google.gson.stream.JsonWriter
8+
9+
class FinishReasonAdapter : TypeAdapter<FinishReason?>() {
10+
11+
override fun write(writer: JsonWriter, value: FinishReason?) {
12+
if (value == null) {
13+
writer.nullValue()
14+
} else {
15+
writer.value(value.name)
16+
}
17+
}
18+
19+
override fun read(reader: JsonReader): FinishReason? {
20+
return if (reader.peek() == JsonToken.NULL) {
21+
reader.nextNull()
22+
null
23+
} else {
24+
val name = reader.nextString()
25+
FinishReason.valueOf(name.uppercase())
26+
}
27+
}
28+
}

0 commit comments

Comments
 (0)