Skip to content

Commit 4bcd10c

Browse files
committed
make CompletionRequest mutable
1 parent 1405a31 commit 4bcd10c

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

src/main/kotlin/com/cjcrafter/openai/OpenAI.kt

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package com.cjcrafter.openai
22

33
import com.cjcrafter.openai.gson.ChatChoiceChunkAdapter
44
import com.cjcrafter.openai.chat.*
5+
import com.cjcrafter.openai.completions.CompletionRequest
6+
import com.cjcrafter.openai.completions.CompletionResponse
57
import com.cjcrafter.openai.exception.OpenAIError
68
import com.cjcrafter.openai.exception.WrappedIOError
79
import com.cjcrafter.openai.gson.ChatUserAdapter
@@ -40,17 +42,42 @@ class OpenAI @JvmOverloads constructor(
4042
private val mediaType = "application/json; charset=utf-8".toMediaType()
4143
private val gson = createGson()
4244

43-
private fun buildRequest(request: Any): Request {
45+
private fun buildRequest(request: Any, endpoint: String): Request {
4446
val json = gson.toJson(request)
4547
val body: RequestBody = json.toRequestBody(mediaType)
4648
return Request.Builder()
47-
.url("https://api.openai.com/v1/chat/completions")
49+
.url("https://api.openai.com/v1/$endpoint")
4850
.addHeader("Content-Type", "application/json")
4951
.addHeader("Authorization", "Bearer $apiKey")
5052
.apply { if (organization != null) addHeader("OpenAI-Organization", organization) }
5153
.post(body).build()
5254
}
5355

56+
@Throws(OpenAIError::class)
57+
fun createCompletion(request: CompletionRequest): CompletionResponse {
58+
@Suppress("DEPRECATION")
59+
request.stream = false // use streamResponse for stream=true
60+
val httpRequest = buildRequest(request, "completions")
61+
62+
// Save the JsonObject to check for errors
63+
var rootObject: JsonObject?
64+
try {
65+
client.newCall(httpRequest).execute().use { response ->
66+
67+
// Servers respond to API calls with json blocks. Since raw JSON isn't
68+
// very developer friendly, we wrap for easy data access.
69+
rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject
70+
if (rootObject!!.has("error"))
71+
throw OpenAIError.fromJson(rootObject!!.get("error").asJsonObject)
72+
73+
return gson.fromJson(rootObject, CompletionResponse::class.java)
74+
//return ChatResponse(rootObject!!)
75+
}
76+
} catch (ex: IOException) {
77+
throw WrappedIOError(ex)
78+
}
79+
}
80+
5481
/**
5582
* Blocks the current thread until OpenAI responds to https request. The
5683
* returned value includes information including tokens, generated text,
@@ -65,7 +92,7 @@ class OpenAI @JvmOverloads constructor(
6592
fun createChatCompletion(request: ChatRequest): ChatResponse {
6693
@Suppress("DEPRECATION")
6794
request.stream = false // use streamResponse for stream=true
68-
val httpRequest = buildRequest(request)
95+
val httpRequest = buildRequest(request, "chat/completions")
6996

7097
// Save the JsonObject to check for errors
7198
var rootObject: JsonObject?
@@ -150,7 +177,7 @@ class OpenAI @JvmOverloads constructor(
150177
) {
151178
@Suppress("DEPRECATION")
152179
request.stream = true // use requestResponse for stream=false
153-
val httpRequest = buildRequest(request)
180+
val httpRequest = buildRequest(request, "chat/completions")
154181

155182
client.newCall(httpRequest).enqueue(object : Callback {
156183
var cache: ChatResponseChunk? = null

src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,22 @@ import com.google.gson.annotations.SerializedName
3434
* @constructor Create a CompletionRequest instance. Recommend using [builder] instead.
3535
*/
3636
data class CompletionRequest @JvmOverloads constructor(
37-
val model: String,
38-
val prompt: Any,
39-
val suffix: String? = null,
40-
@field:SerializedName("max_tokens") val maxTokens: Int? = null,
41-
val temperature: Number? = null,
42-
@field:SerializedName("top_p") val topP: Number? = null,
43-
val n: Int? = null,
44-
@Deprecated("Use OpenAI#streamCompletion") val stream: Boolean? = null,
45-
val logprobs: Int? = null,
46-
val echo: Boolean? = null,
47-
val stop: Any? = null,
48-
@field:SerializedName("presence_penalty") val presencePenalty: Number? = null,
49-
@field:SerializedName("frequency_penalty") val frequencyPenalty: Number? = null,
50-
@field:SerializedName("best_of") val bestOf: Int? = null,
51-
@field:SerializedName("logit_bias") val logitBias: Map<String, Int>? = null,
52-
val user: String? = null
37+
var model: String,
38+
var prompt: Any,
39+
var suffix: String? = null,
40+
@field:SerializedName("max_tokens") var maxTokens: Int? = null,
41+
var temperature: Number? = null,
42+
@field:SerializedName("top_p") var topP: Number? = null,
43+
var n: Int? = null,
44+
@Deprecated("Use OpenAI#streamCompletion") var stream: Boolean? = null,
45+
var logprobs: Int? = null,
46+
var echo: Boolean? = null,
47+
var stop: Any? = null,
48+
@field:SerializedName("presence_penalty") var presencePenalty: Number? = null,
49+
@field:SerializedName("frequency_penalty") var frequencyPenalty: Number? = null,
50+
@field:SerializedName("best_of") var bestOf: Int? = null,
51+
@field:SerializedName("logit_bias") var logitBias: Map<String, Int>? = null,
52+
var user: String? = null
5353
) {
5454

5555
/**

0 commit comments

Comments
 (0)