Skip to content

Commit 3c15343

Browse files
committed
New OpenAIService function/endpoint - createChatCompletion (impl, settings, response, json formatting)
1 parent c8e8a68 commit 3c15343

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

openai-client/src/main/scala/io/cequence/openaiscala/JsonFormats.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package io.cequence.openaiscala
22

3+
import io.cequence.openaiscala.JsonUtil.JsonOps
4+
import io.cequence.openaiscala.domain.ChatRole
5+
36
import java.{util => ju}
47
import io.cequence.openaiscala.domain.response._
58
import play.api.libs.functional.syntax._
@@ -20,6 +23,29 @@ object JsonFormats {
2023
implicit val textCompletionChoiceInfoFormat: Format[TextCompletionChoiceInfo] = Json.format[TextCompletionChoiceInfo]
2124
implicit val textCompletionFormat: Format[TextCompletionResponse] = Json.format[TextCompletionResponse]
2225

26+
implicit object ChatRoleFormat extends Format[ChatRole] {
27+
override def reads(json: JsValue): JsResult[ChatRole] = {
28+
json.asSafe[String] match {
29+
case "user" => JsSuccess(ChatRole.User)
30+
case "assistant" => JsSuccess(ChatRole.Assistant)
31+
case "system" => JsSuccess(ChatRole.System)
32+
case x => JsError(s"$x is not a valid message role.")
33+
}
34+
}
35+
36+
override def writes(o: ChatRole): JsValue = {
37+
JsString(o.toString.toLowerCase())
38+
}
39+
}
40+
41+
implicit val chatMessageFormat: Format[ChatMessage] = Json.format[ChatMessage]
42+
implicit val chatCompletionChoiceInfoFormat: Format[ChatCompletionChoiceInfo] = Json.format[ChatCompletionChoiceInfo]
43+
implicit val chatCompletionResponseFormat: Format[ChatCompletionResponse] = Json.format[ChatCompletionResponse]
44+
45+
implicit val chatChunkMessageFormat: Format[ChatChunkMessage] = Json.format[ChatChunkMessage]
46+
implicit val chatCompletionChoiceChunkInfoFormat: Format[ChatCompletionChoiceChunkInfo] = Json.format[ChatCompletionChoiceChunkInfo]
47+
implicit val chatCompletionChunkResponseFormat: Format[ChatCompletionChunkResponse] = Json.format[ChatCompletionChunkResponse]
48+
2349
implicit val textEditChoiceInfoFormat: Format[TextEditChoiceInfo] = Json.format[TextEditChoiceInfo]
2450
implicit val textEditFormat: Format[TextEditResponse] = Json.format[TextEditResponse]
2551

openai-client/src/main/scala/io/cequence/openaiscala/service/Command.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package io.cequence.openaiscala.service
33
object Command extends Enumeration {
44
val models = Value
55
val completions = Value
6+
val chat_completions = Value("chat/completions")
67
val edits = Value
78
val images_generations = Value("images/generations")
89
val images_edits = Value("images/edits")
@@ -18,7 +19,7 @@ object Command extends Enumeration {
1819

1920
object Tag extends Enumeration {
2021
val model, prompt, suffix, max_tokens, temperature, top_p, n, stream, logprobs, echo, stop,
21-
presence_penalty, frequency_penalty, best_of, logit_bias, user,
22+
presence_penalty, frequency_penalty, best_of, logit_bias, user, messages,
2223
input, image, mask, instruction, size, response_format, file, purpose, file_id,
2324
training_file, validation_file, n_epochs, batch_size, learning_rate_multiplier, prompt_loss_weight,
2425
compute_classification_metrics, classification_n_classes, classification_positive_class,

openai-client/src/main/scala/io/cequence/openaiscala/service/OpenAIServiceImpl.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import io.cequence.openaiscala.OpenAIScalaClientException
99
import io.cequence.openaiscala.domain.settings._
1010
import io.cequence.openaiscala.domain.response._
1111
import io.cequence.openaiscala.ConfigImplicits._
12+
import io.cequence.openaiscala.domain.MessageSpec
1213
import io.cequence.openaiscala.service.ws.{Timeouts, WSRequestHelper}
1314

1415
import java.io.File
@@ -118,6 +119,52 @@ private class OpenAIServiceImpl(
118119
Tag.user -> settings.user
119120
)
120121

122+
override def createChatCompletion(
123+
messages: Seq[MessageSpec],
124+
settings: CreateChatCompletionSettings
125+
): Future[ChatCompletionResponse] =
126+
execPOST(
127+
Command.chat_completions,
128+
bodyParams = createBodyParamsForChatCompletion(messages, settings, stream = false)
129+
).map(
130+
_.asSafe[ChatCompletionResponse]
131+
)
132+
133+
protected def createBodyParamsForChatCompletion(
134+
messages: Seq[MessageSpec],
135+
settings: CreateChatCompletionSettings,
136+
stream: Boolean
137+
) = {
138+
assert(messages.nonEmpty, "At least one message expected.")
139+
140+
val messageJsons = messages.map { case MessageSpec(role, content) =>
141+
Json.obj("role" -> role.toString.toLowerCase, "content" -> content)
142+
}
143+
144+
jsonBodyParams(
145+
Tag.messages -> Some(JsArray(messageJsons)),
146+
Tag.model -> Some(settings.model),
147+
Tag.temperature -> settings.temperature,
148+
Tag.top_p -> settings.top_p,
149+
Tag.n -> settings.n,
150+
Tag.stream -> Some(stream),
151+
Tag.stop -> {
152+
settings.stop.size match {
153+
case 0 => None
154+
case 1 => Some(settings.stop.head)
155+
case _ => Some(settings.stop)
156+
}
157+
},
158+
Tag.max_tokens -> settings.max_tokens,
159+
Tag.presence_penalty -> settings.presence_penalty,
160+
Tag.frequency_penalty -> settings.frequency_penalty,
161+
Tag.logit_bias -> {
162+
if (settings.logit_bias.isEmpty) None else Some(settings.logit_bias)
163+
},
164+
Tag.user -> settings.user
165+
)
166+
}
167+
121168
override def createEdit(
122169
input: String,
123170
instruction: String,

0 commit comments

Comments
 (0)