Skip to content

Commit e9af231

Browse files
authored
Add support for allowed_function_names (#6273)
A note about `@SerialName("allowed_function_names")`: This is not strictly necessary, the backend will parse it if it's camel case too. We should eventually remove all unnecessary `@SerialName` declarations, but for now, and to keep consistency, I'm adding it to this declaration.
1 parent c56bc5d commit e9af231

File tree

5 files changed

+25
-9
lines changed

5 files changed

+25
-9
lines changed

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ internal data class ToolConfig(
4747
)
4848

4949
@Serializable
50-
internal data class FunctionCallingConfig(val mode: Mode) {
50+
internal data class FunctionCallingConfig(
51+
val mode: Mode,
52+
@SerialName("allowed_function_names") val allowedFunctionNames: List<String>? = null
53+
) {
5154
@Serializable
5255
enum class Mode {
5356
@SerialName("MODE_UNSPECIFIED") UNSPECIFIED,

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ internal fun ToolConfig.toInternal() =
137137
com.google.firebase.vertexai.common.client.FunctionCallingConfig.Mode.AUTO
138138
FunctionCallingConfig.Mode.NONE ->
139139
com.google.firebase.vertexai.common.client.FunctionCallingConfig.Mode.NONE
140-
}
140+
},
141+
functionCallingConfig.allowedFunctionNames
141142
)
142143
)
143144

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ package com.google.firebase.vertexai.type
2121
* calling predictions or disable them.
2222
*
2323
* @param mode The function calling mode of the model
24+
* @param allowedFunctionNames Function names to call. Only set when the [Mode.ANY]. Function names
25+
* should match [FunctionDeclaration.name]. With [Mode.ANY], model will predict a function call from
26+
* the set of function names provided.
2427
*/
25-
class FunctionCallingConfig(val mode: Mode) {
28+
class FunctionCallingConfig(val mode: Mode, val allowedFunctionNames: List<String>? = null) {
2629

2730
/** Configuration for dictating when the model should call the attached function. */
2831
enum class Mode {

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ class ToolConfig(val functionCallingConfig: FunctionCallingConfig) {
2727
companion object {
2828
/** Shorthand to construct a ToolConfig that restricts the model from calling any functions */
2929
fun never(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.NONE))
30-
/** Shorthand to construct a ToolConfig that restricts the model to always call some function */
31-
fun always(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.ANY))
30+
/**
31+
* Shorthand to construct a ToolConfig that restricts the model to always call some function.
32+
* You can optionally [allowedFunctionNames] to restrict the model to only call these functions.
33+
* See [FunctionCallingConfig] for more information.
34+
*/
35+
fun always(allowedFunctionNames: List<String>? = null): ToolConfig =
36+
ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.ANY, allowedFunctionNames))
3237
}
3338
}

firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,21 @@ internal class RequestFormatTests {
185185
contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))),
186186
toolConfig =
187187
ToolConfig(
188-
functionCallingConfig =
189-
FunctionCallingConfig(mode = FunctionCallingConfig.Mode.AUTO)
190-
),
191-
)
188+
FunctionCallingConfig(
189+
mode = FunctionCallingConfig.Mode.ANY,
190+
allowedFunctionNames = listOf("allowedFunctionName")
191+
)
192+
)
193+
),
192194
)
193195
.collect { channel.close() }
194196
}
195197

196198
val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text
197199

198200
requestBodyAsText shouldContainJsonKey "tool_config.function_calling_config.mode"
201+
requestBodyAsText shouldContainJsonKey
202+
"tool_config.function_calling_config.allowed_function_names"
199203
}
200204

201205
@Test

0 commit comments

Comments
 (0)