@@ -4,7 +4,7 @@ import com.cjcrafter.openai.chat.*
4
4
import com.cjcrafter.openai.completions.CompletionRequest
5
5
import com.cjcrafter.openai.completions.CompletionResponse
6
6
import com.cjcrafter.openai.completions.CompletionResponseChunk
7
- import com.cjcrafter.openai.completions.CompletionUsage
7
+ import com.fasterxml.jackson.databind.JavaType
8
8
import com.fasterxml.jackson.databind.node.ObjectNode
9
9
import okhttp3.*
10
10
import okhttp3.MediaType.Companion.toMediaType
@@ -32,62 +32,40 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
32
32
.post(body).build()
33
33
}
34
34
35
- override fun createCompletion (request : CompletionRequest ): CompletionResponse {
36
- @Suppress(" DEPRECATION" )
37
- request.stream = false // use streamCompletion for stream=true
38
- val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT )
39
-
40
- val httpResponse = client.newCall(httpRequest).execute()
41
- println (httpResponse)
42
-
43
- return CompletionResponse (" 1" , 1 , " 1" , listOf (), CompletionUsage (1 , 1 , 1 ))
44
- }
45
-
46
- override fun streamCompletion (request : CompletionRequest ): Iterable <CompletionResponseChunk > {
47
- @Suppress(" DEPRECATION" )
48
- request.stream = true // use createCompletion for stream=false
49
- val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT )
50
-
51
- return listOf ()
52
- }
53
-
54
- override fun createChatCompletion (request : ChatRequest ): ChatResponse {
55
- @Suppress(" DEPRECATION" )
56
- request.stream = false // use streamChatCompletion for stream=true
57
- val httpRequest = buildRequest(request, CHAT_ENDPOINT )
58
-
35
+ protected open fun <T > executeRequest (httpRequest : Request , responseType : Class <T >): T {
59
36
val httpResponse = client.newCall(httpRequest).execute()
60
37
if (! httpResponse.isSuccessful) {
61
38
val json = httpResponse.body?.byteStream()?.bufferedReader()?.readText()
62
39
httpResponse.close()
63
- throw IOException (" Unexpected code $httpResponse , recieved : $json " )
40
+ throw IOException (" Unexpected code $httpResponse , received : $json " )
64
41
}
65
42
66
- val json = httpResponse.body?.byteStream()?.bufferedReader() ? : throw IOException (" Response body is null" )
67
- val str = json.readText()
68
- return objectMapper.readValue(str, ChatResponse ::class .java)
43
+ val jsonReader = httpResponse.body?.byteStream()?.bufferedReader()
44
+ ? : throw IOException (" Response body is null" )
45
+ val responseStr = jsonReader.readText()
46
+ return objectMapper.readValue(responseStr, responseType)
69
47
}
70
48
71
- override fun streamChatCompletion (request : ChatRequest ): Iterable <ChatResponseChunk > {
72
- request.stream = true // Set streaming to true
73
- val httpRequest = buildRequest(request, CHAT_ENDPOINT )
74
-
75
- return object : Iterable <ChatResponseChunk > {
76
- override fun iterator (): Iterator <ChatResponseChunk > {
77
- val httpResponse = client.newCall(httpRequest).execute()
49
+ private fun <T > streamResponses (
50
+ request : Request ,
51
+ responseType : JavaType ,
52
+ updateResponse : (T , String ) -> T
53
+ ): Iterable <T > {
54
+ return object : Iterable <T > {
55
+ override fun iterator (): Iterator <T > {
56
+ val httpResponse = client.newCall(request).execute()
78
57
79
58
if (! httpResponse.isSuccessful) {
80
59
httpResponse.close()
81
60
throw IOException (" Unexpected code $httpResponse " )
82
61
}
83
62
84
- val reader = httpResponse.body?.byteStream()?.bufferedReader() ? : throw IOException (" Response body is null" )
63
+ val reader = httpResponse.body?.byteStream()?.bufferedReader()
64
+ ? : throw IOException (" Response body is null" )
85
65
86
- // Only instantiate 1 ChatResponseChunk, otherwise simply update
87
- // the existing one. This lets us accumulate the message.
88
- var chunk: ChatResponseChunk ? = null
66
+ var currentResponse: T ? = null
89
67
90
- return object : Iterator <ChatResponseChunk > {
68
+ return object : Iterator <T > {
91
69
private var nextLine: String? = readNextLine(reader)
92
70
93
71
private fun readNextLine (reader : BufferedReader ): String? {
@@ -98,8 +76,6 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
98
76
reader.close()
99
77
return null
100
78
}
101
-
102
- // Check if the line starts with 'data:' and skip empty lines
103
79
} while (line != null && (line.isEmpty() || ! line.startsWith(" data: " )))
104
80
return line?.removePrefix(" data: " )
105
81
}
@@ -108,24 +84,57 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
108
84
return nextLine != null
109
85
}
110
86
111
- override fun next (): ChatResponseChunk {
112
- val currentLine = nextLine ? : throw NoSuchElementException (" No more lines" )
113
- // println(" $currentLine")
114
- chunk = chunk?.apply { update(objectMapper.readTree(currentLine) as ObjectNode ) } ? : objectMapper.readValue(currentLine, ChatResponseChunk ::class .java)
115
- nextLine = readNextLine(reader) // Prepare the next line
116
- return chunk!!
117
- // return ChatResponseChunk("1", 1, listOf())
87
+ override fun next (): T {
88
+ val line = nextLine ? : throw NoSuchElementException (" No more lines" )
89
+ currentResponse = if (currentResponse == null ) {
90
+ objectMapper.readValue(line, responseType)
91
+ } else {
92
+ updateResponse(currentResponse!! , line)
93
+ }
94
+ nextLine = readNextLine(reader)
95
+ return currentResponse!!
118
96
}
119
97
}
120
98
}
121
99
}
122
100
}
123
101
102
+ override fun createCompletion (request : CompletionRequest ): CompletionResponse {
103
+ @Suppress(" DEPRECATION" )
104
+ request.stream = false // use streamCompletion for stream=true
105
+ val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT )
106
+ return executeRequest(httpRequest, CompletionResponse ::class .java)
107
+ }
108
+
109
+ override fun streamCompletion (request : CompletionRequest ): Iterable <CompletionResponseChunk > {
110
+ @Suppress(" DEPRECATION" )
111
+ request.stream = true
112
+ val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT )
113
+ return streamResponses(httpRequest, objectMapper.typeFactory.constructType(CompletionResponseChunk ::class .java)) { response, newLine ->
114
+ // We don't have any update logic, so we should ignore the old response and just return a new one
115
+ objectMapper.readValue(newLine, CompletionResponseChunk ::class .java)
116
+ }
117
+ }
118
+
119
+ override fun createChatCompletion (request : ChatRequest ): ChatResponse {
120
+ @Suppress(" DEPRECATION" )
121
+ request.stream = false // use streamChatCompletion for stream=true
122
+ val httpRequest = buildRequest(request, CHAT_ENDPOINT )
123
+ return executeRequest(httpRequest, ChatResponse ::class .java)
124
+ }
125
+
126
+ override fun streamChatCompletion (request : ChatRequest ): Iterable <ChatResponseChunk > {
127
+ @Suppress(" DEPRECATION" )
128
+ request.stream = true
129
+ val httpRequest = buildRequest(request, CHAT_ENDPOINT )
130
+ return streamResponses(httpRequest, objectMapper.typeFactory.constructType(ChatResponseChunk ::class .java)) { response, newLine ->
131
+ response.update(objectMapper.readTree(newLine) as ObjectNode )
132
+ response
133
+ }
134
+ }
135
+
124
136
companion object {
125
137
const val COMPLETIONS_ENDPOINT = " v1/completions"
126
138
const val CHAT_ENDPOINT = " v1/chat/completions"
127
- const val IMAGE_CREATE_ENDPOINT = " v1/images/generations"
128
- const val IMAGE_EDIT_ENDPOINT = " v1/images/edits"
129
- const val IMAGE_VARIATION_ENDPOINT = " v1/images/variations"
130
139
}
131
140
}
0 commit comments