Skip to content

Use kotlin's explicit API in vertexAI #6313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions firebase-vertexai/firebase-vertexai.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

@file:Suppress("UnstableApiUsage")

import org.jetbrains.kotlin.gradle.tasks.KotlinCompile


plugins {
id("firebase-library")
id("kotlin-android")
Expand Down Expand Up @@ -66,6 +69,21 @@ android {
}
}

// Enable Kotlin "Explicit API Mode". This causes the Kotlin compiler to fail if any
// classes, methods, or properties have implicit `public` visibility. This check helps
// avoid accidentally leaking elements into the public API, requiring that any public
// element be explicitly declared as `public`.
// https://github.com/Kotlin/KEEP/blob/master/proposals/explicit-api-mode.md
// https://chao2zhang.medium.com/explicit-api-mode-for-kotlin-on-android-b8264fdd76d1
tasks.withType<KotlinCompile>().all {
if (!name.contains("test", ignoreCase = true)) {
if (!kotlinOptions.freeCompilerArgs.contains("-Xexplicit-api=strict")) {
kotlinOptions.freeCompilerArgs += "-Xexplicit-api=strict"
}
}
}


dependencies {
val ktorVersion = "2.3.2"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ import kotlinx.coroutines.flow.onEach
* @param model The model to use for the interaction
* @property history The previous interactions with the model
*/
class Chat(private val model: GenerativeModel, val history: MutableList<Content> = ArrayList()) {
public class Chat(
private val model: GenerativeModel,
public val history: MutableList<Content> = ArrayList()
) {
private var lock = Semaphore(1)

/**
Expand All @@ -53,7 +56,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Content): GenerateContentResponse {
public suspend fun sendMessage(prompt: Content): GenerateContentResponse {
prompt.assertComesFromUser()
attemptLock()
try {
Expand All @@ -72,7 +75,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @param prompt The text to be converted into a single piece of [Content] to send to the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: String): GenerateContentResponse {
public suspend fun sendMessage(prompt: String): GenerateContentResponse {
val content = content { text(prompt) }
return sendMessage(content)
}
Expand All @@ -83,7 +86,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @param prompt The image to be converted into a single piece of [Content] to send to the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Bitmap): GenerateContentResponse {
public suspend fun sendMessage(prompt: Bitmap): GenerateContentResponse {
val content = content { image(prompt) }
return sendMessage(content)
}
Expand All @@ -96,7 +99,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: Content): Flow<GenerateContentResponse> {
public fun sendMessageStream(prompt: Content): Flow<GenerateContentResponse> {
prompt.assertComesFromUser()
attemptLock()

Expand Down Expand Up @@ -149,7 +152,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @return A [Flow] which will emit responses as they are returned from the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: String): Flow<GenerateContentResponse> {
public fun sendMessageStream(prompt: String): Flow<GenerateContentResponse> {
val content = content { text(prompt) }
return sendMessageStream(content)
}
Expand All @@ -161,7 +164,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @return A [Flow] which will emit responses as they are returned from the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: Bitmap): Flow<GenerateContentResponse> {
public fun sendMessageStream(prompt: Bitmap): Flow<GenerateContentResponse> {
val content = content { image(prompt) }
return sendMessageStream(content)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import com.google.firebase.vertexai.type.Tool
import com.google.firebase.vertexai.type.ToolConfig

/** Entry point for all _Vertex AI for Firebase_ functionality. */
class FirebaseVertexAI
public class FirebaseVertexAI
internal constructor(
private val firebaseApp: FirebaseApp,
private val location: String,
Expand All @@ -51,7 +51,7 @@ internal constructor(
* @param systemInstruction contains a [Content] that directs the model to behave a certain way
*/
@JvmOverloads
fun generativeModel(
public fun generativeModel(
modelName: String,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
Expand All @@ -77,13 +77,13 @@ internal constructor(
)
}

companion object {
public companion object {
/** The [FirebaseVertexAI] instance for the default [FirebaseApp] */
@JvmStatic
val instance: FirebaseVertexAI
public val instance: FirebaseVertexAI
get() = getInstance(location = "us-central1")

@JvmStatic fun getInstance(app: FirebaseApp): FirebaseVertexAI = getInstance(app)
@JvmStatic public fun getInstance(app: FirebaseApp): FirebaseVertexAI = getInstance(app)

/**
* Returns the [FirebaseVertexAI] instance for the provided [FirebaseApp] and [location]
Expand All @@ -93,19 +93,19 @@ internal constructor(
*/
@JvmStatic
@JvmOverloads
fun getInstance(app: FirebaseApp = Firebase.app, location: String): FirebaseVertexAI {
public fun getInstance(app: FirebaseApp = Firebase.app, location: String): FirebaseVertexAI {
val multiResourceComponent = app[FirebaseVertexAIMultiResourceComponent::class.java]
return multiResourceComponent.get(location)
}
}
}

/** Returns the [FirebaseVertexAI] instance of the default [FirebaseApp]. */
val Firebase.vertexAI: FirebaseVertexAI
public val Firebase.vertexAI: FirebaseVertexAI
get() = FirebaseVertexAI.instance

/** Returns the [FirebaseVertexAI] instance of a given [FirebaseApp]. */
fun Firebase.vertexAI(
public fun Firebase.vertexAI(
app: FirebaseApp = Firebase.app,
location: String = "us-central1"
): FirebaseVertexAI = FirebaseVertexAI.getInstance(app, location)
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import kotlinx.coroutines.tasks.await
/**
* A controller for communicating with the API of a given multimodal model (for example, Gemini).
*/
class GenerativeModel
public class GenerativeModel
internal constructor(
private val modelName: String,
private val generationConfig: GenerationConfig? = null,
Expand Down Expand Up @@ -128,7 +128,7 @@ internal constructor(
* @return A [GenerateContentResponse]. Function should be called within a suspend context to
* properly manage concurrency.
*/
suspend fun generateContent(vararg prompt: Content): GenerateContentResponse =
public suspend fun generateContent(vararg prompt: Content): GenerateContentResponse =
try {
controller.generateContent(constructRequest(*prompt)).toPublic().validate()
} catch (e: Throwable) {
Expand All @@ -141,7 +141,7 @@ internal constructor(
* @param prompt [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
*/
fun generateContentStream(vararg prompt: Content): Flow<GenerateContentResponse> =
public fun generateContentStream(vararg prompt: Content): Flow<GenerateContentResponse> =
controller
.generateContentStream(constructRequest(*prompt))
.catch { throw FirebaseVertexAIException.from(it) }
Expand All @@ -154,7 +154,7 @@ internal constructor(
* @return A [GenerateContentResponse] after some delay. Function should be called within a
* suspend context to properly manage concurrency.
*/
suspend fun generateContent(prompt: String): GenerateContentResponse =
public suspend fun generateContent(prompt: String): GenerateContentResponse =
generateContent(content { text(prompt) })

/**
Expand All @@ -163,7 +163,7 @@ internal constructor(
* @param prompt The text to be converted into a single piece of [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
*/
fun generateContentStream(prompt: String): Flow<GenerateContentResponse> =
public fun generateContentStream(prompt: String): Flow<GenerateContentResponse> =
generateContentStream(content { text(prompt) })

/**
Expand All @@ -173,7 +173,7 @@ internal constructor(
* @return A [GenerateContentResponse] after some delay. Function should be called within a
* suspend context to properly manage concurrency.
*/
suspend fun generateContent(prompt: Bitmap): GenerateContentResponse =
public suspend fun generateContent(prompt: Bitmap): GenerateContentResponse =
generateContent(content { image(prompt) })

/**
Expand All @@ -182,19 +182,20 @@ internal constructor(
* @param prompt The image to be converted into a single piece of [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
*/
fun generateContentStream(prompt: Bitmap): Flow<GenerateContentResponse> =
public fun generateContentStream(prompt: Bitmap): Flow<GenerateContentResponse> =
generateContentStream(content { image(prompt) })

/** Creates a [Chat] instance which internally tracks the ongoing conversation with the model */
fun startChat(history: List<Content> = emptyList()): Chat = Chat(this, history.toMutableList())
public fun startChat(history: List<Content> = emptyList()): Chat =
Chat(this, history.toMutableList())

/**
* Counts the amount of tokens in a prompt.
*
* @param prompt A group of [Content] to count tokens of.
* @return A [CountTokensResponse] containing the amount of tokens in the prompt.
*/
suspend fun countTokens(vararg prompt: Content): CountTokensResponse {
public suspend fun countTokens(vararg prompt: Content): CountTokensResponse {
try {
return controller.countTokens(constructCountTokensRequest(*prompt)).toPublic()
} catch (e: Throwable) {
Expand All @@ -208,7 +209,7 @@ internal constructor(
* @param prompt The text to be converted to a single piece of [Content] to count the tokens of.
* @return A [CountTokensResponse] containing the amount of tokens in the prompt.
*/
suspend fun countTokens(prompt: String): CountTokensResponse {
public suspend fun countTokens(prompt: String): CountTokensResponse {
return countTokens(content { text(prompt) })
}

Expand All @@ -218,7 +219,7 @@ internal constructor(
* @param prompt The image to be converted to a single piece of [Content] to count the tokens of.
* @return A [CountTokensResponse] containing the amount of tokens in the prompt.
*/
suspend fun countTokens(prompt: Bitmap): CountTokensResponse {
public suspend fun countTokens(prompt: Bitmap): CountTokensResponse {
return countTokens(content { image(prompt) })
}

Expand Down Expand Up @@ -247,7 +248,7 @@ internal constructor(
?.let { throw ResponseStoppedException(this) }
}

companion object {
private companion object {
private val TAG = GenerativeModel::class.java.simpleName
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ internal enum class HarmCategory {
@SerialName("HARM_CATEGORY_DANGEROUS_CONTENT") DANGEROUS_CONTENT
}

typealias Base64 = String
internal typealias Base64 = String

@ExperimentalSerializationApi
@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,25 @@ import org.reactivestreams.Publisher
*
* @see from
*/
abstract class ChatFutures internal constructor() {
public abstract class ChatFutures internal constructor() {

/**
* Generates a response from the backend with the provided [Content], and any previous ones
* sent/returned from this chat.
*
* @param prompt A [Content] to send to the model.
*/
abstract fun sendMessage(prompt: Content): ListenableFuture<GenerateContentResponse>
public abstract fun sendMessage(prompt: Content): ListenableFuture<GenerateContentResponse>

/**
* Generates a streaming response from the backend with the provided [Content].
*
* @param prompt A [Content] to send to the model.
*/
abstract fun sendMessageStream(prompt: Content): Publisher<GenerateContentResponse>
public abstract fun sendMessageStream(prompt: Content): Publisher<GenerateContentResponse>

/** Returns the [Chat] instance that was used to create this instance */
abstract fun getChat(): Chat
public abstract fun getChat(): Chat

private class FuturesImpl(private val chat: Chat) : ChatFutures() {
override fun sendMessage(prompt: Content): ListenableFuture<GenerateContentResponse> =
Expand All @@ -59,9 +59,9 @@ abstract class ChatFutures internal constructor() {
override fun getChat(): Chat = chat
}

companion object {
public companion object {

/** @return a [ChatFutures] created around the provided [Chat] */
@JvmStatic fun from(chat: Chat): ChatFutures = FuturesImpl(chat)
@JvmStatic public fun from(chat: Chat): ChatFutures = FuturesImpl(chat)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,41 +31,45 @@ import org.reactivestreams.Publisher
*
* @see from
*/
abstract class GenerativeModelFutures internal constructor() {
public abstract class GenerativeModelFutures internal constructor() {

/**
* Generates a response from the backend with the provided [Content].
*
* @param prompt A group of [Content] to send to the model.
*/
abstract fun generateContent(vararg prompt: Content): ListenableFuture<GenerateContentResponse>
public abstract fun generateContent(
vararg prompt: Content
): ListenableFuture<GenerateContentResponse>

/**
* Generates a streaming response from the backend with the provided [Content].
*
* @param prompt A group of [Content] to send to the model.
*/
abstract fun generateContentStream(vararg prompt: Content): Publisher<GenerateContentResponse>
public abstract fun generateContentStream(
vararg prompt: Content
): Publisher<GenerateContentResponse>

/**
* Counts the number of tokens used in a prompt.
*
* @param prompt A group of [Content] to count tokens of.
*/
abstract fun countTokens(vararg prompt: Content): ListenableFuture<CountTokensResponse>
public abstract fun countTokens(vararg prompt: Content): ListenableFuture<CountTokensResponse>

/** Creates a chat instance which internally tracks the ongoing conversation with the model */
abstract fun startChat(): ChatFutures
public abstract fun startChat(): ChatFutures

/**
* Creates a chat instance which internally tracks the ongoing conversation with the model
*
* @param history an existing history of context to use as a starting point
*/
abstract fun startChat(history: List<Content>): ChatFutures
public abstract fun startChat(history: List<Content>): ChatFutures

/** Returns the [GenerativeModel] instance that was used to create this object */
abstract fun getGenerativeModel(): GenerativeModel
public abstract fun getGenerativeModel(): GenerativeModel

private class FuturesImpl(private val model: GenerativeModel) : GenerativeModelFutures() {
override fun generateContent(
Expand All @@ -86,9 +90,9 @@ abstract class GenerativeModelFutures internal constructor() {
override fun getGenerativeModel(): GenerativeModel = model
}

companion object {
public companion object {

/** @return a [GenerativeModelFutures] created around the provided [GenerativeModel] */
@JvmStatic fun from(model: GenerativeModel): GenerativeModelFutures = FuturesImpl(model)
@JvmStatic public fun from(model: GenerativeModel): GenerativeModelFutures = FuturesImpl(model)
}
}
Loading
Loading