Skip to content

Commit ebdf00a

Browse files
committed
Merge common testing
1 parent b2ddfdc commit ebdf00a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2324
-1
lines changed

firebase-vertexai/firebase-vertexai.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ dependencies {
7070
implementation("com.google.firebase:firebase-components:18.0.0")
7171
implementation("com.google.firebase:firebase-annotations:16.2.0")
7272
implementation("com.google.firebase:firebase-appcheck-interop:17.1.0")
73-
implementation("com.google.ai.client.generativeai:common:0.9.0")
7473
implementation(libs.androidx.annotation)
7574
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
7675
implementation("androidx.core:core-ktx:1.12.0")
@@ -85,6 +84,7 @@ dependencies {
8584

8685
testImplementation("io.kotest:kotest-assertions-core:5.5.5")
8786
testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5")
87+
testImplementation("io.kotest:kotest-assertions-json:5.5.5")
8888
testImplementation("io.ktor:ktor-client-okhttp:$ktorVersion")
8989
testImplementation("io.ktor:ktor-client-mock:$ktorVersion")
9090
testImplementation("org.json:json:20240303")
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.firebase.vertexai.common
18+
19+
import com.google.firebase.vertexai.BuildConfig
20+
import com.google.firebase.vertexai.common.client.FunctionCallingConfig
21+
import com.google.firebase.vertexai.common.client.Tool
22+
import com.google.firebase.vertexai.common.client.ToolConfig
23+
import com.google.firebase.vertexai.common.shared.Content
24+
import com.google.firebase.vertexai.common.shared.TextPart
25+
import com.google.firebase.vertexai.common.util.commonTest
26+
import com.google.firebase.vertexai.common.util.createResponses
27+
import com.google.firebase.vertexai.common.util.doBlocking
28+
import com.google.firebase.vertexai.common.util.prepareStreamingResponse
29+
import io.kotest.assertions.json.shouldContainJsonKey
30+
import io.kotest.assertions.throwables.shouldThrow
31+
import io.kotest.matchers.shouldBe
32+
import io.kotest.matchers.string.shouldContain
33+
import io.ktor.client.engine.mock.MockEngine
34+
import io.ktor.client.engine.mock.respond
35+
import io.ktor.content.TextContent
36+
import io.ktor.http.HttpHeaders
37+
import io.ktor.http.HttpStatusCode
38+
import io.ktor.http.headersOf
39+
import io.ktor.utils.io.ByteChannel
40+
import io.ktor.utils.io.close
41+
import io.ktor.utils.io.writeFully
42+
import kotlin.time.Duration
43+
import kotlin.time.Duration.Companion.milliseconds
44+
import kotlin.time.Duration.Companion.seconds
45+
import kotlinx.coroutines.delay
46+
import kotlinx.coroutines.withTimeout
47+
import kotlinx.serialization.encodeToString
48+
import kotlinx.serialization.json.JsonObject
49+
import org.junit.Test
50+
import org.junit.runner.RunWith
51+
import org.junit.runners.Parameterized
52+
53+
private val TEST_CLIENT_ID = "genai-android/test"
54+
55+
internal class APIControllerTests {
56+
private val testTimeout = 5.seconds
57+
58+
@Test
59+
fun `(generateContentStream) emits responses as they come in`() = commonTest {
60+
val response = createResponses("The", " world", " is", " a", " beautiful", " place!")
61+
val bytes = prepareStreamingResponse(response)
62+
63+
bytes.forEach { channel.writeFully(it) }
64+
val responses = apiController.generateContentStream(textGenerateContentRequest("test"))
65+
66+
withTimeout(testTimeout) {
67+
responses.collect {
68+
it.candidates?.isEmpty() shouldBe false
69+
channel.close()
70+
}
71+
}
72+
}
73+
74+
@Test
75+
fun `(generateContent) respects a custom timeout`() =
76+
commonTest(requestOptions = RequestOptions(2.seconds)) {
77+
shouldThrow<RequestTimeoutException> {
78+
withTimeout(testTimeout) {
79+
apiController.generateContent(textGenerateContentRequest("test"))
80+
}
81+
}
82+
}
83+
}
84+
85+
internal class RequestFormatTests {
86+
@Test
87+
fun `using default endpoint`() = doBlocking {
88+
val channel = ByteChannel(autoFlush = true)
89+
val mockEngine = MockEngine {
90+
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
91+
}
92+
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
93+
val controller =
94+
APIController(
95+
"super_cool_test_key",
96+
"gemini-pro-1.5",
97+
RequestOptions(),
98+
mockEngine,
99+
"genai-android/${BuildConfig.VERSION_NAME}",
100+
null,
101+
)
102+
103+
withTimeout(5.seconds) {
104+
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
105+
it.candidates?.isEmpty() shouldBe false
106+
channel.close()
107+
}
108+
}
109+
110+
mockEngine.requestHistory.first().url.host shouldBe "generativelanguage.googleapis.com"
111+
}
112+
113+
@Test
114+
fun `using custom endpoint`() = doBlocking {
115+
val channel = ByteChannel(autoFlush = true)
116+
val mockEngine = MockEngine {
117+
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
118+
}
119+
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
120+
val controller =
121+
APIController(
122+
"super_cool_test_key",
123+
"gemini-pro-1.5",
124+
RequestOptions(endpoint = "https://my.custom.endpoint"),
125+
mockEngine,
126+
TEST_CLIENT_ID,
127+
null,
128+
)
129+
130+
withTimeout(5.seconds) {
131+
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
132+
it.candidates?.isEmpty() shouldBe false
133+
channel.close()
134+
}
135+
}
136+
137+
mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint"
138+
}
139+
140+
@Test
141+
fun `client id header is set correctly in the request`() = doBlocking {
142+
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
143+
val mockEngine = MockEngine {
144+
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
145+
}
146+
147+
val controller =
148+
APIController(
149+
"super_cool_test_key",
150+
"gemini-pro-1.5",
151+
RequestOptions(),
152+
mockEngine,
153+
TEST_CLIENT_ID,
154+
null,
155+
)
156+
157+
withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
158+
159+
mockEngine.requestHistory.first().headers["x-goog-api-client"] shouldBe TEST_CLIENT_ID
160+
}
161+
162+
@Test
163+
fun `ToolConfig serialization contains correct keys`() = doBlocking {
164+
val channel = ByteChannel(autoFlush = true)
165+
val mockEngine = MockEngine {
166+
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
167+
}
168+
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
169+
170+
val controller =
171+
APIController(
172+
"super_cool_test_key",
173+
"gemini-pro-1.5",
174+
RequestOptions(),
175+
mockEngine,
176+
TEST_CLIENT_ID,
177+
null,
178+
)
179+
180+
withTimeout(5.seconds) {
181+
controller
182+
.generateContentStream(
183+
GenerateContentRequest(
184+
model = "unused",
185+
contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))),
186+
toolConfig =
187+
ToolConfig(
188+
functionCallingConfig =
189+
FunctionCallingConfig(mode = FunctionCallingConfig.Mode.AUTO)
190+
),
191+
)
192+
)
193+
.collect { channel.close() }
194+
}
195+
196+
val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text
197+
198+
requestBodyAsText shouldContainJsonKey "tool_config.function_calling_config.mode"
199+
}
200+
201+
@Test
202+
fun `headers from HeaderProvider are added to the request`() = doBlocking {
203+
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
204+
val mockEngine = MockEngine {
205+
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
206+
}
207+
208+
val testHeaderProvider =
209+
object : HeaderProvider {
210+
override val timeout: Duration
211+
get() = 5.seconds
212+
213+
override suspend fun generateHeaders(): Map<String, String> =
214+
mapOf("header1" to "value1", "header2" to "value2")
215+
}
216+
217+
val controller =
218+
APIController(
219+
"super_cool_test_key",
220+
"gemini-pro-1.5",
221+
RequestOptions(),
222+
mockEngine,
223+
TEST_CLIENT_ID,
224+
testHeaderProvider,
225+
)
226+
227+
withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
228+
229+
mockEngine.requestHistory.first().headers["header1"] shouldBe "value1"
230+
mockEngine.requestHistory.first().headers["header2"] shouldBe "value2"
231+
}
232+
233+
@Test
234+
fun `headers from HeaderProvider are ignored if timeout`() = doBlocking {
235+
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
236+
val mockEngine = MockEngine {
237+
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
238+
}
239+
240+
val testHeaderProvider =
241+
object : HeaderProvider {
242+
override val timeout: Duration
243+
get() = 5.milliseconds
244+
245+
override suspend fun generateHeaders(): Map<String, String> {
246+
delay(10.milliseconds)
247+
return mapOf("header1" to "value1")
248+
}
249+
}
250+
251+
val controller =
252+
APIController(
253+
"super_cool_test_key",
254+
"gemini-pro-1.5",
255+
RequestOptions(),
256+
mockEngine,
257+
TEST_CLIENT_ID,
258+
testHeaderProvider,
259+
)
260+
261+
withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
262+
263+
mockEngine.requestHistory.first().headers.contains("header1") shouldBe false
264+
}
265+
266+
@Test
267+
fun `code execution tool serialization contains correct keys`() = doBlocking {
268+
val channel = ByteChannel(autoFlush = true)
269+
val mockEngine = MockEngine {
270+
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
271+
}
272+
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
273+
274+
val controller =
275+
APIController(
276+
"super_cool_test_key",
277+
"gemini-pro-1.5",
278+
RequestOptions(),
279+
mockEngine,
280+
TEST_CLIENT_ID,
281+
null,
282+
)
283+
284+
withTimeout(5.seconds) {
285+
controller
286+
.generateContentStream(
287+
GenerateContentRequest(
288+
model = "unused",
289+
contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))),
290+
tools = listOf(Tool(codeExecution = JsonObject(emptyMap()))),
291+
)
292+
)
293+
.collect { channel.close() }
294+
}
295+
296+
val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text
297+
298+
requestBodyAsText shouldContainJsonKey "tools[0].codeExecution"
299+
}
300+
}
301+
302+
@RunWith(Parameterized::class)
303+
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {
304+
305+
@Test
306+
fun `request should include right model name`() = doBlocking {
307+
val channel = ByteChannel(autoFlush = true)
308+
val mockEngine = MockEngine {
309+
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
310+
}
311+
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
312+
val controller =
313+
APIController(
314+
"super_cool_test_key",
315+
modelName,
316+
RequestOptions(),
317+
mockEngine,
318+
TEST_CLIENT_ID,
319+
null,
320+
)
321+
322+
withTimeout(5.seconds) {
323+
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
324+
it.candidates?.isEmpty() shouldBe false
325+
channel.close()
326+
}
327+
}
328+
329+
mockEngine.requestHistory.first().url.encodedPath shouldContain actualName
330+
}
331+
332+
companion object {
333+
@JvmStatic
334+
@Parameterized.Parameters
335+
fun data() =
336+
listOf(
337+
arrayOf("gemini-pro", "models/gemini-pro"),
338+
arrayOf("x/gemini-pro", "x/gemini-pro"),
339+
arrayOf("models/gemini-pro", "models/gemini-pro"),
340+
arrayOf("/modelname", "/modelname"),
341+
arrayOf("modifiedNaming/mymodel", "modifiedNaming/mymodel"),
342+
)
343+
}
344+
}
345+
346+
fun textGenerateContentRequest(prompt: String) =
347+
GenerateContentRequest(
348+
model = "unused",
349+
contents = listOf(Content(parts = listOf(TextPart(prompt)))),
350+
)
351+
352+
fun textCountTokenRequest(prompt: String) =
353+
CountTokensRequest(generateContentRequest = textGenerateContentRequest(prompt))

0 commit comments

Comments
 (0)