Skip to content

Commit bf9f759

Browse files
committed
Bedrock/AWS impl with anthropic models support
1 parent b6a3faa commit bf9f759

File tree

8 files changed

+490
-102
lines changed

8 files changed

+490
-102
lines changed

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/service/AnthropicServiceConsts.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ trait AnthropicServiceConsts {
1010

1111
protected val defaultCoreUrl = "https://api.anthropic.com/v1/"
1212

13+
protected def bedrockCoreUrl(region: String) =
14+
s"https://bedrock-runtime.$region.amazonaws.com/"
15+
1316
object DefaultSettings {
1417

1518
val CreateMessage = AnthropicCreateMessageSettings(

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/service/AnthropicServiceFactory.scala

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package io.cequence.openaiscala.anthropic.service
22

33
import akka.stream.Materializer
44
import io.cequence.openaiscala.anthropic.service.impl.{
5+
AnthropicBedrockServiceImpl,
56
AnthropicServiceImpl,
7+
BedrockConnectionSettings,
68
OpenAIAnthropicChatCompletionService
79
}
810
import io.cequence.openaiscala.service.StreamedServiceTypes.OpenAIChatCompletionStreamedService
@@ -22,7 +24,13 @@ import scala.concurrent.ExecutionContext
2224
object AnthropicServiceFactory extends AnthropicServiceConsts {
2325

2426
private def apiVersion = "2023-06-01"
25-
private def envAPIKey = "ANTHROPIC_API_KEY"
27+
28+
object EnvKeys {
29+
val anthropicAPIKey = "ANTHROPIC_API_KEY"
30+
val bedrockAccessKey = "AWS_BEDROCK_ACCESS_KEY"
31+
val bedrockSecretKey = "AWS_BEDROCK_SECRET_KEY"
32+
val bedrockRegion = "AWS_BEDROCK_REGION"
33+
}
2634

2735
/**
2836
* Create a new instance of the [[OpenAIChatCompletionService]] wrapping the AnthropicService
@@ -37,7 +45,7 @@ object AnthropicServiceFactory extends AnthropicServiceConsts {
3745
* @return
3846
*/
3947
def asOpenAI(
40-
apiKey: String = getAPIKeyFromEnv(),
48+
apiKey: String = getEnvValue(EnvKeys.anthropicAPIKey),
4149
timeouts: Option[Timeouts] = None,
4250
withCache: Boolean = false
4351
)(
@@ -48,6 +56,19 @@ object AnthropicServiceFactory extends AnthropicServiceConsts {
4856
AnthropicServiceFactory(apiKey, timeouts, withPdf = false, withCache)
4957
)
5058

59+
def bedrockAsOpenAI(
60+
accessKey: String = getEnvValue(EnvKeys.bedrockAccessKey),
61+
secretKey: String = getEnvValue(EnvKeys.bedrockSecretKey),
62+
region: String = getEnvValue(EnvKeys.bedrockRegion),
63+
timeouts: Option[Timeouts] = None
64+
)(
65+
implicit ec: ExecutionContext,
66+
materializer: Materializer
67+
): OpenAIChatCompletionStreamedService =
68+
new OpenAIAnthropicChatCompletionService(
69+
AnthropicServiceFactory.forBedrock(accessKey, secretKey, region, timeouts)
70+
)
71+
5172
/**
5273
* Create a new instance of the [[AnthropicService]]
5374
*
@@ -61,7 +82,7 @@ object AnthropicServiceFactory extends AnthropicServiceConsts {
6182
* @return
6283
*/
6384
def apply(
64-
apiKey: String = getAPIKeyFromEnv(),
85+
apiKey: String = getEnvValue(EnvKeys.anthropicAPIKey),
6586
timeouts: Option[Timeouts] = None,
6687
withPdf: Boolean = false,
6788
withCache: Boolean = false
@@ -78,17 +99,31 @@ object AnthropicServiceFactory extends AnthropicServiceConsts {
7899
new AnthropicServiceClassImpl(defaultCoreUrl, authHeaders, timeouts)
79100
}
80101

81-
private def getAPIKeyFromEnv(): String =
82-
Option(System.getenv(envAPIKey)).getOrElse(
102+
def forBedrock(
103+
accessKey: String = getEnvValue(EnvKeys.bedrockAccessKey),
104+
secretKey: String = getEnvValue(EnvKeys.bedrockSecretKey),
105+
region: String = getEnvValue(EnvKeys.bedrockRegion),
106+
timeouts: Option[Timeouts] = None
107+
)(
108+
implicit ec: ExecutionContext,
109+
materializer: Materializer
110+
): AnthropicService =
111+
new AnthropicBedrockServiceClassImpl(
112+
BedrockConnectionSettings(accessKey, secretKey, region),
113+
timeouts
114+
)
115+
116+
private def getEnvValue(envKey: String): String =
117+
Option(System.getenv(envKey)).getOrElse(
83118
throw new IllegalStateException(
84-
"ANTHROPIC_API_KEY environment variable expected but not set. Alternatively, you can pass the API key explicitly to the factory method."
119+
s"${envKey} environment variable expected but not set. Alternatively, you can pass the API key explicitly to the factory method."
85120
)
86121
)
87122

88123
private class AnthropicServiceClassImpl(
89-
val coreUrl: String,
90-
val authHeaders: Seq[(String, String)],
91-
val explTimeouts: Option[Timeouts] = None
124+
coreUrl: String,
125+
authHeaders: Seq[(String, String)],
126+
explTimeouts: Option[Timeouts] = None
92127
)(
93128
implicit val ec: ExecutionContext,
94129
val materializer: Materializer
@@ -97,7 +132,25 @@ object AnthropicServiceFactory extends AnthropicServiceConsts {
97132
override protected val engine: WSClientEngine with WSClientEngineStreamExtra =
98133
PlayWSStreamClientEngine(
99134
coreUrl,
100-
WsRequestContext(authHeaders = authHeaders, explTimeouts = explTimeouts)
135+
WsRequestContext(authHeaders = authHeaders, explTimeouts = explTimeouts),
136+
recoverErrors
137+
)
138+
}
139+
140+
private class AnthropicBedrockServiceClassImpl(
141+
override val connectionInfo: BedrockConnectionSettings,
142+
explTimeouts: Option[Timeouts] = None
143+
)(
144+
implicit val ec: ExecutionContext,
145+
val materializer: Materializer
146+
) extends AnthropicBedrockServiceImpl {
147+
148+
// Play WS engine
149+
override protected val engine: WSClientEngine with WSClientEngineStreamExtra =
150+
PlayWSStreamClientEngine(
151+
coreUrl = bedrockCoreUrl(connectionInfo.region),
152+
WsRequestContext(explTimeouts = explTimeouts),
153+
recoverErrors
101154
)
102155
}
103156

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package io.cequence.openaiscala.anthropic.service.impl
2+
3+
import io.cequence.openaiscala.OpenAIScalaClientException
4+
import io.cequence.openaiscala.anthropic.JsonFormats
5+
import io.cequence.openaiscala.anthropic.domain.{ChatRole, Content, Message}
6+
import io.cequence.openaiscala.anthropic.domain.Message.{SystemMessage, SystemMessageContent}
7+
import io.cequence.openaiscala.anthropic.domain.response.ContentBlockDelta
8+
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
9+
import io.cequence.openaiscala.anthropic.service.{AnthropicService, HandleAnthropicErrorCodes}
10+
import io.cequence.wsclient.service.WSClientWithEngineTypes.WSClientWithStreamEngine
11+
import org.slf4j.LoggerFactory
12+
import play.api.libs.json.{JsString, JsValue, Json, Writes}
13+
import com.typesafe.scalalogging.Logger
14+
import io.cequence.wsclient.JsonUtil.JsonOps
15+
16+
trait Anthropic
17+
extends AnthropicService
18+
with WSClientWithStreamEngine
19+
with HandleAnthropicErrorCodes
20+
with JsonFormats {
21+
22+
protected val logger = Logger(LoggerFactory.getLogger(this.getClass))
23+
24+
protected def createBodyParamsForMessageCreation(
25+
messages: Seq[Message],
26+
settings: AnthropicCreateMessageSettings,
27+
stream: Option[Boolean],
28+
ignoreModel: Boolean = false
29+
): Seq[(Param, Option[JsValue])] = {
30+
assert(messages.nonEmpty, "At least one message expected.")
31+
32+
val (system, nonSystem) = messages.partition(_.isSystem)
33+
34+
assert(nonSystem.head.role == ChatRole.User, "First non-system message must be from user.")
35+
assert(
36+
system.size <= 1,
37+
"System message can be only 1. Use SystemMessageContent to include more content blocks."
38+
)
39+
40+
val messageJsons = nonSystem.map(Json.toJson(_))
41+
42+
val systemJson: Seq[JsValue] = system.map {
43+
case SystemMessage(text, cacheControl) =>
44+
if (cacheControl.isEmpty) JsString(text)
45+
else {
46+
val blocks =
47+
Seq(Content.ContentBlockBase(Content.ContentBlock.TextBlock(text), cacheControl))
48+
49+
Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites))
50+
}
51+
case SystemMessageContent(blocks) =>
52+
Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites))
53+
}
54+
55+
jsonBodyParams(
56+
Param.messages -> Some(messageJsons),
57+
Param.model -> (if (ignoreModel) None else Some(settings.model)),
58+
Param.system -> {
59+
if (system.isEmpty) None
60+
else Some(systemJson.head)
61+
},
62+
Param.max_tokens -> Some(settings.max_tokens),
63+
Param.metadata -> { if (settings.metadata.isEmpty) None else Some(settings.metadata) },
64+
Param.stop_sequences -> {
65+
if (settings.stop_sequences.nonEmpty) Some(settings.stop_sequences) else None
66+
},
67+
Param.stream -> stream,
68+
Param.temperature -> settings.temperature,
69+
Param.top_p -> settings.top_p,
70+
Param.top_k -> settings.top_k
71+
)
72+
}
73+
74+
protected def serializeStreamedJson(json: JsValue): Option[ContentBlockDelta] =
75+
(json \ "error").toOption.map { error =>
76+
logger.error(s"Error in streamed response: ${error.toString()}")
77+
throw new OpenAIScalaClientException(error.toString())
78+
}.getOrElse {
79+
val jsonType = (json \ "type").as[String]
80+
81+
// TODO: for now, we return only ContentBlockDelta
82+
jsonType match {
83+
case "message_start" => None // json.asSafe[CreateMessageChunkResponse]
84+
case "content_block_start" => None
85+
case "ping" => None
86+
case "content_block_delta" => Some(json.asSafe[ContentBlockDelta])
87+
case "content_block_stop" => None
88+
case "message_delta" => None
89+
case "message_stop" => None
90+
case _ =>
91+
logger.error(s"Unknown message type: $jsonType")
92+
throw new OpenAIScalaClientException(s"Unknown message type: $jsonType")
93+
}
94+
}
95+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package io.cequence.openaiscala.anthropic.service.impl
2+
3+
import akka.NotUsed
4+
import akka.stream.javadsl.{Framing, FramingTruncation}
5+
import akka.stream.scaladsl.Source
6+
import akka.util.ByteString
7+
import io.cequence.openaiscala.anthropic.domain.Message
8+
import io.cequence.openaiscala.anthropic.domain.response.{
9+
ContentBlockDelta,
10+
CreateMessageResponse
11+
}
12+
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
13+
import io.cequence.wsclient.ResponseImplicits.JsonSafeOps
14+
import play.api.libs.json.{JsString, JsValue, Json}
15+
16+
import scala.concurrent.Future
17+
18+
private[service] trait AnthropicBedrockServiceImpl extends Anthropic with BedrockAuthHelper {
19+
20+
override protected type PEP = String
21+
override protected type PT = Param
22+
23+
private def invokeEndpoint(model: String) = s"model/$model/invoke"
24+
private def invokeWithResponseStreamEndpoint(model: String) =
25+
s"model/$model/invoke-with-response-stream"
26+
private val serviceName = "bedrock"
27+
28+
private val bedrockAnthropicVersion = "bedrock-2023-05-31"
29+
30+
override def createMessage(
31+
messages: Seq[Message],
32+
settings: AnthropicCreateMessageSettings
33+
): Future[CreateMessageResponse] = {
34+
val coreBodyParams =
35+
createBodyParamsForMessageCreation(messages, settings, stream = None, ignoreModel = true)
36+
val bodyParams =
37+
coreBodyParams :+ (Param.anthropic_version -> Some(JsString(bedrockAnthropicVersion)))
38+
39+
val jsBodyObject = toJsBodyObject(paramTuplesToStrings(bodyParams))
40+
val endpoint = invokeEndpoint(settings.model)
41+
42+
val extraHeaders = createSignatureHeaders(
43+
"POST",
44+
createURL(Some(endpoint)),
45+
headers = requestContext.authHeaders,
46+
jsBodyObject
47+
)
48+
49+
execPOST(
50+
endpoint,
51+
bodyParams = bodyParams,
52+
extraHeaders = extraHeaders
53+
).map(
54+
_.asSafeJson[CreateMessageResponse]
55+
)
56+
}
57+
58+
override def createMessageStreamed(
59+
messages: Seq[Message],
60+
settings: AnthropicCreateMessageSettings
61+
): Source[ContentBlockDelta, NotUsed] = {
62+
val coreBodyParams =
63+
createBodyParamsForMessageCreation(messages, settings, stream = None, ignoreModel = true)
64+
val bodyParams =
65+
coreBodyParams :+ (Param.anthropic_version -> Some(JsString(bedrockAnthropicVersion)))
66+
67+
val stringParams = paramTuplesToStrings(bodyParams)
68+
val jsBodyObject = toJsBodyObject(stringParams)
69+
val endpoint = invokeWithResponseStreamEndpoint(settings.model)
70+
71+
val extraHeaders = createSignatureHeaders(
72+
"POST",
73+
createURL(Some(endpoint)),
74+
headers = requestContext.authHeaders,
75+
jsBodyObject
76+
)
77+
78+
engine
79+
.execRawStream(
80+
endpoint,
81+
"POST",
82+
endPointParam = None,
83+
params = Nil,
84+
bodyParams = stringParams,
85+
extraHeaders = extraHeaders
86+
)
87+
.via(
88+
Framing.delimiter(
89+
ByteString(":content-type"),
90+
maximumFrameLength = 65536,
91+
FramingTruncation.ALLOW
92+
)
93+
)
94+
.via(AwsEventStreamEventParser.flow) // parse frames into JSON with "bytes"
95+
.collect { case Some(x) => x }
96+
.via(AwsEventStreamBytesDecoder.flow) // decode the "
97+
.map(serializeStreamedJson)
98+
.collect { case Some(delta) => delta }
99+
}
100+
101+
protected def createSignatureHeaders(
102+
method: String,
103+
url: String,
104+
headers: Seq[(String, String)],
105+
body: JsValue
106+
): Seq[(String, String)] = {
107+
val connectionSettings = connectionInfo
108+
109+
addAuthHeaders(
110+
method,
111+
url,
112+
headers.toMap,
113+
Json.stringify(body),
114+
accessKey = connectionSettings.accessKey,
115+
secretKey = connectionSettings.secretKey,
116+
region = connectionSettings.region,
117+
service = serviceName
118+
).toSeq
119+
}
120+
121+
def connectionInfo: BedrockConnectionSettings
122+
}
123+
124+
case class BedrockConnectionSettings(
125+
accessKey: String,
126+
secretKey: String,
127+
region: String
128+
)

0 commit comments

Comments
 (0)