Skip to content

Commit d8469ab

Browse files
authored
Merge 38806af into 5fe51eb
2 parents 5fe51eb + 38806af commit d8469ab

File tree

65 files changed

+3609
-123
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+3609
-123
lines changed

build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ buildscript {
4141
classpath 'com.google.firebase:firebase-appdistribution-gradle:5.0.0'
4242
classpath 'com.google.firebase:firebase-crashlytics-gradle:2.9.5'
4343
classpath "com.diffplug.spotless:spotless-plugin-gradle:7.0.0.BETA1"
44+
classpath "org.jetbrains.kotlin:kotlin-serialization:1.8.22"
4445
}
4546
}
4647

firebase-vertexai/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Unreleased
2+
* [changed] Merged core networking code into VertexAI from a separate library
23
* [feature] added support for `responseSchema` in `GenerationConfig`.
34

45
# 16.0.0-beta03

firebase-vertexai/firebase-vertexai.gradle.kts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
plugins {
2020
id("firebase-library")
2121
id("kotlin-android")
22+
kotlin("plugin.serialization")
2223
}
2324

2425
firebaseLibrary {
@@ -56,12 +57,19 @@ android {
5657
}
5758

5859
dependencies {
59-
api("com.google.firebase:firebase-common:21.0.0")
60+
val ktorVersion = "2.3.2"
61+
62+
implementation("io.ktor:ktor-client-okhttp:$ktorVersion")
63+
implementation("io.ktor:ktor-client-core:$ktorVersion")
64+
implementation("io.ktor:ktor-client-content-negotiation:$ktorVersion")
65+
implementation("io.ktor:ktor-serialization-kotlinx-json:$ktorVersion")
66+
implementation("io.ktor:ktor-client-logging:$ktorVersion")
67+
compileOnly("io.ktor:ktor-client-mock:$ktorVersion")
6068

69+
implementation("com.google.firebase:firebase-common:21.0.0")
6170
implementation("com.google.firebase:firebase-components:18.0.0")
6271
implementation("com.google.firebase:firebase-annotations:16.2.0")
6372
implementation("com.google.firebase:firebase-appcheck-interop:17.1.0")
64-
implementation("com.google.ai.client.generativeai:common:0.10.0")
6573
implementation(libs.androidx.annotation)
6674
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
6775
implementation("androidx.core:core-ktx:1.12.0")
@@ -74,9 +82,9 @@ dependencies {
7482
implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha03")
7583
implementation("com.google.firebase:firebase-auth-interop:18.0.0")
7684

77-
val ktorVersion = "2.3.2"
7885
testImplementation("io.kotest:kotest-assertions-core:5.5.5")
7986
testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5")
87+
testImplementation("io.kotest:kotest-assertions-json:5.5.5")
8088
testImplementation("io.ktor:ktor-client-okhttp:$ktorVersion")
8189
testImplementation("io.ktor:ktor-client-mock:$ktorVersion")
8290
testImplementation("org.json:json:20240303")

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ package com.google.firebase.vertexai
1818

1919
import android.graphics.Bitmap
2020
import android.util.Log
21-
import com.google.ai.client.generativeai.common.APIController
22-
import com.google.ai.client.generativeai.common.CountTokensRequest
23-
import com.google.ai.client.generativeai.common.GenerateContentRequest
24-
import com.google.ai.client.generativeai.common.HeaderProvider
2521
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
2622
import com.google.firebase.auth.internal.InternalAuthProvider
23+
import com.google.firebase.vertexai.common.APIController
24+
import com.google.firebase.vertexai.common.CountTokensRequest
25+
import com.google.firebase.vertexai.common.GenerateContentRequest
26+
import com.google.firebase.vertexai.common.HeaderProvider
2727
import com.google.firebase.vertexai.internal.util.toInternal
2828
import com.google.firebase.vertexai.internal.util.toPublic
2929
import com.google.firebase.vertexai.type.Content
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.firebase.vertexai.common
18+
19+
import android.util.Log
20+
import androidx.annotation.VisibleForTesting
21+
import com.google.firebase.vertexai.common.server.FinishReason
22+
import com.google.firebase.vertexai.common.util.decodeToFlow
23+
import com.google.firebase.vertexai.common.util.fullModelName
24+
import io.ktor.client.HttpClient
25+
import io.ktor.client.call.body
26+
import io.ktor.client.engine.HttpClientEngine
27+
import io.ktor.client.engine.mock.MockEngine
28+
import io.ktor.client.engine.mock.respond
29+
import io.ktor.client.engine.okhttp.OkHttp
30+
import io.ktor.client.plugins.HttpTimeout
31+
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
32+
import io.ktor.client.request.HttpRequestBuilder
33+
import io.ktor.client.request.header
34+
import io.ktor.client.request.post
35+
import io.ktor.client.request.preparePost
36+
import io.ktor.client.request.setBody
37+
import io.ktor.client.statement.HttpResponse
38+
import io.ktor.client.statement.bodyAsChannel
39+
import io.ktor.client.statement.bodyAsText
40+
import io.ktor.http.ContentType
41+
import io.ktor.http.HttpHeaders
42+
import io.ktor.http.HttpStatusCode
43+
import io.ktor.http.contentType
44+
import io.ktor.http.headersOf
45+
import io.ktor.serialization.kotlinx.json.json
46+
import io.ktor.utils.io.ByteChannel
47+
import kotlin.time.Duration
48+
import kotlinx.coroutines.CoroutineName
49+
import kotlinx.coroutines.TimeoutCancellationException
50+
import kotlinx.coroutines.flow.Flow
51+
import kotlinx.coroutines.flow.catch
52+
import kotlinx.coroutines.flow.channelFlow
53+
import kotlinx.coroutines.flow.map
54+
import kotlinx.coroutines.launch
55+
import kotlinx.coroutines.withTimeout
56+
import kotlinx.serialization.json.Json
57+
58+
internal val JSON = Json {
59+
ignoreUnknownKeys = true
60+
prettyPrint = false
61+
isLenient = true
62+
}
63+
64+
/**
65+
* Backend class for interfacing with the Gemini API.
66+
*
67+
* This class handles making HTTP requests to the API and streaming the responses back.
68+
*
69+
* @param httpEngine The HTTP client engine to be used for making requests. Defaults to CIO engine.
70+
* Exposed primarily for DI in tests.
71+
* @property key The API key used for authentication.
72+
* @property model The model to use for generation.
73+
* @property apiClient The value to pass in the `x-goog-api-client` header.
74+
* @property headerProvider A provider that generates extra headers to include in all HTTP requests.
75+
*/
76+
internal class APIController
77+
internal constructor(
78+
private val key: String,
79+
model: String,
80+
private val requestOptions: RequestOptions,
81+
httpEngine: HttpClientEngine,
82+
private val apiClient: String,
83+
private val headerProvider: HeaderProvider?,
84+
) {
85+
86+
constructor(
87+
key: String,
88+
model: String,
89+
requestOptions: RequestOptions,
90+
apiClient: String,
91+
headerProvider: HeaderProvider? = null,
92+
) : this(key, model, requestOptions, OkHttp.create(), apiClient, headerProvider)
93+
94+
@VisibleForTesting(otherwise = VisibleForTesting.NONE)
95+
constructor(
96+
key: String,
97+
model: String,
98+
requestOptions: RequestOptions,
99+
apiClient: String,
100+
headerProvider: HeaderProvider?,
101+
channel: ByteChannel,
102+
status: HttpStatusCode,
103+
) : this(
104+
key,
105+
model,
106+
requestOptions,
107+
MockEngine { respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) },
108+
apiClient,
109+
headerProvider,
110+
)
111+
112+
private val model = fullModelName(model)
113+
114+
private val client =
115+
HttpClient(httpEngine) {
116+
install(HttpTimeout) {
117+
requestTimeoutMillis = requestOptions.timeout.inWholeMilliseconds
118+
socketTimeoutMillis = 80_000
119+
}
120+
install(ContentNegotiation) { json(JSON) }
121+
}
122+
123+
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
124+
try {
125+
client
126+
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") {
127+
applyCommonConfiguration(request)
128+
applyHeaderProvider()
129+
}
130+
.also { validateResponse(it) }
131+
.body<GenerateContentResponse>()
132+
.validate()
133+
} catch (e: Throwable) {
134+
throw FirebaseCommonAIException.from(e)
135+
}
136+
137+
fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> =
138+
client
139+
.postStream<GenerateContentResponse>(
140+
"${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
141+
) {
142+
applyCommonConfiguration(request)
143+
}
144+
.map { it.validate() }
145+
.catch { throw FirebaseCommonAIException.from(it) }
146+
147+
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
148+
try {
149+
client
150+
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") {
151+
applyCommonConfiguration(request)
152+
applyHeaderProvider()
153+
}
154+
.also { validateResponse(it) }
155+
.body()
156+
} catch (e: Throwable) {
157+
throw FirebaseCommonAIException.from(e)
158+
}
159+
160+
private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) {
161+
when (request) {
162+
is GenerateContentRequest -> setBody<GenerateContentRequest>(request)
163+
is CountTokensRequest -> setBody<CountTokensRequest>(request)
164+
}
165+
contentType(ContentType.Application.Json)
166+
header("x-goog-api-key", key)
167+
header("x-goog-api-client", apiClient)
168+
}
169+
170+
private suspend fun HttpRequestBuilder.applyHeaderProvider() {
171+
if (headerProvider != null) {
172+
try {
173+
withTimeout(headerProvider.timeout) {
174+
for ((tag, value) in headerProvider.generateHeaders()) {
175+
header(tag, value)
176+
}
177+
}
178+
} catch (e: TimeoutCancellationException) {
179+
Log.w(TAG, "HeaderProvided timed out without generating headers, ignoring")
180+
}
181+
}
182+
}
183+
184+
/**
185+
* Makes a POST request to the specified [url] and returns a [Flow] of deserialized response
186+
* objects of type [R]. The response is expected to be a stream of JSON objects that are parsed in
187+
* real-time as they are received from the server.
188+
*
189+
* This function is intended for internal use within the client that handles streaming responses.
190+
*
191+
* Example usage:
192+
* ```
193+
* val client: HttpClient = HttpClient(CIO)
194+
* val request: Request = GenerateContentRequest(...)
195+
* val url: String = "http://example.com/stream"
196+
*
197+
* val responses: GenerateContentResponse = client.postStream(url) {
198+
* setBody(request)
199+
* contentType(ContentType.Application.Json)
200+
* }
201+
* responses.collect {
202+
* println("Got a response: $it")
203+
* }
204+
* ```
205+
*
206+
* @param R The type of the response object.
207+
* @param url The URL to which the POST request will be made.
208+
* @param config An optional [HttpRequestBuilder] callback for request configuration.
209+
* @return A [Flow] of response objects of type [R].
210+
*/
211+
private inline fun <reified R : Response> HttpClient.postStream(
212+
url: String,
213+
crossinline config: HttpRequestBuilder.() -> Unit = {},
214+
): Flow<R> = channelFlow {
215+
launch(CoroutineName("postStream")) {
216+
preparePost(url) {
217+
applyHeaderProvider()
218+
config()
219+
}
220+
.execute {
221+
validateResponse(it)
222+
223+
val channel = it.bodyAsChannel()
224+
val flow = JSON.decodeToFlow<R>(channel)
225+
226+
flow.collect { send(it) }
227+
}
228+
}
229+
}
230+
231+
companion object {
232+
private val TAG = APIController::class.java.simpleName
233+
}
234+
}
235+
236+
internal interface HeaderProvider {
237+
val timeout: Duration
238+
239+
suspend fun generateHeaders(): Map<String, String>
240+
}
241+
242+
private suspend fun validateResponse(response: HttpResponse) {
243+
if (response.status == HttpStatusCode.OK) return
244+
val text = response.bodyAsText()
245+
val error =
246+
try {
247+
JSON.decodeFromString<GRpcErrorResponse>(text).error
248+
} catch (e: Throwable) {
249+
throw ServerException("Unexpected Response:\n$text $e")
250+
}
251+
val message = error.message
252+
if (message.contains("API key not valid")) {
253+
throw InvalidAPIKeyException(message)
254+
}
255+
// TODO (b/325117891): Use a better method than string matching.
256+
if (message == "User location is not supported for the API use.") {
257+
throw UnsupportedUserLocationException()
258+
}
259+
if (message.contains("quota")) {
260+
throw QuotaExceededException(message)
261+
}
262+
if (error.details?.any { "SERVICE_DISABLED" == it.reason } == true) {
263+
throw ServiceDisabledException(message)
264+
}
265+
throw ServerException(message)
266+
}
267+
268+
private fun GenerateContentResponse.validate() = apply {
269+
if ((candidates?.isEmpty() != false) && promptFeedback == null) {
270+
throw SerializationException("Error deserializing response, found no valid fields")
271+
}
272+
promptFeedback?.blockReason?.let { throw PromptBlockedException(this) }
273+
candidates
274+
?.mapNotNull { it.finishReason }
275+
?.firstOrNull { it != FinishReason.STOP }
276+
?.let { throw ResponseStoppedException(this) }
277+
}

0 commit comments

Comments
 (0)