Skip to content

Commit 1a56b12

Browse files
committed
Support for custom-URL OpenAI-compatible services (e.g. FastChat) with core endpoints (completion, chat completion, etc.) or all endpoints. An explicit Azure instance provided as well. Closes #32, #41. Note that this is an experimental feature and needs to be properly tested.
1 parent 5d19760 commit 1a56b12

File tree

9 files changed

+455
-283
lines changed

9 files changed

+455
-283
lines changed

openai-client-stream/src/main/scala/io/cequence/openaiscala/service/OpenAIServiceStreamedImpl.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.concurrent.ExecutionContext
2424
private trait OpenAIServiceStreamedExtraImpl
2525
extends OpenAIServiceStreamedExtra
2626
with WSStreamRequestHelper {
27-
this: OpenAIServiceImpl =>
27+
this: OpenAIServiceClassImpl =>
2828

2929
override def createCompletionStreamed(
3030
prompt: String,
@@ -89,6 +89,11 @@ object OpenAIServiceStreamedFactory
8989
)(
9090
implicit ec: ExecutionContext,
9191
materializer: Materializer
92-
): OpenAIService with OpenAIServiceStreamedExtra =
93-
new OpenAIServiceImpl(apiKey, orgId, timeouts) with OpenAIServiceStreamedExtraImpl
92+
): OpenAIService with OpenAIServiceStreamedExtra = {
93+
val orgIdHeader = orgId.map(("OpenAI-Organization", _))
94+
val authHeaders = orgIdHeader ++: Seq(("Authorization", s"Bearer $apiKey"))
95+
96+
new OpenAIServiceClassImpl(defaultCoreUrl, authHeaders, timeouts)
97+
with OpenAIServiceStreamedExtraImpl
98+
}
9499
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package io.cequence.openaiscala.service
2+
3+
import akka.stream.Materializer
4+
import io.cequence.openaiscala.service.ws.Timeouts
5+
6+
import scala.concurrent.ExecutionContext
7+
8+
object OpenAICoreServiceFactory {
9+
10+
def apply(
11+
coreUrl: String,
12+
authHeaders: Seq[(String, String)] = Nil,
13+
timeouts: Option[Timeouts] = None
14+
)(
15+
implicit ec: ExecutionContext,
16+
materializer: Materializer
17+
): OpenAICoreService =
18+
new OpenAICoreServiceClassImpl(coreUrl, authHeaders, timeouts)
19+
}
20+
21+
private class OpenAICoreServiceClassImpl(
22+
val coreUrl: String,
23+
val authHeaders: Seq[(String, String)],
24+
val explTimeouts: Option[Timeouts]
25+
)(
26+
implicit val ec: ExecutionContext,
27+
val materializer: Materializer
28+
) extends OpenAICoreServiceImpl
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
package io.cequence.openaiscala.service
2+
3+
import akka.stream.Materializer
4+
import play.api.libs.ws.StandaloneWSRequest
5+
import play.api.libs.json.{JsArray, JsNull, JsObject, JsValue, Json}
6+
import io.cequence.openaiscala.JsonUtil.JsonOps
7+
import io.cequence.openaiscala.JsonFormats._
8+
import io.cequence.openaiscala.OpenAIScalaClientException
9+
import io.cequence.openaiscala.domain.settings._
10+
import io.cequence.openaiscala.domain.response._
11+
import io.cequence.openaiscala.domain.{
12+
BaseMessageSpec,
13+
FunMessageSpec,
14+
MessageSpec
15+
}
16+
import io.cequence.openaiscala.service.ws.{Timeouts, WSRequestHelper}
17+
18+
import scala.concurrent.{ExecutionContext, Future}
19+
20+
/**
21+
* Private impl. of [[OpenAICoreService]].
22+
*
23+
* @param ec
24+
* @param materializer
25+
*
26+
* @since July
27+
* 2023
28+
*/
29+
private trait OpenAICoreServiceImpl extends OpenAICoreService with WSRequestHelper {
30+
31+
override protected type PEP = EndPoint
32+
override protected type PT = Param
33+
34+
protected implicit val ec: ExecutionContext
35+
protected implicit val materializer: Materializer
36+
37+
protected val explTimeouts: Option[Timeouts]
38+
protected val authHeaders: Seq[(String, String)]
39+
40+
override protected def timeouts: Timeouts =
41+
explTimeouts.getOrElse(
42+
Timeouts(
43+
requestTimeout = Some(defaultRequestTimeout),
44+
readTimeout = Some(defaultReadoutTimeout)
45+
)
46+
)
47+
48+
override def listModels: Future[Seq[ModelInfo]] =
49+
execGET(EndPoint.models).map { response =>
50+
(response.asSafe[JsObject] \ "data").toOption.map {
51+
_.asSafeArray[ModelInfo]
52+
}.getOrElse(
53+
throw new OpenAIScalaClientException(
54+
s"The attribute 'data' is not present in the response: ${response.toString()}."
55+
)
56+
)
57+
}
58+
59+
override def createCompletion(
60+
prompt: String,
61+
settings: CreateCompletionSettings
62+
): Future[TextCompletionResponse] =
63+
execPOST(
64+
EndPoint.completions,
65+
bodyParams = createBodyParamsForCompletion(prompt, settings, stream = false)
66+
).map(
67+
_.asSafe[TextCompletionResponse]
68+
)
69+
70+
protected def createBodyParamsForCompletion(
71+
prompt: String,
72+
settings: CreateCompletionSettings,
73+
stream: Boolean
74+
): Seq[(Param, Option[JsValue])] =
75+
jsonBodyParams(
76+
Param.prompt -> Some(prompt),
77+
Param.model -> Some(settings.model),
78+
Param.suffix -> settings.suffix,
79+
Param.max_tokens -> settings.max_tokens,
80+
Param.temperature -> settings.temperature,
81+
Param.top_p -> settings.top_p,
82+
Param.n -> settings.n,
83+
Param.stream -> Some(stream),
84+
Param.logprobs -> settings.logprobs,
85+
Param.echo -> settings.echo,
86+
Param.stop -> {
87+
settings.stop.size match {
88+
case 0 => None
89+
case 1 => Some(settings.stop.head)
90+
case _ => Some(settings.stop)
91+
}
92+
},
93+
Param.presence_penalty -> settings.presence_penalty,
94+
Param.frequency_penalty -> settings.frequency_penalty,
95+
Param.best_of -> settings.best_of,
96+
Param.logit_bias -> {
97+
if (settings.logit_bias.isEmpty) None else Some(settings.logit_bias)
98+
},
99+
Param.user -> settings.user
100+
)
101+
102+
override def createChatCompletion(
103+
messages: Seq[MessageSpec],
104+
settings: CreateChatCompletionSettings
105+
): Future[ChatCompletionResponse] =
106+
execPOST(
107+
EndPoint.chat_completions,
108+
bodyParams = createBodyParamsForChatCompletion(messages, settings, stream = false)
109+
).map(
110+
_.asSafe[ChatCompletionResponse]
111+
)
112+
113+
protected def createBodyParamsForChatCompletion(
114+
messages: Seq[BaseMessageSpec],
115+
settings: CreateChatCompletionSettings,
116+
stream: Boolean
117+
): Seq[(Param, Option[JsValue])] = {
118+
assert(messages.nonEmpty, "At least one message expected.")
119+
val messageJsons = messages.map(_ match {
120+
case m: MessageSpec =>
121+
Json.toJson(m)(messageSpecFormat)
122+
case m: FunMessageSpec =>
123+
val json = Json.toJson(m)(funMessageSpecFormat)
124+
// if the content is empty, add a null value (expected by the API)
125+
m.content
126+
.map(_ => json)
127+
.getOrElse(
128+
json.as[JsObject].+("content" -> JsNull)
129+
)
130+
})
131+
132+
jsonBodyParams(
133+
Param.messages -> Some(JsArray(messageJsons)),
134+
Param.model -> Some(settings.model),
135+
Param.temperature -> settings.temperature,
136+
Param.top_p -> settings.top_p,
137+
Param.n -> settings.n,
138+
Param.stream -> Some(stream),
139+
Param.stop -> {
140+
settings.stop.size match {
141+
case 0 => None
142+
case 1 => Some(settings.stop.head)
143+
case _ => Some(settings.stop)
144+
}
145+
},
146+
Param.max_tokens -> settings.max_tokens,
147+
Param.presence_penalty -> settings.presence_penalty,
148+
Param.frequency_penalty -> settings.frequency_penalty,
149+
Param.logit_bias -> {
150+
if (settings.logit_bias.isEmpty) None else Some(settings.logit_bias)
151+
},
152+
Param.user -> settings.user
153+
)
154+
}
155+
156+
override def createEmbeddings(
157+
input: Seq[String],
158+
settings: CreateEmbeddingsSettings
159+
): Future[EmbeddingResponse] =
160+
execPOST(
161+
EndPoint.embeddings,
162+
bodyParams = jsonBodyParams(
163+
Param.input -> {
164+
input.size match {
165+
case 0 => None
166+
case 1 => Some(input.head)
167+
case _ => Some(input)
168+
}
169+
},
170+
Param.model -> Some(settings.model),
171+
Param.user -> settings.user
172+
)
173+
).map(
174+
_.asSafe[EmbeddingResponse]
175+
)
176+
177+
// auth
178+
179+
override protected def getWSRequestOptional(
180+
endPoint: Option[PEP],
181+
endPointParam: Option[String],
182+
params: Seq[(PT, Option[Any])] = Nil
183+
): StandaloneWSRequest#Self =
184+
super.getWSRequestOptional(endPoint, endPointParam, params).addHttpHeaders(authHeaders: _*)
185+
186+
override protected def getWSRequest(
187+
endPoint: Option[PEP],
188+
endPointParam: Option[String],
189+
params: Seq[(PT, Any)] = Nil
190+
): StandaloneWSRequest#Self =
191+
super.getWSRequest(endPoint, endPointParam, params).addHttpHeaders(authHeaders: _*)
192+
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package io.cequence.openaiscala.service
2+
3+
import akka.stream.Materializer
4+
import io.cequence.openaiscala.service.ws.Timeouts
5+
6+
import scala.concurrent.ExecutionContext
7+
8+
object OpenAIServiceFactory
9+
extends OpenAIServiceFactoryHelper[OpenAIService]
10+
with OpenAIServiceConsts {
11+
12+
override def apply(
13+
apiKey: String,
14+
orgId: Option[String] = None,
15+
timeouts: Option[Timeouts] = None
16+
)(
17+
implicit ec: ExecutionContext,
18+
materializer: Materializer
19+
): OpenAIService = {
20+
val orgIdHeader = orgId.map(("OpenAI-Organization", _))
21+
val authHeaders = orgIdHeader ++: Seq(("Authorization", s"Bearer $apiKey"))
22+
23+
apply(defaultCoreUrl, authHeaders, timeouts)
24+
}
25+
26+
/**
27+
* Create an OpenAI Service for Azure using an API key.
28+
*
29+
* Note that not all endpoints are supported! Check <a
30+
* href="https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference">the
31+
* Azure OpenAI API documentation</a> for more information.
32+
*
33+
* @param resourceName
34+
* The name of your Azure OpenAI Resource.
35+
* @param deploymentId
36+
* The deployment name you chose when you deployed the model.
37+
* @param apiVersion
38+
* The API version to use for this operation. This follows the YYYY-MM-DD format. Supported
39+
* versions: 2023-03-15-preview, 2022-12-01, 2023-05-15, and 2023-06-01-preview
40+
*/
41+
def forAzureWithApiKey(
42+
resourceName: String,
43+
deploymentId: String,
44+
apiVersion: String,
45+
apiKey: String,
46+
timeouts: Option[Timeouts] = None
47+
)(
48+
implicit ec: ExecutionContext,
49+
materializer: Materializer
50+
): OpenAIService = {
51+
val authHeaders = Seq(("api-key", apiKey))
52+
forAzureAux(resourceName, deploymentId, apiVersion, authHeaders, timeouts)
53+
}
54+
55+
/**
56+
* Create an OpenAI Service for Azure using an access token (Azure Active Directory
57+
* authentication).
58+
*
59+
* Note that not all endpoints are supported! Check <a
60+
* href="https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference">the
61+
* Azure OpenAI API documentation</a> for more information.
62+
*
63+
* @param resourceName
64+
* The name of your Azure OpenAI Resource.
65+
* @param deploymentId
66+
* The deployment name you chose when you deployed the model.
67+
* @param apiVersion
68+
* The API version to use for this operation. This follows the YYYY-MM-DD format. Supported
69+
* versions: 2023-03-15-preview, 2022-12-01, 2023-05-15, and 2023-06-01-preview
70+
*/
71+
def forAzureWithAccessToken(
72+
resourceName: String,
73+
deploymentId: String,
74+
apiVersion: String,
75+
accessToken: String,
76+
timeouts: Option[Timeouts] = None
77+
)(
78+
implicit ec: ExecutionContext,
79+
materializer: Materializer
80+
): OpenAIService = {
81+
val authHeaders = Seq(("Authorization", s"Bearer $accessToken"))
82+
forAzureAux(resourceName, deploymentId, apiVersion, authHeaders, timeouts)
83+
}
84+
85+
private def forAzureAux(
86+
resourceName: String,
87+
deploymentId: String,
88+
apiVersion: String,
89+
authHeaders: Seq[(String, String)],
90+
timeouts: Option[Timeouts] = None
91+
)(
92+
implicit ec: ExecutionContext,
93+
materializer: Materializer
94+
): OpenAIService = {
95+
val coreUrl =
96+
s"https://${resourceName}.openai.azure.com/openai/deployments/${deploymentId}/completions?api-version=${apiVersion}"
97+
98+
apply(coreUrl, authHeaders, timeouts)
99+
}
100+
101+
def apply(
102+
coreUrl: String,
103+
authHeaders: Seq[(String, String)],
104+
timeouts: Option[Timeouts]
105+
)(
106+
implicit ec: ExecutionContext,
107+
materializer: Materializer
108+
): OpenAIService =
109+
new OpenAIServiceClassImpl(coreUrl, authHeaders, timeouts)
110+
}
111+
112+
private class OpenAIServiceClassImpl(
113+
val coreUrl: String,
114+
val authHeaders: Seq[(String, String)],
115+
val explTimeouts: Option[Timeouts]
116+
)(
117+
implicit val ec: ExecutionContext,
118+
val materializer: Materializer
119+
) extends OpenAIServiceImpl

0 commit comments

Comments
 (0)