Skip to content

Commit b6b117a

Browse files
committed
reimplement completions api
1 parent de6be93 commit b6b117a

File tree

3 files changed

+92
-55
lines changed

3 files changed

+92
-55
lines changed

examples/src/main/kotlin/completion/Completion.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import com.cjcrafter.openai.openAI
55
import io.github.cdimascio.dotenv.dotenv
66

77
/**
8-
* In this Kotlin example, we will be using the Chat API to create a simple chatbot.
8+
* In this Kotlin example, we will be using the Completions API to generate a response.
99
*/
1010
fun main() {
1111

@@ -17,7 +17,7 @@ fun main() {
1717
// Here you can change the model's settings, add tools, and more.
1818
val request = completionRequest {
1919
model("davinci")
20-
prompt("What is 9+10")
20+
prompt("The wheels on the bus go")
2121
}
2222

2323
val completion = openai.createCompletion(request)[0]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package completion
2+
3+
import com.cjcrafter.openai.completions.completionRequest
4+
import com.cjcrafter.openai.openAI
5+
import io.github.cdimascio.dotenv.dotenv
6+
7+
/**
8+
* In this Kotlin example, we will be using the Completions API to generate a
9+
* response. We will stream the tokens 1 at a time for a faster response time.
10+
*/
11+
fun main() {
12+
13+
// To use dotenv, you need to add the "io.github.cdimascio:dotenv-kotlin:version"
14+
// dependency. Then you can add a .env file in your project directory.
15+
val key = dotenv()["OPENAI_TOKEN"]
16+
val openai = openAI { apiKey(key) }
17+
18+
// Here you can change the model's settings, add tools, and more.
19+
val request = completionRequest {
20+
model("davinci")
21+
prompt("The wheels on the bus go")
22+
maxTokens(500)
23+
}
24+
25+
for (chunk in openai.streamCompletion(request)) {
26+
print(chunk.choices[0].text)
27+
}
28+
}

src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt

Lines changed: 62 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import com.cjcrafter.openai.chat.*
44
import com.cjcrafter.openai.completions.CompletionRequest
55
import com.cjcrafter.openai.completions.CompletionResponse
66
import com.cjcrafter.openai.completions.CompletionResponseChunk
7-
import com.cjcrafter.openai.completions.CompletionUsage
7+
import com.fasterxml.jackson.databind.JavaType
88
import com.fasterxml.jackson.databind.node.ObjectNode
99
import okhttp3.*
1010
import okhttp3.MediaType.Companion.toMediaType
@@ -32,62 +32,40 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
3232
.post(body).build()
3333
}
3434

35-
override fun createCompletion(request: CompletionRequest): CompletionResponse {
36-
@Suppress("DEPRECATION")
37-
request.stream = false // use streamCompletion for stream=true
38-
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
39-
40-
val httpResponse = client.newCall(httpRequest).execute()
41-
println(httpResponse)
42-
43-
return CompletionResponse("1", 1, "1", listOf(), CompletionUsage(1, 1, 1))
44-
}
45-
46-
override fun streamCompletion(request: CompletionRequest): Iterable<CompletionResponseChunk> {
47-
@Suppress("DEPRECATION")
48-
request.stream = true // use createCompletion for stream=false
49-
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
50-
51-
return listOf()
52-
}
53-
54-
override fun createChatCompletion(request: ChatRequest): ChatResponse {
55-
@Suppress("DEPRECATION")
56-
request.stream = false // use streamChatCompletion for stream=true
57-
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
58-
35+
protected open fun <T> executeRequest(httpRequest: Request, responseType: Class<T>): T {
5936
val httpResponse = client.newCall(httpRequest).execute()
6037
if (!httpResponse.isSuccessful) {
6138
val json = httpResponse.body?.byteStream()?.bufferedReader()?.readText()
6239
httpResponse.close()
63-
throw IOException("Unexpected code $httpResponse, recieved: $json")
40+
throw IOException("Unexpected code $httpResponse, received: $json")
6441
}
6542

66-
val json = httpResponse.body?.byteStream()?.bufferedReader() ?: throw IOException("Response body is null")
67-
val str = json.readText()
68-
return objectMapper.readValue(str, ChatResponse::class.java)
43+
val jsonReader = httpResponse.body?.byteStream()?.bufferedReader()
44+
?: throw IOException("Response body is null")
45+
val responseStr = jsonReader.readText()
46+
return objectMapper.readValue(responseStr, responseType)
6947
}
7048

71-
override fun streamChatCompletion(request: ChatRequest): Iterable<ChatResponseChunk> {
72-
request.stream = true // Set streaming to true
73-
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
74-
75-
return object : Iterable<ChatResponseChunk> {
76-
override fun iterator(): Iterator<ChatResponseChunk> {
77-
val httpResponse = client.newCall(httpRequest).execute()
49+
private fun <T> streamResponses(
50+
request: Request,
51+
responseType: JavaType,
52+
updateResponse: (T, String) -> T
53+
): Iterable<T> {
54+
return object : Iterable<T> {
55+
override fun iterator(): Iterator<T> {
56+
val httpResponse = client.newCall(request).execute()
7857

7958
if (!httpResponse.isSuccessful) {
8059
httpResponse.close()
8160
throw IOException("Unexpected code $httpResponse")
8261
}
8362

84-
val reader = httpResponse.body?.byteStream()?.bufferedReader() ?: throw IOException("Response body is null")
63+
val reader = httpResponse.body?.byteStream()?.bufferedReader()
64+
?: throw IOException("Response body is null")
8565

86-
// Only instantiate 1 ChatResponseChunk, otherwise simply update
87-
// the existing one. This lets us accumulate the message.
88-
var chunk: ChatResponseChunk? = null
66+
var currentResponse: T? = null
8967

90-
return object : Iterator<ChatResponseChunk> {
68+
return object : Iterator<T> {
9169
private var nextLine: String? = readNextLine(reader)
9270

9371
private fun readNextLine(reader: BufferedReader): String? {
@@ -98,8 +76,6 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
9876
reader.close()
9977
return null
10078
}
101-
102-
// Check if the line starts with 'data:' and skip empty lines
10379
} while (line != null && (line.isEmpty() || !line.startsWith("data: ")))
10480
return line?.removePrefix("data: ")
10581
}
@@ -108,24 +84,57 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
10884
return nextLine != null
10985
}
11086

111-
override fun next(): ChatResponseChunk {
112-
val currentLine = nextLine ?: throw NoSuchElementException("No more lines")
113-
//println(" $currentLine")
114-
chunk = chunk?.apply { update(objectMapper.readTree(currentLine) as ObjectNode) } ?: objectMapper.readValue(currentLine, ChatResponseChunk::class.java)
115-
nextLine = readNextLine(reader) // Prepare the next line
116-
return chunk!!
117-
//return ChatResponseChunk("1", 1, listOf())
87+
override fun next(): T {
88+
val line = nextLine ?: throw NoSuchElementException("No more lines")
89+
currentResponse = if (currentResponse == null) {
90+
objectMapper.readValue(line, responseType)
91+
} else {
92+
updateResponse(currentResponse!!, line)
93+
}
94+
nextLine = readNextLine(reader)
95+
return currentResponse!!
11896
}
11997
}
12098
}
12199
}
122100
}
123101

102+
override fun createCompletion(request: CompletionRequest): CompletionResponse {
103+
@Suppress("DEPRECATION")
104+
request.stream = false // use streamCompletion for stream=true
105+
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
106+
return executeRequest(httpRequest, CompletionResponse::class.java)
107+
}
108+
109+
override fun streamCompletion(request: CompletionRequest): Iterable<CompletionResponseChunk> {
110+
@Suppress("DEPRECATION")
111+
request.stream = true
112+
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
113+
return streamResponses(httpRequest, objectMapper.typeFactory.constructType(CompletionResponseChunk::class.java)) { response, newLine ->
114+
// We don't have any update logic, so we should ignore the old response and just return a new one
115+
objectMapper.readValue(newLine, CompletionResponseChunk::class.java)
116+
}
117+
}
118+
119+
override fun createChatCompletion(request: ChatRequest): ChatResponse {
120+
@Suppress("DEPRECATION")
121+
request.stream = false // use streamChatCompletion for stream=true
122+
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
123+
return executeRequest(httpRequest, ChatResponse::class.java)
124+
}
125+
126+
override fun streamChatCompletion(request: ChatRequest): Iterable<ChatResponseChunk> {
127+
@Suppress("DEPRECATION")
128+
request.stream = true
129+
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
130+
return streamResponses(httpRequest, objectMapper.typeFactory.constructType(ChatResponseChunk::class.java)) { response, newLine ->
131+
response.update(objectMapper.readTree(newLine) as ObjectNode)
132+
response
133+
}
134+
}
135+
124136
companion object {
125137
const val COMPLETIONS_ENDPOINT = "v1/completions"
126138
const val CHAT_ENDPOINT = "v1/chat/completions"
127-
const val IMAGE_CREATE_ENDPOINT = "v1/images/generations"
128-
const val IMAGE_EDIT_ENDPOINT = "v1/images/edits"
129-
const val IMAGE_VARIATION_ENDPOINT = "v1/images/variations"
130139
}
131140
}

0 commit comments

Comments
 (0)