Skip to content

Commit efec83d

Browse files
committed
Google gemini examples - generate content cached, with inline data, and with openai adapter
1 parent 9823f74 commit efec83d

File tree

4 files changed

+262
-3
lines changed

4 files changed

+262
-3
lines changed

openai-examples/src/main/scala/io/cequence/openaiscala/examples/CreateChatCompletionStreamed.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package io.cequence.openaiscala.examples
33
import akka.stream.scaladsl.Sink
44
import io.cequence.openaiscala.domain._
55
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
6-
import io.cequence.openaiscala.service.StreamedServiceTypes.OpenAIStreamedService
6+
import io.cequence.openaiscala.domain.settings.ReasoningEffort.medium
77
import io.cequence.openaiscala.service.OpenAIServiceFactory
88
import io.cequence.openaiscala.service.OpenAIStreamedServiceImplicits._
9+
import io.cequence.openaiscala.service.StreamedServiceTypes.OpenAIStreamedService
10+
911
import scala.concurrent.Future
1012

1113
// requires `openai-scala-client-stream` as a dependency
@@ -26,8 +28,7 @@ object CreateChatCompletionStreamed extends ExampleBase[OpenAIStreamedService] {
2628
)
2729
.runWith(
2830
Sink.foreach { completion =>
29-
val content = completion.choices.headOption.flatMap(_.delta.content)
30-
print(content.getOrElse(""))
31+
print(completion.contentHead.getOrElse(""))
3132
}
3233
)
3334
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package io.cequence.openaiscala.examples.nonopenai
2+
3+
import com.typesafe.scalalogging.Logger
4+
import io.cequence.openaiscala.domain.NonOpenAIModelId
5+
import io.cequence.openaiscala.examples.ExampleBase
6+
import io.cequence.openaiscala.gemini.domain.ChatRole.User
7+
import io.cequence.openaiscala.gemini.domain.settings.{
8+
GenerateContentSettings,
9+
GenerationConfig
10+
}
11+
import io.cequence.openaiscala.gemini.domain.{CachedContent, Content, Expiration}
12+
import io.cequence.openaiscala.gemini.service.{GeminiService, GeminiServiceFactory}
13+
import org.slf4j.LoggerFactory
14+
15+
import scala.concurrent.Future
16+
import scala.io.Source
17+
18+
// requires `openai-scala-google-gemini-client` as a dependency and `GOOGLE_API_KEY` environment variable to be set
19+
object GoogleGeminiGenerateContentCached extends ExampleBase[GeminiService] {
20+
21+
override protected val service: GeminiService = GeminiServiceFactory()
22+
23+
protected val logger: Logger = Logger(LoggerFactory.getLogger(this.getClass))
24+
25+
private val systemPrompt = "You are a helpful assistant and expert in Norway."
26+
private val userPrompt = "Write the section 'Higher education in Norway' verbatim."
27+
private val knowledgeFile = getClass.getResource("/norway_wiki.md").getFile
28+
29+
private lazy val knowledgeContent = {
30+
val source = Source.fromFile(knowledgeFile)
31+
try source.mkString("")
32+
finally source.close()
33+
}
34+
35+
private val model = NonOpenAIModelId.gemini_1_5_flash_002
36+
37+
private val knowledgeTextContent: Content =
38+
Content.textPart(
39+
knowledgeContent,
40+
User
41+
)
42+
43+
override protected def run: Future[_] = {
44+
def listCachedContents =
45+
service.listCachedContents().map { cachedContentsResponse =>
46+
logger.info(
47+
s"Cached contents: ${cachedContentsResponse.cachedContents.flatMap(_.name).mkString(", ")}"
48+
)
49+
}
50+
51+
for {
52+
_ <- listCachedContents
53+
54+
saveCachedContent <- service.createCachedContent(
55+
CachedContent(
56+
contents = Seq(knowledgeTextContent),
57+
systemInstruction = Some(Content.textPart(systemPrompt, User)),
58+
model = model
59+
)
60+
)
61+
62+
cachedContentName = saveCachedContent.name.get
63+
64+
_ = logger.info(s"${cachedContentName} - expire time : " + saveCachedContent.expireTime)
65+
66+
_ <- listCachedContents
67+
68+
updatedCachedContent <- service.updateCachedContent(
69+
cachedContentName,
70+
Expiration.TTL("60s")
71+
)
72+
73+
_ = logger.info(
74+
s"${cachedContentName} - new expire time : " + updatedCachedContent.expireTime
75+
)
76+
77+
response <- service.generateContent(
78+
Seq(Content.textPart(userPrompt, User)),
79+
settings = GenerateContentSettings(
80+
model = model,
81+
generationConfig = Some(
82+
GenerationConfig(
83+
maxOutputTokens = Some(2000),
84+
temperature = Some(0.2)
85+
)
86+
),
87+
cachedContent = Some(cachedContentName)
88+
)
89+
)
90+
91+
_ = logger.info("Response : " + response.contentHeadText)
92+
93+
_ = {
94+
val usage = response.usageMetadata
95+
logger.info(
96+
s"""Usage
97+
|Prompt tokens : ${usage.promptTokenCount}
98+
|(cached) : ${usage.cachedContentTokenCount.getOrElse(0)}
99+
|Candidate tokens: : ${usage.candidatesTokenCount.getOrElse(0)}
100+
|Total tokens : ${usage.totalTokenCount}""".stripMargin
101+
)
102+
}
103+
104+
cachedContentNameNew <- service.getCachedContent(cachedContentName)
105+
106+
_ = logger.info(
107+
s"${cachedContentNameNew.name.get} - expire time : " + cachedContentNameNew.expireTime
108+
)
109+
110+
_ <- service.deleteCachedContent(cachedContentName)
111+
112+
_ <- listCachedContents
113+
} yield ()
114+
}
115+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package io.cequence.openaiscala.examples.nonopenai
2+
3+
import com.typesafe.scalalogging.Logger
4+
import io.cequence.openaiscala.domain.NonOpenAIModelId
5+
import io.cequence.openaiscala.examples.ExampleBase
6+
import io.cequence.openaiscala.gemini.domain.ChatRole.User
7+
import io.cequence.openaiscala.gemini.domain.settings.{
8+
GenerateContentSettings,
9+
GenerationConfig
10+
}
11+
import io.cequence.openaiscala.gemini.domain.{CachedContent, Content, Part}
12+
import io.cequence.openaiscala.gemini.service.{GeminiService, GeminiServiceFactory}
13+
import org.slf4j.LoggerFactory
14+
15+
import java.util.Base64
16+
import scala.concurrent.Future
17+
import scala.io.Source
18+
19+
// requires `openai-scala-google-gemini-client` as a dependency and `GOOGLE_API_KEY` environment variable to be set
20+
object GoogleGeminiGenerateContentCachedWithInlineData extends ExampleBase[GeminiService] {
21+
22+
override protected val service: GeminiService = GeminiServiceFactory()
23+
24+
protected val logger: Logger = Logger(LoggerFactory.getLogger(this.getClass))
25+
26+
private val systemPrompt = "You are a helpful assistant and expert in Norway."
27+
private val userPrompt = "Write the section 'Higher education in Norway' verbatim."
28+
private val knowledgeFile = getClass.getResource("/norway_wiki.md").getFile
29+
30+
private lazy val knowledgeContent = {
31+
val source = Source.fromFile(knowledgeFile)
32+
try source.mkString("")
33+
finally source.close()
34+
}
35+
36+
private val model = NonOpenAIModelId.gemini_1_5_flash_002
37+
38+
private val knowledgeInlineData: Content =
39+
Content(
40+
User,
41+
Part.InlineData(
42+
mimeType = "text/plain",
43+
data = Base64.getEncoder.encodeToString(knowledgeContent.getBytes("UTF-8"))
44+
)
45+
)
46+
47+
override protected def run: Future[_] =
48+
for {
49+
// create cached content
50+
saveCachedContent <- service.createCachedContent(
51+
CachedContent(
52+
contents = Seq(knowledgeInlineData),
53+
systemInstruction = Some(Content.textPart(systemPrompt, User)),
54+
model = model
55+
)
56+
)
57+
58+
cachedContentName = saveCachedContent.name.get
59+
60+
_ = logger.info(s"${cachedContentName} - expire time : " + saveCachedContent.expireTime)
61+
62+
// chat completion with cached content
63+
response <- service.generateContent(
64+
Seq(Content.textPart(userPrompt, User)),
65+
settings = GenerateContentSettings(
66+
model = model,
67+
generationConfig = Some(
68+
GenerationConfig(
69+
maxOutputTokens = Some(2000),
70+
temperature = Some(0.2)
71+
)
72+
),
73+
cachedContent = Some(cachedContentName)
74+
)
75+
)
76+
77+
// response
78+
_ = logger.info("Response: " + response.contentHeadText)
79+
80+
// clean up
81+
_ <- service.deleteCachedContent(cachedContentName)
82+
} yield ()
83+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package io.cequence.openaiscala.examples.nonopenai
2+
3+
import com.typesafe.scalalogging.Logger
4+
import io.cequence.openaiscala.domain.NonOpenAIModelId
5+
import io.cequence.openaiscala.examples.ExampleBase
6+
import io.cequence.openaiscala.gemini.domain.ChatRole.User
7+
import io.cequence.openaiscala.gemini.domain.{CachedContent, Content}
8+
import io.cequence.openaiscala.gemini.service.{GeminiService, GeminiServiceFactory}
9+
import io.cequence.openaiscala.service.StreamedServiceTypes.OpenAIChatCompletionStreamedService
10+
import org.slf4j.LoggerFactory
11+
12+
import scala.concurrent.Future
13+
import scala.io.Source
14+
15+
// requires `openai-scala-google-gemini-client` as a dependency and `GOOGLE_API_KEY` environment variable to be set
16+
object GoogleGeminiGenerateContentCachedWithOpenAIAdapter
17+
extends ExampleBase[OpenAIChatCompletionStreamedService] {
18+
19+
override val service: OpenAIChatCompletionStreamedService = GeminiServiceFactory.asOpenAI()
20+
21+
private val rawGeminiService: GeminiService = GeminiServiceFactory()
22+
23+
protected val logger: Logger = Logger(LoggerFactory.getLogger(this.getClass))
24+
25+
private val systemPrompt = "You are a helpful assistant and expert in Norway."
26+
private val userPrompt = "Write the section 'Higher education in Norway' verbatim."
27+
private val knowledgeFile = getClass.getResource("/norway_wiki.md").getFile
28+
29+
private lazy val knowledgeContent = {
30+
val source = Source.fromFile(knowledgeFile)
31+
try source.mkString("")
32+
finally source.close()
33+
}
34+
35+
private val model = NonOpenAIModelId.gemini_1_5_flash_002
36+
37+
private val knowledgeTextContent: Content =
38+
Content.textPart(
39+
knowledgeContent,
40+
User
41+
)
42+
43+
// TODO
44+
override protected def run: Future[_] =
45+
for {
46+
saveCachedContent <- rawGeminiService.createCachedContent(
47+
CachedContent(
48+
contents = Seq(knowledgeTextContent),
49+
systemInstruction = Some(Content.textPart(systemPrompt, User)),
50+
model = model
51+
)
52+
)
53+
//
54+
// response <- service.createChatCompletion(
55+
//
56+
// )
57+
58+
_ <- rawGeminiService.deleteCachedContent(saveCachedContent.name.get)
59+
} yield ()
60+
}

0 commit comments

Comments
 (0)