Skip to content

Commit 92a824b

Browse files
rlazodaymxn
andauthored
Improvements to vertexAI types (#6309)
A collection of improvements to the VertexAI SDK: - Make properties of `GenerativeModel` private - Rename all `blob.*` to `inlineData.*` - Improvements to `FunctionCallingConfig` - Add support for `frequencyPenalty` and `presencePenalty` - Add support for `HarmBlockMethod` - Add support for `title` and `publicationDate` in `Citation - Make error handling more robust when details are missing --------- Co-authored-by: Daymon <[email protected]>
1 parent a95da3f commit 92a824b

File tree

14 files changed

+192
-59
lines changed

14 files changed

+192
-59
lines changed

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/Chat.kt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
package com.google.firebase.vertexai
1818

1919
import android.graphics.Bitmap
20-
import com.google.firebase.vertexai.type.BlobPart
2120
import com.google.firebase.vertexai.type.Content
2221
import com.google.firebase.vertexai.type.GenerateContentResponse
2322
import com.google.firebase.vertexai.type.ImagePart
23+
import com.google.firebase.vertexai.type.InlineDataPart
2424
import com.google.firebase.vertexai.type.InvalidStateException
2525
import com.google.firebase.vertexai.type.TextPart
2626
import com.google.firebase.vertexai.type.content
@@ -102,21 +102,21 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
102102

103103
val flow = model.generateContentStream(*history.toTypedArray(), prompt)
104104
val bitmaps = LinkedList<Bitmap>()
105-
val blobs = LinkedList<BlobPart>()
105+
val inlineDataParts = LinkedList<InlineDataPart>()
106106
val text = StringBuilder()
107107

108108
/**
109-
* TODO: revisit when images and blobs are returned. This will cause issues with how things are
110-
* structured in the response. eg; a text/image/text response will be (incorrectly) represented
111-
* as image/text
109+
* TODO: revisit when images and inline data are returned. This will cause issues with how
110+
* things are structured in the response. eg; a text/image/text response will be (incorrectly)
111+
* represented as image/text
112112
*/
113113
return flow
114114
.onEach {
115115
for (part in it.candidates.first().content.parts) {
116116
when (part) {
117117
is TextPart -> text.append(part.text)
118118
is ImagePart -> bitmaps.add(part.image)
119-
is BlobPart -> blobs.add(part)
119+
is InlineDataPart -> inlineDataParts.add(part)
120120
}
121121
}
122122
}
@@ -128,8 +128,8 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
128128
for (bitmap in bitmaps) {
129129
image(bitmap)
130130
}
131-
for (blob in blobs) {
132-
blob(blob.mimeType, blob.blob)
131+
for (inlineDataPart in inlineDataParts) {
132+
inlineData(inlineDataPart.mimeType, inlineDataPart.inlineData)
133133
}
134134
if (text.isNotBlank()) {
135135
text(text.toString())

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ import kotlinx.coroutines.tasks.await
5252
*/
5353
class GenerativeModel
5454
internal constructor(
55-
val modelName: String,
56-
val generationConfig: GenerationConfig? = null,
57-
val safetySettings: List<SafetySetting>? = null,
58-
val tools: List<Tool>? = null,
59-
val toolConfig: ToolConfig? = null,
60-
val systemInstruction: Content? = null,
55+
private val modelName: String,
56+
private val generationConfig: GenerationConfig? = null,
57+
private val safetySettings: List<SafetySetting>? = null,
58+
private val tools: List<Tool>? = null,
59+
private val toolConfig: ToolConfig? = null,
60+
private val systemInstruction: Content? = null,
6161
private val controller: APIController
6262
) {
6363

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/server/Types.kt

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,25 @@ internal constructor(@JsonNames("citations") val citationSources: List<CitationS
6868

6969
@Serializable
7070
internal data class CitationSources(
71+
val title: String? = null,
7172
val startIndex: Int = 0,
7273
val endIndex: Int,
7374
val uri: String? = null,
7475
val license: String? = null,
76+
val publicationDate: Date? = null,
77+
)
78+
79+
@Serializable
80+
internal data class Date(
81+
/** Year of the date. Must be between 1 and 9999, or 0 for no year. */
82+
val year: Int? = null,
83+
/** 1-based index for month. Must be from 1 to 12, or 0 to specify a year without a month. */
84+
val month: Int? = null,
85+
/**
86+
* Day of a month. Must be from 1 to 31 and valid for the year and month, or 0 to specify a year
87+
* by itself or a year and month where the day isn't significant.
88+
*/
89+
val day: Int? = null,
7590
)
7691

7792
@Serializable
@@ -145,7 +160,7 @@ internal enum class FinishReason {
145160
internal data class GRpcError(
146161
val code: Int,
147162
val message: String,
148-
val details: List<GRpcErrorDetails>
163+
val details: List<GRpcErrorDetails>? = null
149164
)
150165

151166
@Serializable internal data class GRpcErrorDetails(val reason: String? = null)

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ internal data class Content(@EncodeDefault val role: String? = "user", val parts
5151

5252
@Serializable internal data class TextPart(val text: String) : Part
5353

54-
@Serializable internal data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part
54+
@Serializable
55+
internal data class InlineDataPart(@SerialName("inline_data") val inlineData: InlineData) : Part
5556

5657
@Serializable internal data class FunctionCallPart(val functionCall: FunctionCall) : Part
5758

@@ -73,7 +74,7 @@ internal data class FileData(
7374
)
7475

7576
@Serializable
76-
internal data class Blob(@SerialName("mime_type") val mimeType: String, val data: Base64)
77+
internal data class InlineData(@SerialName("mime_type") val mimeType: String, val data: Base64)
7778

7879
@Serializable
7980
internal data class SafetySetting(
@@ -105,7 +106,7 @@ internal object PartSerializer : JsonContentPolymorphicSerializer<Part>(Part::cl
105106
"text" in jsonObject -> TextPart.serializer()
106107
"functionCall" in jsonObject -> FunctionCallPart.serializer()
107108
"functionResponse" in jsonObject -> FunctionResponsePart.serializer()
108-
"inlineData" in jsonObject -> BlobPart.serializer()
109+
"inlineData" in jsonObject -> InlineDataPart.serializer()
109110
"fileData" in jsonObject -> FileDataPart.serializer()
110111
else -> throw SerializationException("Unknown Part type")
111112
}

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@ import android.graphics.Bitmap
2020
import android.graphics.BitmapFactory
2121
import android.util.Base64
2222
import com.google.firebase.vertexai.common.client.Schema
23-
import com.google.firebase.vertexai.common.shared.Blob
2423
import com.google.firebase.vertexai.common.shared.FileData
2524
import com.google.firebase.vertexai.common.shared.FunctionCall
2625
import com.google.firebase.vertexai.common.shared.FunctionCallPart
2726
import com.google.firebase.vertexai.common.shared.FunctionResponse
2827
import com.google.firebase.vertexai.common.shared.FunctionResponsePart
29-
import com.google.firebase.vertexai.type.BlobPart
28+
import com.google.firebase.vertexai.common.shared.InlineData
3029
import com.google.firebase.vertexai.type.BlockReason
3130
import com.google.firebase.vertexai.type.Candidate
3231
import com.google.firebase.vertexai.type.Citation
@@ -39,11 +38,13 @@ import com.google.firebase.vertexai.type.FunctionCallingConfig
3938
import com.google.firebase.vertexai.type.FunctionDeclaration
4039
import com.google.firebase.vertexai.type.GenerateContentResponse
4140
import com.google.firebase.vertexai.type.GenerationConfig
41+
import com.google.firebase.vertexai.type.HarmBlockMethod
4242
import com.google.firebase.vertexai.type.HarmBlockThreshold
4343
import com.google.firebase.vertexai.type.HarmCategory
4444
import com.google.firebase.vertexai.type.HarmProbability
4545
import com.google.firebase.vertexai.type.HarmSeverity
4646
import com.google.firebase.vertexai.type.ImagePart
47+
import com.google.firebase.vertexai.type.InlineDataPart
4748
import com.google.firebase.vertexai.type.Part
4849
import com.google.firebase.vertexai.type.PromptFeedback
4950
import com.google.firebase.vertexai.type.SafetyRating
@@ -55,6 +56,7 @@ import com.google.firebase.vertexai.type.ToolConfig
5556
import com.google.firebase.vertexai.type.UsageMetadata
5657
import com.google.firebase.vertexai.type.content
5758
import java.io.ByteArrayOutputStream
59+
import java.util.Calendar
5860
import kotlinx.serialization.json.Json
5961
import kotlinx.serialization.json.JsonObject
6062
import org.json.JSONObject
@@ -71,12 +73,12 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part
7173
return when (this) {
7274
is TextPart -> com.google.firebase.vertexai.common.shared.TextPart(text)
7375
is ImagePart ->
74-
com.google.firebase.vertexai.common.shared.BlobPart(
75-
Blob("image/jpeg", encodeBitmapToBase64Png(image))
76+
com.google.firebase.vertexai.common.shared.InlineDataPart(
77+
InlineData("image/jpeg", encodeBitmapToBase64Png(image))
7678
)
77-
is BlobPart ->
78-
com.google.firebase.vertexai.common.shared.BlobPart(
79-
Blob(mimeType, Base64.encodeToString(blob, BASE_64_FLAGS))
79+
is InlineDataPart ->
80+
com.google.firebase.vertexai.common.shared.InlineDataPart(
81+
InlineData(mimeType, Base64.encodeToString(inlineData, BASE_64_FLAGS))
8082
)
8183
is com.google.firebase.vertexai.type.FunctionCallPart ->
8284
FunctionCallPart(FunctionCall(name, args.orEmpty()))
@@ -96,7 +98,8 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part
9698
internal fun SafetySetting.toInternal() =
9799
com.google.firebase.vertexai.common.shared.SafetySetting(
98100
harmCategory.toInternal(),
99-
threshold.toInternal()
101+
threshold.toInternal(),
102+
method.toInternal()
100103
)
101104

102105
internal fun GenerationConfig.toInternal() =
@@ -107,11 +110,13 @@ internal fun GenerationConfig.toInternal() =
107110
candidateCount = candidateCount,
108111
maxOutputTokens = maxOutputTokens,
109112
stopSequences = stopSequences,
113+
frequencyPenalty = frequencyPenalty,
114+
presencePenalty = presencePenalty,
110115
responseMimeType = responseMimeType,
111116
responseSchema = responseSchema?.toInternal()
112117
)
113118

114-
internal fun com.google.firebase.vertexai.type.HarmCategory.toInternal() =
119+
internal fun HarmCategory.toInternal() =
115120
when (this) {
116121
HarmCategory.HARASSMENT -> com.google.firebase.vertexai.common.shared.HarmCategory.HARASSMENT
117122
HarmCategory.HATE_SPEECH -> com.google.firebase.vertexai.common.shared.HarmCategory.HATE_SPEECH
@@ -122,6 +127,13 @@ internal fun com.google.firebase.vertexai.type.HarmCategory.toInternal() =
122127
HarmCategory.UNKNOWN -> com.google.firebase.vertexai.common.shared.HarmCategory.UNKNOWN
123128
}
124129

130+
internal fun HarmBlockMethod.toInternal() =
131+
when (this) {
132+
HarmBlockMethod.SEVERITY -> com.google.firebase.vertexai.common.shared.HarmBlockMethod.SEVERITY
133+
HarmBlockMethod.PROBABILITY ->
134+
com.google.firebase.vertexai.common.shared.HarmBlockMethod.PROBABILITY
135+
}
136+
125137
internal fun ToolConfig.toInternal() =
126138
com.google.firebase.vertexai.common.client.ToolConfig(
127139
com.google.firebase.vertexai.common.client.FunctionCallingConfig(
@@ -150,7 +162,9 @@ internal fun HarmBlockThreshold.toInternal() =
150162
}
151163

152164
internal fun Tool.toInternal() =
153-
com.google.firebase.vertexai.common.client.Tool(functionDeclarations.map { it.toInternal() })
165+
com.google.firebase.vertexai.common.client.Tool(
166+
functionDeclarations?.map { it.toInternal() } ?: emptyList()
167+
)
154168

155169
internal fun FunctionDeclaration.toInternal() =
156170
com.google.firebase.vertexai.common.client.FunctionDeclaration(name, "", schema.toInternal())
@@ -191,12 +205,12 @@ internal fun com.google.firebase.vertexai.common.shared.Content.toPublic(): Cont
191205
internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part {
192206
return when (this) {
193207
is com.google.firebase.vertexai.common.shared.TextPart -> TextPart(text)
194-
is com.google.firebase.vertexai.common.shared.BlobPart -> {
208+
is com.google.firebase.vertexai.common.shared.InlineDataPart -> {
195209
val data = Base64.decode(inlineData.data, BASE_64_FLAGS)
196210
if (inlineData.mimeType.contains("image")) {
197211
ImagePart(decodeBitmapFromImage(data))
198212
} else {
199-
BlobPart(inlineData.mimeType, data)
213+
InlineDataPart(inlineData.mimeType, data)
200214
}
201215
}
202216
is FunctionCallPart ->
@@ -218,8 +232,29 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part {
218232
}
219233
}
220234

221-
internal fun com.google.firebase.vertexai.common.server.CitationSources.toPublic() =
222-
Citation(startIndex = startIndex, endIndex = endIndex, uri = uri, license = license)
235+
internal fun com.google.firebase.vertexai.common.server.CitationSources.toPublic(): Citation {
236+
val publicationDateAsCalendar =
237+
publicationDate?.let {
238+
val calendar = Calendar.getInstance()
239+
// Internal `Date.year` uses 0 to represent not specified. We use 1 as default.
240+
val year = if (it.year == null || it.year < 1) 1 else it.year
241+
// Internal `Date.month` uses 0 to represent not specified, or is 1-12 as months. The month as
242+
// expected by [Calendar] is 0-based, so we subtract 1 or use 0 as default.
243+
val month = if (it.month == null || it.month < 1) 0 else it.month - 1
244+
// Internal `Date.day` uses 0 to represent not specified. We use 1 as default.
245+
val day = if (it.day == null || it.day < 1) 1 else it.day
246+
calendar.set(year, month, day)
247+
calendar
248+
}
249+
return Citation(
250+
title = title,
251+
startIndex = startIndex,
252+
endIndex = endIndex,
253+
uri = uri,
254+
license = license,
255+
publicationDate = publicationDateAsCalendar
256+
)
257+
}
223258

224259
internal fun com.google.firebase.vertexai.common.server.CitationMetadata.toPublic() =
225260
CitationMetadata(citationSources.map { it.toPublic() })

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package com.google.firebase.vertexai.type
1818

19+
import java.util.Calendar
20+
1921
/** A response generated by the model. */
2022
class Candidate
2123
internal constructor(
@@ -48,19 +50,23 @@ class CitationMetadata internal constructor(val citations: List<Citation>)
4850
* Provides citation information for sourcing of content provided by the model between a given
4951
* [startIndex] and [endIndex].
5052
*
53+
* @property title Title of the attribution.
5154
* @property startIndex The inclusive beginning of a sequence in a model response that derives from
5255
* a cited source.
5356
* @property endIndex The exclusive end of a sequence in a model response that derives from a cited
5457
* source.
5558
* @property uri A link to the cited source, if available.
5659
* @property license The license the cited source work is distributed under, if specified.
60+
* @property publicationDate Publication date of the attribution, if available.
5761
*/
5862
class Citation
5963
internal constructor(
64+
val title: String? = null,
6065
val startIndex: Int = 0,
6166
val endIndex: Int,
6267
val uri: String? = null,
63-
val license: String? = null
68+
val license: String? = null,
69+
val publicationDate: Calendar? = null
6470
)
6571

6672
/** The reason for content finishing. */

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@ class Content @JvmOverloads constructor(val role: String? = "user", val parts: L
4848
@JvmName("addText") fun text(text: String) = part(TextPart(text))
4949

5050
/**
51-
* Wraps the provided [blob] and [mimeType] inside a [BlobPart] and adds it to the [parts] list.
51+
* Wraps the provided [bytes] and [mimeType] inside a [InlineDataPart] and adds it to the
52+
* [parts] list.
5253
*/
53-
@JvmName("addBlob") fun blob(mimeType: String, blob: ByteArray) = part(BlobPart(mimeType, blob))
54+
@JvmName("addInlineData")
55+
fun inlineData(mimeType: String, bytes: ByteArray) = part(InlineDataPart(mimeType, bytes))
5456

5557
/** Wraps the provided [image] inside an [ImagePart] and adds it to the [parts] list. */
5658
@JvmName("addImage") fun image(image: Bitmap) = part(ImagePart(image))

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package com.google.firebase.vertexai.type
1818

19+
import com.google.firebase.vertexai.common.client.FunctionCallingConfig
20+
1921
/**
2022
* Contains configuration for function calling from the model. This can be used to force function
2123
* calling predictions or disable them.
@@ -25,10 +27,14 @@ package com.google.firebase.vertexai.type
2527
* should match [FunctionDeclaration.name]. With [Mode.ANY], model will predict a function call from
2628
* the set of function names provided.
2729
*/
28-
class FunctionCallingConfig(val mode: Mode, val allowedFunctionNames: List<String>? = null) {
30+
class FunctionCallingConfig
31+
internal constructor(
32+
internal val mode: Mode,
33+
internal val allowedFunctionNames: List<String>? = null
34+
) {
2935

3036
/** Configuration for dictating when the model should call the attached function. */
31-
enum class Mode {
37+
internal enum class Mode {
3238
/**
3339
* The default behavior for function calling. The model calls functions to answer queries at its
3440
* discretion
@@ -44,4 +50,24 @@ class FunctionCallingConfig(val mode: Mode, val allowedFunctionNames: List<Strin
4450
*/
4551
NONE
4652
}
53+
54+
companion object {
55+
/**
56+
* The default behavior for function calling. The model calls functions to answer queries at its
57+
* discretion
58+
*/
59+
@JvmStatic fun auto() = FunctionCallingConfig(Mode.AUTO)
60+
61+
/** The model always predicts a provided function call to answer every query. */
62+
@JvmStatic
63+
@JvmOverloads
64+
fun any(allowedFunctionNames: List<String>? = null) =
65+
FunctionCallingConfig(Mode.ANY, allowedFunctionNames)
66+
67+
/**
68+
* The model will never predict a function call to answer a query. This can also be achieved by
69+
* not passing any tools to the model.
70+
*/
71+
@JvmStatic fun none() = FunctionCallingConfig(Mode.NONE)
72+
}
4773
}

0 commit comments

Comments
 (0)