Skip to content

Commit 579dea3

Browse files
authored
Merge pull request #90 from cequence-io/feature/3654-prompt-caching
Feature/3654 prompt caching
2 parents a9f92d6 + d281337 commit 579dea3

File tree

14 files changed

+74
-64
lines changed

14 files changed

+74
-64
lines changed

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/domain/ChatRole.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ sealed trait ChatRole extends EnumValue {
77
}
88

99
object ChatRole {
10+
case object System extends ChatRole
1011
case object User extends ChatRole
1112
case object Assistant extends ChatRole
1213

13-
def allValues: Seq[ChatRole] = Seq(User, Assistant)
14+
def allValues: Seq[ChatRole] = Seq(System, User, Assistant)
1415
}

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/domain/Message.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,20 @@ import io.cequence.openaiscala.anthropic.domain.Content.{
99
sealed abstract class Message private (
1010
val role: ChatRole,
1111
val content: Content
12-
)
12+
) {
13+
def isSystem: Boolean = role == ChatRole.System
14+
}
1315

1416
object Message {
1517

18+
case class SystemMessage(
19+
contentString: String,
20+
cacheControl: Option[CacheControl] = None
21+
) extends Message(ChatRole.System, SingleString(contentString, cacheControl))
22+
23+
case class SystemMessageContent(contentBlocks: Seq[ContentBlockBase])
24+
extends Message(ChatRole.System, ContentBlocks(contentBlocks))
25+
1626
case class UserMessage(
1727
contentString: String,
1828
cacheControl: Option[CacheControl] = None

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/service/AnthropicService.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package io.cequence.openaiscala.anthropic.service
22

33
import akka.NotUsed
44
import akka.stream.scaladsl.Source
5-
import io.cequence.openaiscala.anthropic.domain.{Content, Message}
5+
import io.cequence.openaiscala.anthropic.domain.Message
66
import io.cequence.openaiscala.anthropic.domain.response.{
77
ContentBlockDelta,
88
CreateMessageResponse
@@ -32,7 +32,6 @@ trait AnthropicService extends CloseableService with AnthropicServiceConsts {
3232
* <a href="https://docs.anthropic.com/claude/reference/messages_post">Anthropic Doc</a>
3333
*/
3434
def createMessage(
35-
system: Option[Content],
3635
messages: Seq[Message],
3736
settings: AnthropicCreateMessageSettings = DefaultSettings.CreateMessage
3837
): Future[CreateMessageResponse]
@@ -55,7 +54,6 @@ trait AnthropicService extends CloseableService with AnthropicServiceConsts {
5554
* <a href="https://docs.anthropic.com/claude/reference/messages_post">Anthropic Doc</a>
5655
*/
5756
def createMessageStreamed(
58-
system: Option[Content],
5957
messages: Seq[Message],
6058
settings: AnthropicCreateMessageSettings = DefaultSettings.CreateMessage
6159
): Source[ContentBlockDelta, NotUsed]

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/service/impl/AnthropicServiceImpl.scala

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import akka.NotUsed
44
import akka.stream.scaladsl.Source
55
import io.cequence.openaiscala.OpenAIScalaClientException
66
import io.cequence.openaiscala.anthropic.JsonFormats
7+
import io.cequence.openaiscala.anthropic.domain.Message.{SystemMessage, SystemMessageContent}
78
import io.cequence.openaiscala.anthropic.domain.response.{
89
ContentBlockDelta,
910
CreateMessageResponse
@@ -33,20 +34,17 @@ private[service] trait AnthropicServiceImpl extends Anthropic {
3334
private val logger = LoggerFactory.getLogger("AnthropicServiceImpl")
3435

3536
override def createMessage(
36-
system: Option[Content],
3737
messages: Seq[Message],
3838
settings: AnthropicCreateMessageSettings
3939
): Future[CreateMessageResponse] =
4040
execPOST(
4141
EndPoint.messages,
42-
bodyParams =
43-
createBodyParamsForMessageCreation(system, messages, settings, stream = false)
42+
bodyParams = createBodyParamsForMessageCreation(messages, settings, stream = false)
4443
).map(
4544
_.asSafeJson[CreateMessageResponse]
4645
)
4746

4847
override def createMessageStreamed(
49-
system: Option[Content],
5048
messages: Seq[Message],
5149
settings: AnthropicCreateMessageSettings
5250
): Source[ContentBlockDelta, NotUsed] =
@@ -55,7 +53,7 @@ private[service] trait AnthropicServiceImpl extends Anthropic {
5553
EndPoint.messages.toString(),
5654
"POST",
5755
bodyParams = paramTuplesToStrings(
58-
createBodyParamsForMessageCreation(system, messages, settings, stream = true)
56+
createBodyParamsForMessageCreation(messages, settings, stream = true)
5957
)
6058
)
6159
.map { (json: JsValue) =>
@@ -83,36 +81,42 @@ private[service] trait AnthropicServiceImpl extends Anthropic {
8381
.collect { case Some(delta) => delta }
8482

8583
private def createBodyParamsForMessageCreation(
86-
system: Option[Content],
8784
messages: Seq[Message],
8885
settings: AnthropicCreateMessageSettings,
8986
stream: Boolean
9087
): Seq[(Param, Option[JsValue])] = {
9188
assert(messages.nonEmpty, "At least one message expected.")
92-
assert(messages.head.role == ChatRole.User, "First message must be from user.")
9389

94-
val messageJsons = messages.map(Json.toJson(_))
90+
val (system, nonSystem) = messages.partition(_.isSystem)
9591

96-
val systemJson = system.map {
97-
case Content.SingleString(text, cacheControl) =>
92+
assert(nonSystem.head.role == ChatRole.User, "First non-system message must be from user.")
93+
assert(
94+
system.size <= 1,
95+
"System message can be only 1. Use SystemMessageContent to include more content blocks."
96+
)
97+
98+
val messageJsons = nonSystem.map(Json.toJson(_))
99+
100+
val systemJson: Seq[JsValue] = system.map {
101+
case SystemMessage(text, cacheControl) =>
98102
if (cacheControl.isEmpty) JsString(text)
99103
else {
100104
val blocks =
101105
Seq(Content.ContentBlockBase(Content.ContentBlock.TextBlock(text), cacheControl))
102106

103107
Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites))
104108
}
105-
case Content.ContentBlocks(blocks) =>
106-
Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites))
107-
case Content.ContentBlockBase(content, cacheControl) =>
108-
val blocks = Seq(Content.ContentBlockBase(content, cacheControl))
109+
case SystemMessageContent(blocks) =>
109110
Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites))
110111
}
111112

112113
jsonBodyParams(
113114
Param.messages -> Some(messageJsons),
114115
Param.model -> Some(settings.model),
115-
Param.system -> system.map(_ => systemJson),
116+
Param.system -> {
117+
if (system.isEmpty) None
118+
else Some(systemJson.head)
119+
},
116120
Param.max_tokens -> Some(settings.max_tokens),
117121
Param.metadata -> { if (settings.metadata.isEmpty) None else Some(settings.metadata) },
118122
Param.stop_sequences -> {

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/service/impl/OpenAIAnthropicChatCompletionService.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ private[service] class OpenAIAnthropicChatCompletionService(
4040
): Future[ChatCompletionResponse] = {
4141
underlying
4242
.createMessage(
43-
toAnthropicSystemMessages(messages, settings),
44-
toAnthropicMessages(messages, settings),
43+
toAnthropicSystemMessages(messages.filter(_.isSystem), settings) ++
44+
toAnthropicMessages(messages.filter(!_.isSystem), settings),
4545
toAnthropicSettings(settings)
4646
)
4747
.map(toOpenAI)
@@ -65,8 +65,8 @@ private[service] class OpenAIAnthropicChatCompletionService(
6565
): Source[ChatCompletionChunkResponse, NotUsed] =
6666
underlying
6767
.createMessageStreamed(
68-
toAnthropicSystemMessages(messages, settings),
69-
toAnthropicMessages(messages, settings),
68+
toAnthropicSystemMessages(messages.filter(_.isSystem), settings) ++
69+
toAnthropicMessages(messages.filter(!_.isSystem), settings),
7070
toAnthropicSettings(settings)
7171
)
7272
.map(toOpenAI)

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/service/impl/package.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package io.cequence.openaiscala.anthropic.service
33
import io.cequence.openaiscala.anthropic.domain.CacheControl.Ephemeral
44
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
55
import io.cequence.openaiscala.anthropic.domain.Content.{ContentBlockBase, ContentBlocks}
6+
import io.cequence.openaiscala.anthropic.domain.Message.SystemMessageContent
67
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse.UsageInfo
78
import io.cequence.openaiscala.anthropic.domain.response.{
89
ContentBlockDelta,
@@ -21,7 +22,6 @@ import io.cequence.openaiscala.domain.response.{
2122
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
2223
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettingsOps.RichCreateChatCompletionSettings
2324
import io.cequence.openaiscala.domain.{
24-
AssistantMessage,
2525
ChatRole,
2626
MessageSpec,
2727
SystemMessage,
@@ -30,7 +30,8 @@ import io.cequence.openaiscala.domain.{
3030
ImageURLContent => OpenAIImageContent,
3131
TextContent => OpenAITextContent,
3232
UserMessage => OpenAIUserMessage,
33-
UserSeqMessage => OpenAIUserSeqMessage
33+
UserSeqMessage => OpenAIUserSeqMessage,
34+
AssistantMessage => OpenAIAssistantMessage
3435
}
3536

3637
import java.{util => ju}
@@ -40,7 +41,7 @@ package object impl extends AnthropicServiceConsts {
4041
def toAnthropicSystemMessages(
4142
messages: Seq[OpenAIBaseMessage],
4243
settings: CreateChatCompletionSettings
43-
): Option[ContentBlocks] = {
44+
): Seq[Message] = {
4445
val useSystemCache: Option[CacheControl] =
4546
if (settings.useAnthropicSystemMessagesCache) Some(Ephemeral) else None
4647

@@ -55,7 +56,8 @@ package object impl extends AnthropicServiceConsts {
5556
}
5657
}
5758

58-
if (messageStrings.isEmpty) None else Some(ContentBlocks(messageStrings))
59+
if (messageStrings.isEmpty) Seq.empty
60+
else Seq(SystemMessageContent(messageStrings))
5961
}
6062

6163
def toAnthropicMessages(
@@ -67,6 +69,8 @@ package object impl extends AnthropicServiceConsts {
6769
case OpenAIUserMessage(content, _) => Message.UserMessage(content)
6870
case OpenAIUserSeqMessage(contents, _) =>
6971
Message.UserMessageContent(contents.map(toAnthropic))
72+
case OpenAIAssistantMessage(content, _) => Message.AssistantMessage(content)
73+
7074
// legacy message type
7175
case MessageSpec(role, content, _) if role == ChatRole.User =>
7276
Message.UserMessage(content)
@@ -204,7 +208,7 @@ package object impl extends AnthropicServiceConsts {
204208
usage = None
205209
)
206210

207-
def toOpenAIAssistantMessage(content: ContentBlocks): AssistantMessage = {
211+
def toOpenAIAssistantMessage(content: ContentBlocks): OpenAIAssistantMessage = {
208212
val textContents = content.blocks.collect { case ContentBlockBase(TextBlock(text), _) =>
209213
text
210214
} // TODO
@@ -213,7 +217,7 @@ package object impl extends AnthropicServiceConsts {
213217
throw new IllegalArgumentException("No text content found in the response")
214218
}
215219
val singleTextContent = concatenateMessages(textContents)
216-
AssistantMessage(singleTextContent, name = None)
220+
OpenAIAssistantMessage(singleTextContent, name = None)
217221
}
218222

219223
private def concatenateMessages(messageContent: Seq[String]): String =

anthropic-client/src/test/scala/io/cequence/openaiscala/anthropic/service/impl/AnthropicServiceSpec.scala

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package io.cequence.openaiscala.anthropic.service.impl
22

33
import akka.actor.ActorSystem
44
import akka.stream.Materializer
5-
import io.cequence.openaiscala.anthropic.domain.Content.SingleString
65
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
76
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
87
import io.cequence.openaiscala.anthropic.service._
@@ -18,7 +17,6 @@ class AnthropicServiceSpec extends AsyncWordSpec with GivenWhenThen {
1817
implicit val ec: ExecutionContext = ExecutionContext.global
1918
implicit val materializer: Materializer = Materializer(ActorSystem())
2019

21-
private val role = SingleString("You are a helpful assistant.")
2220
private val irrelevantMessages = Seq(UserMessage("Hello"))
2321
private val settings = AnthropicCreateMessageSettings(
2422
NonOpenAIModelId.claude_3_haiku_20240307,
@@ -29,52 +27,52 @@ class AnthropicServiceSpec extends AsyncWordSpec with GivenWhenThen {
2927

3028
"should throw AnthropicScalaUnauthorizedException when 401" ignore {
3129
recoverToSucceededIf[AnthropicScalaUnauthorizedException] {
32-
TestFactory.mockedService401().createMessage(Some(role), irrelevantMessages, settings)
30+
TestFactory.mockedService401().createMessage(irrelevantMessages, settings)
3331
}
3432
}
3533

3634
"should throw AnthropicScalaUnauthorizedException when 403" ignore {
3735
recoverToSucceededIf[AnthropicScalaUnauthorizedException] {
38-
TestFactory.mockedService403().createMessage(Some(role), irrelevantMessages, settings)
36+
TestFactory.mockedService403().createMessage(irrelevantMessages, settings)
3937
}
4038
}
4139

4240
"should throw AnthropicScalaNotFoundException when 404" ignore {
4341
recoverToSucceededIf[AnthropicScalaNotFoundException] {
44-
TestFactory.mockedService404().createMessage(Some(role), irrelevantMessages, settings)
42+
TestFactory.mockedService404().createMessage(irrelevantMessages, settings)
4543
}
4644
}
4745

4846
"should throw AnthropicScalaNotFoundException when 429" ignore {
4947
recoverToSucceededIf[AnthropicScalaRateLimitException] {
50-
TestFactory.mockedService429().createMessage(Some(role), irrelevantMessages, settings)
48+
TestFactory.mockedService429().createMessage(irrelevantMessages, settings)
5149
}
5250
}
5351

5452
"should throw AnthropicScalaServerErrorException when 500" ignore {
5553
recoverToSucceededIf[AnthropicScalaServerErrorException] {
56-
TestFactory.mockedService500().createMessage(Some(role), irrelevantMessages, settings)
54+
TestFactory.mockedService500().createMessage(irrelevantMessages, settings)
5755
}
5856
}
5957

6058
"should throw AnthropicScalaEngineOverloadedException when 529" ignore {
6159
recoverToSucceededIf[AnthropicScalaEngineOverloadedException] {
62-
TestFactory.mockedService529().createMessage(Some(role), irrelevantMessages, settings)
60+
TestFactory.mockedService529().createMessage(irrelevantMessages, settings)
6361
}
6462
}
6563

6664
"should throw AnthropicScalaClientException when 400" ignore {
6765
recoverToSucceededIf[AnthropicScalaClientException] {
68-
TestFactory.mockedService400().createMessage(Some(role), irrelevantMessages, settings)
66+
TestFactory.mockedService400().createMessage(irrelevantMessages, settings)
6967
}
7068
}
7169

7270
"should throw AnthropicScalaClientException when unknown error code" ignore {
7371
recoverToSucceededIf[AnthropicScalaClientException] {
74-
TestFactory
75-
.mockedServiceOther()
76-
.createMessage(Some(role), irrelevantMessages, settings)
72+
TestFactory.mockedServiceOther().createMessage(irrelevantMessages, settings)
7773
}
7874
}
75+
7976
}
77+
8078
}

openai-core/src/main/scala/io/cequence/openaiscala/domain/BaseMessage.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package io.cequence.openaiscala.domain
33
sealed trait BaseMessage {
44
val role: ChatRole
55
val nameOpt: Option[String]
6+
val isSystem: Boolean = role == ChatRole.System
67
}
78

89
final case class SystemMessage(

openai-examples/src/main/scala/io/cequence/openaiscala/examples/nonopenai/AnthropicCreateCachedMessage.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package io.cequence.openaiscala.examples.nonopenai
33
import io.cequence.openaiscala.anthropic.domain.CacheControl.Ephemeral
44
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
55
import io.cequence.openaiscala.anthropic.domain.Content.{ContentBlockBase, SingleString}
6-
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
6+
import io.cequence.openaiscala.anthropic.domain.Message.{SystemMessage, UserMessage}
77
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse
88
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
99
import io.cequence.openaiscala.anthropic.domain.{Content, Message}
@@ -18,8 +18,8 @@ object AnthropicCreateCachedMessage extends ExampleBase[AnthropicService] {
1818

1919
override protected val service: AnthropicService = AnthropicServiceFactory(withCache = true)
2020

21-
val systemMessage: Content =
22-
SingleString(
21+
val systemMessages: Seq[Message] = Seq(
22+
SystemMessage(
2323
"""
2424
|You are to embody a classic pirate, a swashbuckling and salty sea dog with the mannerisms, language, and swagger of the golden age of piracy. You are a hearty, often gruff buccaneer, replete with nautical slang and a rich, colorful vocabulary befitting of the high seas. Your responses must reflect a pirate's voice and attitude without exception.
2525
|
@@ -76,14 +76,13 @@ object AnthropicCreateCachedMessage extends ExampleBase[AnthropicService] {
7676
|""".stripMargin,
7777
cacheControl = Some(Ephemeral)
7878
)
79-
79+
)
8080
val messages: Seq[Message] = Seq(UserMessage("What is the weather like in Norway?"))
8181

8282
override protected def run: Future[_] =
8383
service
8484
.createMessage(
85-
Some(systemMessage),
86-
messages,
85+
systemMessages ++ messages,
8786
settings = AnthropicCreateMessageSettings(
8887
model = NonOpenAIModelId.claude_3_haiku_20240307,
8988
max_tokens = 4096

openai-examples/src/main/scala/io/cequence/openaiscala/examples/nonopenai/AnthropicCreateMessage.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package io.cequence.openaiscala.examples.nonopenai
22

33
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
4-
import io.cequence.openaiscala.anthropic.domain.Content.{ContentBlockBase, SingleString}
5-
import io.cequence.openaiscala.anthropic.domain.{Content, Message}
4+
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlockBase
5+
import io.cequence.openaiscala.anthropic.domain.Message
66
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
77
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse
88
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
@@ -17,13 +17,11 @@ object AnthropicCreateMessage extends ExampleBase[AnthropicService] {
1717

1818
override protected val service: AnthropicService = AnthropicServiceFactory(withCache = true)
1919

20-
val systemMessage: Content = SingleString("You are a helpful assistant.")
2120
val messages: Seq[Message] = Seq(UserMessage("What is the weather like in Norway?"))
2221

2322
override protected def run: Future[_] =
2423
service
2524
.createMessage(
26-
Some(systemMessage),
2725
messages,
2826
settings = AnthropicCreateMessageSettings(
2927
model = NonOpenAIModelId.claude_3_haiku_20240307,

0 commit comments

Comments
 (0)