Skip to content

Commit 1d4841f

Browse files
committed
New OpenAIService audio functions/endpoints - createAudioTranscription and createAudioTranslation (impl, settings, response, json formatting)
1 parent 3befc71 commit 1d4841f

File tree

11 files changed

+214
-25
lines changed

11 files changed

+214
-25
lines changed

openai-client-stream/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
# OpenAI Scala Client - Stream Support [![version](https://img.shields.io/badge/version-0.2.0-green.svg)](https://cequence.io) [![License](https://img.shields.io/badge/License-MIT-lightgrey.svg)](https://opensource.org/licenses/MIT)
1+
# OpenAI Scala Client - Stream Support [![version](https://img.shields.io/badge/version-0.3.0-green.svg)](https://cequence.io) [![License](https://img.shields.io/badge/License-MIT-lightgrey.svg)](https://opensource.org/licenses/MIT)
22

33
This module provides streaming support for the client. Note that the full project documentation can be found [here](../README.md).
44

55
## Installation 🚀
66

7-
The currently supported Scala versions are **2.12** and **2.13**.
7+
The currently supported Scala versions are **2.12, 2.13**, and **3**.
88

99
To pull the library you have to add the following dependency to your *build.sbt*
1010

1111
```
12-
"io.cequence" %% "openai-scala-client-stream" % "0.2.0"
12+
"io.cequence" %% "openai-scala-client-stream" % "0.3.0"
1313
```
1414

1515
or to *pom.xml* (if you use maven)
@@ -18,6 +18,6 @@ or to *pom.xml* (if you use maven)
1818
<dependency>
1919
<groupId>io.cequence</groupId>
2020
<artifactId>openai-scala-client-stream_2.12</artifactId>
21-
<version>0.2.0</version>
21+
<version>0.3.0</version>
2222
</dependency>
2323
```

openai-client/src/main/scala/io/cequence/openaiscala/service/Command.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ object Command extends Enumeration {
99
val images_edits = Value("images/edits")
1010
val images_variations = Value("images/variations")
1111
val embeddings = Value
12+
val audio_transcriptions = Value("audio/transcriptions")
13+
val audio_translations = Value("audio/translations")
1214
val files = Value
1315
val fine_tunes = Value("fine-tunes")
1416
val moderations = Value
@@ -23,5 +25,5 @@ object Tag extends Enumeration {
2325
input, image, mask, instruction, size, response_format, file, purpose, file_id,
2426
training_file, validation_file, n_epochs, batch_size, learning_rate_multiplier, prompt_loss_weight,
2527
compute_classification_metrics, classification_n_classes, classification_positive_class,
26-
classification_betas, fine_tune_id = Value
28+
classification_betas, fine_tune_id, language = Value
2729
}

openai-client/src/main/scala/io/cequence/openaiscala/service/OpenAIServiceImpl.scala

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package io.cequence.openaiscala.service
22

33
import akka.stream.Materializer
44
import play.api.libs.ws.StandaloneWSRequest
5-
import play.api.libs.json.JsObject
5+
import play.api.libs.json.{JsArray, JsObject, JsValue, Json}
66
import io.cequence.openaiscala.JsonUtil.JsonOps
77
import io.cequence.openaiscala.JsonFormats._
88
import io.cequence.openaiscala.OpenAIScalaClientException
@@ -259,6 +259,71 @@ private class OpenAIServiceImpl(
259259
_.asSafe[EmbeddingResponse]
260260
)
261261

262+
override def createAudioTranscription(
263+
file: File,
264+
prompt: Option[String],
265+
settings: CreateTranscriptionSettings
266+
): Future[TranscriptResponse] =
267+
execPOSTMultipartWithStatusString(
268+
Command.audio_transcriptions,
269+
fileParams = Seq(Tag.file -> file),
270+
bodyParams = Seq(
271+
Tag.prompt -> prompt,
272+
Tag.model -> Some(settings.model),
273+
Tag.response_format -> settings.response_format.map(_.toString),
274+
Tag.temperature -> settings.temperature,
275+
Tag.language -> settings.language
276+
)
277+
).map(processAudioTranscriptResponse(settings.response_format))
278+
279+
override def createAudioTranslation(
280+
file: File,
281+
prompt: Option[String],
282+
settings: CreateTranslationSettings
283+
): Future[TranscriptResponse] =
284+
execPOSTMultipartWithStatusString(
285+
Command.audio_translations,
286+
fileParams = Seq(Tag.file -> file),
287+
bodyParams = Seq(
288+
Tag.prompt -> prompt,
289+
Tag.model -> Some(settings.model),
290+
Tag.response_format -> settings.response_format.map(_.toString),
291+
Tag.temperature -> settings.temperature
292+
)
293+
).map(processAudioTranscriptResponse(settings.response_format))
294+
295+
private def processAudioTranscriptResponse(
296+
responseFormat: Option[TranscriptResponseFormatType.Value])(
297+
stringRichResponse: RichStringResponse
298+
) = {
299+
val stringResponse = handleErrorResponse(stringRichResponse)
300+
301+
def textFromJsonString(json: JsValue) =
302+
(json.asSafe[JsObject] \ "text").toOption.map {
303+
_.asSafe[String]
304+
}.getOrElse(
305+
throw new OpenAIScalaClientException(s"The attribute 'text' is not present in the response: ${stringResponse}.")
306+
)
307+
308+
val FormatType = TranscriptResponseFormatType
309+
310+
responseFormat.getOrElse(FormatType.json) match {
311+
case FormatType.json =>
312+
val json = Json.parse(stringResponse)
313+
TranscriptResponse(textFromJsonString(json))
314+
315+
case FormatType.verbose_json =>
316+
val json = Json.parse(stringResponse)
317+
TranscriptResponse(
318+
text = textFromJsonString(json),
319+
verboseJson = Some(Json.prettyPrint(json))
320+
)
321+
322+
case FormatType.text | FormatType.srt | FormatType.vtt =>
323+
TranscriptResponse(stringResponse)
324+
}
325+
}
326+
262327
override def listFiles: Future[Seq[FileInfo]] =
263328
execGET(Command.files).map { response =>
264329
(response.asSafe[JsObject] \ "data").toOption.map {

openai-client/src/main/scala/io/cequence/openaiscala/service/ws/WSRequestHelper.scala

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ trait WSRequestHelper extends WSHelper {
3232

3333
private val defaultAcceptableStatusCodes = Seq(200)
3434

35-
protected type RichJsResponse = Either[JsValue, (Int, String)]
36-
protected type RichStringResponse = Either[String, (Int, String)]
35+
protected type RichResponse[T] = Either[T, (Int, String)]
36+
protected type RichJsResponse = RichResponse[JsValue]
37+
protected type RichStringResponse = RichResponse[String]
3738

3839
/////////
3940
// GET //
@@ -105,22 +106,42 @@ trait WSRequestHelper extends WSHelper {
105106
acceptableStatusCodes: Seq[Int] = defaultAcceptableStatusCodes
106107
): Future[RichJsResponse] = {
107108
val request = getWSRequestOptional(Some(endPoint), endPointParam, params)
109+
val formData = createMultipartFormData(fileParams, bodyParams)
108110

109-
// create a multipart form data holder contain classic data (key-value) parts as well as file parts
110-
val formData = MultipartFormData(
111-
dataParts = bodyParams.collect { case (key, Some(value)) =>
112-
(key.toString, Seq(value.toString))
113-
}.toMap,
111+
implicit val writeable: BodyWritable[MultipartFormData] = writeableOf_MultipartFormData("utf-8")
114112

115-
// TODO: we can potentially use here header-file-names as well (if provided as function's params)
116-
files = fileParams.map { case (key, file) => FilePart(key.toString, file.getPath) }
117-
)
113+
execPOSTJsonAux(request, formData, Some(endPoint), acceptableStatusCodes)
114+
}
115+
116+
protected def execPOSTMultipartWithStatusString(
117+
endPoint: PEP,
118+
endPointParam: Option[String] = None,
119+
params: Seq[(PT, Option[Any])] = Nil,
120+
fileParams: Seq[(PT, File)] = Nil,
121+
bodyParams: Seq[(PT, Option[Any])] = Nil,
122+
acceptableStatusCodes: Seq[Int] = defaultAcceptableStatusCodes
123+
): Future[RichStringResponse] = {
124+
val request = getWSRequestOptional(Some(endPoint), endPointParam, params)
125+
val formData = createMultipartFormData(fileParams, bodyParams)
118126

119127
implicit val writeable: BodyWritable[MultipartFormData] = writeableOf_MultipartFormData("utf-8")
120128

121-
execPOSTAux(request, formData, Some(endPoint), acceptableStatusCodes)
129+
execPOSTStringAux(request, formData, Some(endPoint), acceptableStatusCodes)
122130
}
123131

132+
// create a multipart form data holder contain classic data (key-value) parts as well as file parts
133+
private def createMultipartFormData(
134+
fileParams: Seq[(PT, File)] = Nil,
135+
bodyParams: Seq[(PT, Option[Any])] = Nil
136+
) = MultipartFormData(
137+
dataParts = bodyParams.collect { case (key, Some(value)) =>
138+
(key.toString, Seq(value.toString))
139+
}.toMap,
140+
141+
// TODO: we can potentially use here header-file-names as well (if provided as function's params)
142+
files = fileParams.map { case (key, file) => FilePart(key.toString, file.getPath) }
143+
)
144+
124145
protected def execPOST(
125146
endPoint: PEP,
126147
endPointParam: Option[String] = None,
@@ -141,10 +162,10 @@ trait WSRequestHelper extends WSHelper {
141162
val request = getWSRequestOptional(Some(endPoint), endPointParam, params)
142163
val bodyParamsX = bodyParams.collect { case (fieldName, Some(jsValue)) => (fieldName.toString, jsValue) }
143164

144-
execPOSTAux(request, JsObject(bodyParamsX), Some(endPoint), acceptableStatusCodes)
165+
execPOSTJsonAux(request, JsObject(bodyParamsX), Some(endPoint), acceptableStatusCodes)
145166
}
146167

147-
protected def execPOSTAux[T: BodyWritable](
168+
protected def execPOSTJsonAux[T: BodyWritable](
148169
request: StandaloneWSRequest,
149170
body: T,
150171
endPointForLogging: Option[PEP], // only for logging
@@ -156,6 +177,18 @@ trait WSRequestHelper extends WSHelper {
156177
endPointForLogging
157178
)
158179

180+
protected def execPOSTStringAux[T: BodyWritable](
181+
request: StandaloneWSRequest,
182+
body: T,
183+
endPointForLogging: Option[PEP], // only for logging
184+
acceptableStatusCodes: Seq[Int] = defaultAcceptableStatusCodes
185+
) =
186+
execRequestStringAux(
187+
request, _.post(body),
188+
acceptableStatusCodes,
189+
endPointForLogging
190+
)
191+
159192
////////////
160193
// DELETE //
161194
////////////
@@ -279,9 +312,9 @@ trait WSRequestHelper extends WSHelper {
279312
) =
280313
params.map { case (paramName, value) => (paramName, value.map(toJson)) }
281314

282-
protected def handleErrorResponse(response: RichJsResponse) =
315+
protected def handleErrorResponse[T](response: RichResponse[T]) =
283316
response match {
284-
case Left(json) => json
317+
case Left(data) => data
285318

286319
case Right((errorCode, message)) => throw new OpenAIScalaClientException(s"Code ${errorCode} : ${message}")
287320
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package io.cequence.openaiscala.domain.response
2+
3+
case class TranscriptResponse(
4+
text: String,
5+
verboseJson: Option[String] = None
6+
)

openai-core/src/main/scala/io/cequence/openaiscala/domain/settings/CreateCompletionSettings.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ case class CreateCompletionSettings(
1212
// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). Defaults to 16.
1313
max_tokens: Option[Int] = None,
1414

15-
// What sampling temperature to use. Higher values means the model will take more risks.
16-
// Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
15+
// What sampling temperature to use, between 0 and 2.
16+
// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
1717
// We generally recommend altering this or top_p but not both. Defaults to 1.
1818
temperature: Option[Double] = None,
1919

openai-core/src/main/scala/io/cequence/openaiscala/domain/settings/CreateImageSettings.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ case class CreateImageSettings(
88
size: Option[ImageSizeType.Value] = None,
99

1010
// The format in which the generated images are returned. Must be one of url or b64_json. Defaults to url
11-
response_format: Option[ResponseFormatType.Value] = None,
11+
response_format: Option[ImageResponseFormatType.Value] = None,
1212

1313
// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
1414
user: Option[String] = None
@@ -20,6 +20,6 @@ object ImageSizeType extends Enumeration {
2020
val Large = Value("1024x1024")
2121
}
2222

23-
object ResponseFormatType extends Enumeration {
23+
object ImageResponseFormatType extends Enumeration {
2424
val url, b64_json = Value
2525
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package io.cequence.openaiscala.domain.settings
2+
3+
case class CreateTranscriptionSettings(
4+
// ID of the model to use. Only whisper-1 is currently available.
5+
model: String,
6+
7+
// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
8+
// Defaults to json.
9+
response_format: Option[TranscriptResponseFormatType.Value] = None,
10+
11+
// The sampling temperature, between 0 and 1.
12+
// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
13+
// If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.
14+
// Defaults to 0.
15+
temperature: Option[Double] = None,
16+
17+
// The language of the input audio.
18+
// Supplying the input language in ISO-639-1 ('en', 'de', 'es', etc.) format will improve accuracy and latency.
19+
language: Option[String] = None
20+
)
21+
22+
object TranscriptResponseFormatType extends Enumeration {
23+
val json, text, srt, verbose_json, vtt = Value
24+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.cequence.openaiscala.domain.settings
2+
3+
case class CreateTranslationSettings(
4+
// ID of the model to use. Only whisper-1 is currently available.
5+
model: String,
6+
7+
// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
8+
// Defaults to json.
9+
response_format: Option[TranscriptResponseFormatType.Value] = None,
10+
11+
// The sampling temperature, between 0 and 1.
12+
// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
13+
// If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.
14+
// Defaults to 0.
15+
temperature: Option[Double] = None
16+
)

openai-core/src/main/scala/io/cequence/openaiscala/service/OpenAIService.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,40 @@ trait OpenAIService extends OpenAIServiceConsts {
164164
settings: CreateEmbeddingsSettings = DefaultSettings.CreateEmbeddings
165165
): Future[EmbeddingResponse]
166166

167+
/**
168+
* Transcribes audio into the input language.
169+
*
170+
* @param file The audio file to transcribe, in one of these formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm.
171+
* @param prompt An optional text to guide the model's style or continue a previous audio segment.
172+
* The prompt should match the audio language.
173+
* @param settings
174+
* @return transcription text
175+
*
176+
* @see <a href="https://platform.openai.com/docs/api-reference/audio/create">OpenAI Doc</a>
177+
*/
178+
def createAudioTranscription(
179+
file: File,
180+
prompt: Option[String] = None,
181+
settings: CreateTranscriptionSettings = DefaultSettings.CreateTranscription
182+
): Future[TranscriptResponse]
183+
184+
/**
185+
* Translates audio into into English.
186+
*
187+
* @param file The audio file to translate, in one of these formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm.
188+
* @param prompt An optional text to guide the model's style or continue a previous audio segment.
189+
* The prompt should match the audio language.
190+
* @param settings
191+
* @return translation text
192+
*
193+
* @see <a href="https://platform.openai.com/docs/api-reference/audio/create">OpenAI Doc</a>
194+
*/
195+
def createAudioTranslation(
196+
file: File,
197+
prompt: Option[String] = None,
198+
settings: CreateTranslationSettings = DefaultSettings.CreateTranslation
199+
): Future[TranscriptResponse]
200+
167201
/**
168202
* Returns a list of files that belong to the user's organization.
169203
*

openai-core/src/main/scala/io/cequence/openaiscala/service/OpenAIServiceConsts.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ trait OpenAIServiceConsts {
4949
model = ModelId.text_embedding_ada_002
5050
)
5151

52+
val CreateTranscription = CreateTranscriptionSettings(
53+
model = ModelId.whisper_1,
54+
language = Some("en")
55+
)
56+
57+
val CreateTranslation = CreateTranslationSettings(
58+
model = ModelId.whisper_1
59+
)
60+
5261
val UploadFile = UploadFileSettings(
5362
purpose = "fine-tune"
5463
)

0 commit comments

Comments
 (0)