Skip to content

Commit 7bab838

Browse files
authored
Introduce FunctionCall and FunctionResponse types (#6311)
Their *part counter parts will now wrap them, instead of exposing the underlying structure directly.
1 parent ffcc8ba commit 7bab838

File tree

3 files changed

+70
-34
lines changed

3 files changed

+70
-34
lines changed

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ import android.graphics.BitmapFactory
2121
import android.util.Base64
2222
import com.google.firebase.vertexai.common.client.Schema
2323
import com.google.firebase.vertexai.common.shared.FileData
24-
import com.google.firebase.vertexai.common.shared.FunctionCall
25-
import com.google.firebase.vertexai.common.shared.FunctionCallPart
26-
import com.google.firebase.vertexai.common.shared.FunctionResponse
27-
import com.google.firebase.vertexai.common.shared.FunctionResponsePart
2824
import com.google.firebase.vertexai.common.shared.InlineData
2925
import com.google.firebase.vertexai.type.BlockReason
3026
import com.google.firebase.vertexai.type.Candidate
@@ -34,8 +30,12 @@ import com.google.firebase.vertexai.type.Content
3430
import com.google.firebase.vertexai.type.CountTokensResponse
3531
import com.google.firebase.vertexai.type.FileDataPart
3632
import com.google.firebase.vertexai.type.FinishReason
33+
import com.google.firebase.vertexai.type.FunctionCall
34+
import com.google.firebase.vertexai.type.FunctionCallPart
3735
import com.google.firebase.vertexai.type.FunctionCallingConfig
3836
import com.google.firebase.vertexai.type.FunctionDeclaration
37+
import com.google.firebase.vertexai.type.FunctionResponse
38+
import com.google.firebase.vertexai.type.FunctionResponsePart
3939
import com.google.firebase.vertexai.type.GenerateContentResponse
4040
import com.google.firebase.vertexai.type.GenerationConfig
4141
import com.google.firebase.vertexai.type.HarmBlockMethod
@@ -59,6 +59,7 @@ import java.io.ByteArrayOutputStream
5959
import java.util.Calendar
6060
import kotlinx.serialization.json.Json
6161
import kotlinx.serialization.json.JsonObject
62+
import kotlinx.serialization.json.JsonPrimitive
6263
import org.json.JSONObject
6364

6465
private const val BASE_64_FLAGS = Base64.NO_WRAP
@@ -80,10 +81,10 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part
8081
com.google.firebase.vertexai.common.shared.InlineDataPart(
8182
InlineData(mimeType, Base64.encodeToString(inlineData, BASE_64_FLAGS))
8283
)
83-
is com.google.firebase.vertexai.type.FunctionCallPart ->
84-
FunctionCallPart(FunctionCall(name, args.orEmpty()))
85-
is com.google.firebase.vertexai.type.FunctionResponsePart ->
86-
FunctionResponsePart(FunctionResponse(name, response.toInternal()))
84+
is FunctionCallPart ->
85+
com.google.firebase.vertexai.common.shared.FunctionCallPart(functionCall.toInternal())
86+
is FunctionResponsePart ->
87+
com.google.firebase.vertexai.common.shared.FunctionResponsePart(functionResponse.toInternal())
8788
is FileDataPart ->
8889
com.google.firebase.vertexai.common.shared.FileDataPart(
8990
FileData(mimeType = mimeType, fileUri = uri)
@@ -95,6 +96,15 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part
9596
}
9697
}
9798

99+
internal fun FunctionCall.toInternal() =
100+
com.google.firebase.vertexai.common.shared.FunctionCall(
101+
name,
102+
args.orEmpty().mapValues { it.value.toString() }
103+
)
104+
105+
internal fun FunctionResponse.toInternal() =
106+
com.google.firebase.vertexai.common.shared.FunctionResponse(name, response)
107+
98108
internal fun SafetySetting.toInternal() =
99109
com.google.firebase.vertexai.common.shared.SafetySetting(
100110
harmCategory.toInternal(),
@@ -213,16 +223,10 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part {
213223
InlineDataPart(inlineData.mimeType, data)
214224
}
215225
}
216-
is FunctionCallPart ->
217-
com.google.firebase.vertexai.type.FunctionCallPart(
218-
functionCall.name,
219-
functionCall.args.orEmpty(),
220-
)
221-
is FunctionResponsePart ->
222-
com.google.firebase.vertexai.type.FunctionResponsePart(
223-
functionResponse.name,
224-
functionResponse.response.toPublic(),
225-
)
226+
is com.google.firebase.vertexai.common.shared.FunctionCallPart ->
227+
FunctionCallPart(functionCall.toPublic())
228+
is com.google.firebase.vertexai.common.shared.FunctionResponsePart ->
229+
FunctionResponsePart(functionResponse.toPublic())
226230
is com.google.firebase.vertexai.common.shared.FileDataPart ->
227231
FileDataPart(fileData.mimeType, fileData.fileUri)
228232
else ->
@@ -232,6 +236,21 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part {
232236
}
233237
}
234238

239+
internal fun com.google.firebase.vertexai.common.shared.FunctionCall.toPublic() =
240+
FunctionCall(
241+
name,
242+
args.orEmpty().mapValues {
243+
val argValue = it.value
244+
if (argValue == null) JsonPrimitive(null) else Json.parseToJsonElement(argValue)
245+
}
246+
)
247+
248+
internal fun com.google.firebase.vertexai.common.shared.FunctionResponse.toPublic() =
249+
FunctionResponse(
250+
name,
251+
response,
252+
)
253+
235254
internal fun com.google.firebase.vertexai.common.server.CitationSources.toPublic(): Citation {
236255
val publicationDateAsCalendar =
237256
publicationDate?.let {

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
package com.google.firebase.vertexai.type
1818

1919
import android.graphics.Bitmap
20-
import org.json.JSONObject
20+
import kotlinx.serialization.json.JsonElement
21+
import kotlinx.serialization.json.JsonObject
2122

2223
/** Interface representing data sent to and received from requests. */
2324
interface Part
@@ -44,20 +45,35 @@ class ImagePart(val image: Bitmap) : Part
4445
class InlineDataPart(val mimeType: String, val inlineData: ByteArray) : Part
4546

4647
/**
47-
* Represents function call name and params received from requests.
48+
* Represents a function call request from the model
49+
*
50+
* @param functionCall The information provided by the model to call a function.
51+
*/
52+
class FunctionCallPart(val functionCall: FunctionCall) : Part
53+
54+
/**
55+
* The result of calling a function as requested by the model.
56+
*
57+
* @param functionResponse The information to send back to the model as the result of a functions
58+
* call.
59+
*/
60+
class FunctionResponsePart(val functionResponse: FunctionResponse) : Part
61+
62+
/**
63+
* The data necessary to invoke function [name] using the arguments [args].
4864
*
4965
* @param name the name of the function to call
5066
* @param args the function parameters and values as a [Map]
5167
*/
52-
class FunctionCallPart(val name: String, val args: Map<String, String?>) : Part
68+
class FunctionCall(val name: String, val args: Map<String, JsonElement>)
5369

5470
/**
55-
* Represents function call output to be returned to the model when it requests a function call.
71+
* The [response] generated after calling function [name].
5672
*
5773
* @param name the name of the called function
58-
* @param response the response produced by the function as a [JSONObject]
74+
* @param response the response produced by the function as a [JsonObject]
5975
*/
60-
class FunctionResponsePart(val name: String, val response: JSONObject) : Part
76+
class FunctionResponse(val name: String, val response: JsonObject)
6177

6278
/**
6379
* Represents file data stored in Cloud Storage for Firebase, referenced by URI.

firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import io.kotest.matchers.types.shouldBeInstanceOf
4444
import io.ktor.http.HttpStatusCode
4545
import kotlin.time.Duration.Companion.seconds
4646
import kotlinx.coroutines.withTimeout
47+
import kotlinx.serialization.json.JsonPrimitive
4748
import org.json.JSONArray
4849
import org.junit.Test
4950

@@ -350,7 +351,7 @@ internal class UnarySnapshotTests {
350351
val response = model.generateContent("prompt")
351352
val callPart = (response.candidates.first().content.parts.first() as FunctionCallPart)
352353

353-
callPart.args["season"] shouldBe null
354+
callPart.functionCall.args["season"] shouldBe JsonPrimitive(null)
354355
}
355356
}
356357

@@ -367,7 +368,7 @@ internal class UnarySnapshotTests {
367368
it.parts.first().shouldBeInstanceOf<FunctionCallPart>()
368369
}
369370

370-
callPart.args["current"] shouldBe "true"
371+
callPart.functionCall.args["current"] shouldBe JsonPrimitive(true)
371372
}
372373
}
373374

@@ -378,8 +379,8 @@ internal class UnarySnapshotTests {
378379
val response = model.generateContent("prompt")
379380
val callPart = response.functionCalls.shouldNotBeEmpty().first()
380381

381-
callPart.name shouldBe "current_time"
382-
callPart.args.isEmpty() shouldBe true
382+
callPart.functionCall.name shouldBe "current_time"
383+
callPart.functionCall.args.isEmpty() shouldBe true
383384
}
384385
}
385386

@@ -390,9 +391,9 @@ internal class UnarySnapshotTests {
390391
val response = model.generateContent("prompt")
391392
val callPart = response.functionCalls.shouldNotBeEmpty().first()
392393

393-
callPart.name shouldBe "sum"
394-
callPart.args["x"] shouldBe "4"
395-
callPart.args["y"] shouldBe "5"
394+
callPart.functionCall.name shouldBe "sum"
395+
callPart.functionCall.args["x"] shouldBe JsonPrimitive(4)
396+
callPart.functionCall.args["y"] shouldBe JsonPrimitive(5)
396397
}
397398
}
398399

@@ -405,8 +406,8 @@ internal class UnarySnapshotTests {
405406

406407
callList.size shouldBe 3
407408
callList.forEach {
408-
it.name shouldBe "sum"
409-
it.args.size shouldBe 2
409+
it.functionCall.name shouldBe "sum"
410+
it.functionCall.args.size shouldBe 2
410411
}
411412
}
412413
}
@@ -420,7 +421,7 @@ internal class UnarySnapshotTests {
420421

421422
response.text shouldBe "The sum of [1, 2, 3] is"
422423
callList.size shouldBe 2
423-
callList.forEach { it.args.size shouldBe 2 }
424+
callList.forEach { it.functionCall.args.size shouldBe 2 }
424425
}
425426
}
426427

0 commit comments

Comments
 (0)