Skip to content

Commit 91f1794

Browse files
committed
New method "createChatCompletionForFunctions" with an. impl added to OpenAIService{Impl} supporting 'function call' response, for which the model may generate JSON inputs by passing the definition of 'functions' as a param.
1 parent 3973d10 commit 91f1794

File tree

8 files changed

+84
-12
lines changed

8 files changed

+84
-12
lines changed

openai-client-stream/src/main/scala/io/cequence/openaiscala/service/OpenAIServiceStreamedExtra.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package io.cequence.openaiscala.service
22

33
import akka.NotUsed
44
import akka.stream.scaladsl.Source
5-
import io.cequence.openaiscala.domain.MessageSpec
5+
import io.cequence.openaiscala.domain.{FunctionSpec, MessageSpec}
66
import io.cequence.openaiscala.domain.response.{ChatCompletionChunkResponse, ChatCompletionResponse, FineTuneEvent, TextCompletionResponse}
77
import io.cequence.openaiscala.domain.settings.{CreateChatCompletionSettings, CreateCompletionSettings}
88

@@ -29,7 +29,7 @@ trait OpenAIServiceStreamedExtra extends OpenAIServiceConsts {
2929
/**
3030
* Creates a completion for the chat message(s) with streamed results.
3131
*
32-
* @param messages The messages to generate chat completions.
32+
* @param messages A list of messages comprising the conversation so far.
3333
* @param settings
3434
* @return chat completion response
3535
*

openai-client/src/main/scala/io/cequence/openaiscala/JsonFormats.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package io.cequence.openaiscala
22

33
import io.cequence.openaiscala.JsonUtil.JsonOps
4-
import io.cequence.openaiscala.domain.ChatRole
4+
import io.cequence.openaiscala.domain.{ChatRole, FunctionSpec, MessageSpec, FunctionCallSpec}
55

66
import java.{util => ju}
77
import io.cequence.openaiscala.domain.response._
@@ -29,6 +29,7 @@ object JsonFormats {
2929
case "user" => JsSuccess(ChatRole.User)
3030
case "assistant" => JsSuccess(ChatRole.Assistant)
3131
case "system" => JsSuccess(ChatRole.System)
32+
case "function" => JsSuccess(ChatRole.Function)
3233
case x => JsError(s"$x is not a valid message role.")
3334
}
3435
}
@@ -38,6 +39,10 @@ object JsonFormats {
3839
}
3940
}
4041

42+
implicit val messageSpecFormat: Format[MessageSpec] = Json.format[MessageSpec]
43+
private implicit val stringAnyMapFormat: Format[Map[String, Any]] = JsonUtil.StringAnyMapFormat
44+
implicit val FunctionCallSpec: Format[FunctionCallSpec] = Json.format[FunctionCallSpec]
45+
implicit val functionSpecFormat: Format[FunctionSpec] = Json.format[FunctionSpec]
4146
implicit val chatMessageFormat: Format[ChatMessage] = Json.format[ChatMessage]
4247
implicit val chatCompletionChoiceInfoFormat: Format[ChatCompletionChoiceInfo] = Json.format[ChatCompletionChoiceInfo]
4348
implicit val chatCompletionResponseFormat: Format[ChatCompletionResponse] = Json.format[ChatCompletionResponse]

openai-client/src/main/scala/io/cequence/openaiscala/JsonUtil.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,16 @@ object JsonUtil {
9595
JsObject(fields)
9696
}
9797
}
98+
99+
object StringAnyMapFormat extends Format[Map[String, Any]] {
100+
override def reads(json: JsValue): JsResult[Map[String, Any]] = {
101+
val resultJsons = json.asSafe[JsObject].fields.map { case (fieldName, jsValue) => (fieldName, jsValue.toString) }
102+
JsSuccess(resultJsons.toMap)
103+
}
104+
105+
override def writes(o: Map[String, Any]): JsValue = {
106+
val fields = o.map { case (fieldName, value) => (fieldName, toJson(value)) }
107+
JsObject(fields)
108+
}
109+
}
98110
}

openai-client/src/main/scala/io/cequence/openaiscala/service/Command.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,6 @@ object Tag {
5959
case object classification_positive_class extends Tag
6060
case object classification_betas extends Tag
6161
case object language extends Tag
62+
case object functions extends Tag
63+
case object function_call extends Tag
6264
}

openai-client/src/main/scala/io/cequence/openaiscala/service/OpenAIServiceImpl.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import io.cequence.openaiscala.OpenAIScalaClientException
99
import io.cequence.openaiscala.domain.settings._
1010
import io.cequence.openaiscala.domain.response._
1111
import io.cequence.openaiscala.ConfigImplicits._
12-
import io.cequence.openaiscala.domain.MessageSpec
12+
import io.cequence.openaiscala.domain.{FunctionSpec, MessageSpec}
1313
import io.cequence.openaiscala.service.ws.{Timeouts, WSRequestHelper}
1414

1515
import java.io.File
@@ -116,16 +116,38 @@ private class OpenAIServiceImpl(
116116
_.asSafe[ChatCompletionResponse]
117117
)
118118

119+
override def createChatCompletionForFunctions(
120+
messages: Seq[MessageSpec],
121+
functions: Seq[FunctionSpec],
122+
responseFunctionName: Option[String],
123+
settings: CreateChatCompletionSettings
124+
): Future[ChatCompletionResponse] = {
125+
val coreParams = createBodyParamsForChatCompletion(messages, settings, stream = false)
126+
127+
val extraParams = jsonBodyParams(
128+
Tag.functions -> Some(Json.toJson(functions)),
129+
Tag.function_call -> responseFunctionName.map(name => Map("name" -> name)), // otherwise "auto" is used by default
130+
)
131+
132+
execPOST(
133+
Command.chat_completions,
134+
bodyParams = coreParams ++ extraParams
135+
).map(
136+
_.asSafe[ChatCompletionResponse]
137+
)
138+
}
139+
119140
protected def createBodyParamsForChatCompletion(
120141
messages: Seq[MessageSpec],
121142
settings: CreateChatCompletionSettings,
122143
stream: Boolean
123144
) = {
124145
assert(messages.nonEmpty, "At least one message expected.")
125146

126-
val messageJsons = messages.map { case MessageSpec(role, content) =>
127-
Json.obj("role" -> role.toString.toLowerCase, "content" -> content)
128-
}
147+
val messageJsons = messages.map(Json.toJson(_)(messageSpecFormat))
148+
149+
// case MessageSpec(role, content, name, function_call) =>
150+
// Json.obj("role" -> role.toString.toLowerCase, "content" -> content)
129151

130152
jsonBodyParams(
131153
Tag.messages -> Some(JsArray(messageJsons)),

openai-core/src/main/scala/io/cequence/openaiscala/service/OpenAIService.scala

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package io.cequence.openaiscala.service
22

3-
import io.cequence.openaiscala.domain.MessageSpec
3+
import io.cequence.openaiscala.domain.{FunctionSpec, MessageSpec}
44
import io.cequence.openaiscala.domain.settings._
55
import io.cequence.openaiscala.domain.response._
66

@@ -61,19 +61,36 @@ trait OpenAIService extends OpenAIServiceConsts {
6161
): Future[TextCompletionResponse]
6262

6363
/**
64-
* Creates a completion for the chat message(s).
64+
* Creates a model response for the given chat conversation.
6565
*
66-
* @param messages The messages to generate chat completions.
66+
* @param messages A list of messages comprising the conversation so far.
6767
* @param settings
6868
* @return chat completion response
69-
*
7069
* @see <a href="https://platform.openai.com/docs/api-reference/chat/create">OpenAI Doc</a>
7170
*/
7271
def createChatCompletion(
7372
messages: Seq[MessageSpec],
7473
settings: CreateChatCompletionSettings = DefaultSettings.CreateChatCompletion
7574
): Future[ChatCompletionResponse]
7675

76+
/**
77+
* Creates a model response for the given chat conversation expecting a function call.
78+
*
79+
* @param messages A list of messages comprising the conversation so far.
80+
* @param functions A list of functions the model may generate JSON inputs for.
81+
* @param responseFunctionName If specified it forces the model to respond with a call to that function (must be listed in `functions`).
82+
* Otherwise, the default "auto" mode is used where the model can pick between an end-user or calling a function.
83+
* @param settings
84+
* @return chat completion response
85+
* @see <a href="https://platform.openai.com/docs/api-reference/chat/create">OpenAI Doc</a>
86+
*/
87+
def createChatCompletionForFunctions(
88+
messages: Seq[MessageSpec],
89+
functions: Seq[FunctionSpec],
90+
responseFunctionName: Option[String] = None,
91+
settings: CreateChatCompletionSettings = DefaultSettings.CreateChatCompletionForFunctions
92+
): Future[ChatCompletionResponse]
93+
7794
/**
7895
* Creates a new edit for the provided input, instruction, and parameters.
7996
*

openai-core/src/main/scala/io/cequence/openaiscala/service/OpenAIServiceConsts.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ trait OpenAIServiceConsts {
3131
max_tokens = Some(1000)
3232
)
3333

34+
val CreateChatCompletionForFunctions = CreateChatCompletionSettings(
35+
model = ModelId.gpt_3_5_turbo_0613,
36+
max_tokens = Some(1000)
37+
)
38+
3439
val CreateEdit = CreateEditSettings(
3540
model = ModelId.text_davinci_edit_001,
3641
temperature = Some(0.7)

openai-core/src/main/scala/io/cequence/openaiscala/service/OpenAIServiceWrapper.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package io.cequence.openaiscala.service
22

3-
import io.cequence.openaiscala.domain.MessageSpec
3+
import io.cequence.openaiscala.domain.{FunctionSpec, MessageSpec}
44
import io.cequence.openaiscala.domain.settings._
55

66
import java.io.File
@@ -33,6 +33,15 @@ trait OpenAIServiceWrapper extends OpenAIService {
3333
_.createChatCompletion(messages, settings)
3434
)
3535

36+
override def createChatCompletionForFunctions(
37+
messages: Seq[MessageSpec],
38+
functions: Seq[FunctionSpec],
39+
responseFunctionName: Option[String],
40+
settings: CreateChatCompletionSettings
41+
) = wrap(
42+
_.createChatCompletion(messages, settings)
43+
)
44+
3645
override def createEdit(
3746
input: String,
3847
instruction: String,

0 commit comments

Comments
 (0)