Skip to content

Revert "Introduce FunctionCall and FunctionResponse types (#6311)" #6360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion firebase-vertexai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
* [feature] Added support for `title` and `publicationDate` in citations. (#6309)
* [feature] Added support for `frequencyPenalty`, `presencePenalty`, and `HarmBlockMethod`. (#6309)
* [changed] **Breaking Change**: Introduced `Citations` class. Now `CitationMetadata` wraps that type. (#6276)
* [changed] **Breaking Change**: Introduced `FunctionCall` and `FunctionResponse` types. Now `FunctionCallPart` and `FunctionResponsePart` wrap those types, respectively. (#6311)
* [changed] **Breaking Change**: Reworked `Schema` declaration mechanism. (#6258)
* [changed] **Breaking Change**: Reworked function calling mechanism to use the new `Schema` format. Function calls no longer use native types, nor include references to the actual executable code. (#6258)
* [changed] **Breaking Change**: Made `totalBillableCharacters` field in `CountTokens` nullable and optional. (#6294)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import android.graphics.BitmapFactory
import android.util.Base64
import com.google.firebase.vertexai.common.client.Schema
import com.google.firebase.vertexai.common.shared.FileData
import com.google.firebase.vertexai.common.shared.FunctionCall
import com.google.firebase.vertexai.common.shared.FunctionCallPart
import com.google.firebase.vertexai.common.shared.FunctionResponse
import com.google.firebase.vertexai.common.shared.FunctionResponsePart
import com.google.firebase.vertexai.common.shared.InlineData
import com.google.firebase.vertexai.type.BlockReason
import com.google.firebase.vertexai.type.Candidate
Expand All @@ -30,12 +34,8 @@ import com.google.firebase.vertexai.type.Content
import com.google.firebase.vertexai.type.CountTokensResponse
import com.google.firebase.vertexai.type.FileDataPart
import com.google.firebase.vertexai.type.FinishReason
import com.google.firebase.vertexai.type.FunctionCall
import com.google.firebase.vertexai.type.FunctionCallPart
import com.google.firebase.vertexai.type.FunctionCallingConfig
import com.google.firebase.vertexai.type.FunctionDeclaration
import com.google.firebase.vertexai.type.FunctionResponse
import com.google.firebase.vertexai.type.FunctionResponsePart
import com.google.firebase.vertexai.type.GenerateContentResponse
import com.google.firebase.vertexai.type.GenerationConfig
import com.google.firebase.vertexai.type.HarmBlockMethod
Expand Down Expand Up @@ -81,10 +81,10 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part
com.google.firebase.vertexai.common.shared.InlineDataPart(
InlineData(mimeType, Base64.encodeToString(inlineData, BASE_64_FLAGS))
)
is FunctionCallPart ->
com.google.firebase.vertexai.common.shared.FunctionCallPart(functionCall.toInternal())
is FunctionResponsePart ->
com.google.firebase.vertexai.common.shared.FunctionResponsePart(functionResponse.toInternal())
is com.google.firebase.vertexai.type.FunctionCallPart ->
FunctionCallPart(FunctionCall(name, args))
is com.google.firebase.vertexai.type.FunctionResponsePart ->
FunctionResponsePart(FunctionResponse(name, response))
is FileDataPart ->
com.google.firebase.vertexai.common.shared.FileDataPart(
FileData(mimeType = mimeType, fileUri = uri)
Expand All @@ -96,12 +96,6 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part
}
}

internal fun FunctionCall.toInternal() =
com.google.firebase.vertexai.common.shared.FunctionCall(name, args)

internal fun FunctionResponse.toInternal() =
com.google.firebase.vertexai.common.shared.FunctionResponse(name, response)

internal fun SafetySetting.toInternal() =
com.google.firebase.vertexai.common.shared.SafetySetting(
harmCategory.toInternal(),
Expand Down Expand Up @@ -235,10 +229,16 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part {
InlineDataPart(data, inlineData.mimeType)
}
}
is com.google.firebase.vertexai.common.shared.FunctionCallPart ->
FunctionCallPart(functionCall.toPublic())
is com.google.firebase.vertexai.common.shared.FunctionResponsePart ->
FunctionResponsePart(functionResponse.toPublic())
is FunctionCallPart ->
com.google.firebase.vertexai.type.FunctionCallPart(
functionCall.name,
functionCall.args.orEmpty().mapValues { it.value ?: JsonNull }
)
is FunctionResponsePart ->
com.google.firebase.vertexai.type.FunctionResponsePart(
functionResponse.name,
functionResponse.response,
)
is com.google.firebase.vertexai.common.shared.FileDataPart ->
FileDataPart(fileData.mimeType, fileData.fileUri)
else ->
Expand All @@ -248,15 +248,6 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part {
}
}

internal fun com.google.firebase.vertexai.common.shared.FunctionCall.toPublic() =
FunctionCall(name, args.orEmpty().mapValues { it.value ?: JsonNull })

internal fun com.google.firebase.vertexai.common.shared.FunctionResponse.toPublic() =
FunctionResponse(
name,
response,
)

internal fun com.google.firebase.vertexai.common.server.CitationSources.toPublic(): Citation {
val publicationDateAsCalendar =
publicationDate?.let {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.google.firebase.vertexai.type
import android.graphics.Bitmap
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import org.json.JSONObject

/** Interface representing data sent to and received from requests. */
public interface Part
Expand All @@ -44,35 +45,21 @@ public class ImagePart(public val image: Bitmap) : Part
public class InlineDataPart(public val inlineData: ByteArray, public val mimeType: String) : Part

/**
* Represents a function call request from the model
*
* @param functionCall The information provided by the model to call a function.
*/
public class FunctionCallPart(public val functionCall: FunctionCall) : Part

/**
* The result of calling a function as requested by the model.
*
* @param functionResponse The information to send back to the model as the result of a functions
* call.
*/
public class FunctionResponsePart(public val functionResponse: FunctionResponse) : Part

/**
* The data necessary to invoke function [name] using the arguments [args].
* Represents function call name and params received from requests.
*
* @param name the name of the function to call
* @param args the function parameters and values as a [Map]
*/
public class FunctionCall(public val name: String, public val args: Map<String, JsonElement>)
public class FunctionCallPart(public val name: String, public val args: Map<String, JsonElement>) :
Part

/**
* The [response] generated after calling function [name].
* Represents function call output to be returned to the model when it requests a function call.
*
* @param name the name of the called function
* @param response the response produced by the function as a [JsonObject]
* @param response the response produced by the function as a [JSONObject]
*/
public class FunctionResponse(public val name: String, public val response: JsonObject)
public class FunctionResponsePart(public val name: String, public val response: JsonObject) : Part

/**
* Represents file data stored in Cloud Storage for Firebase, referenced by URI.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ internal class UnarySnapshotTests {
val response = model.generateContent("prompt")
val callPart = (response.candidates.first().content.parts.first() as FunctionCallPart)

callPart.functionCall.args["season"] shouldBe JsonPrimitive(null)
callPart.args["season"] shouldBe JsonPrimitive(null)
}
}

Expand All @@ -370,7 +370,7 @@ internal class UnarySnapshotTests {
it.parts.first().shouldBeInstanceOf<FunctionCallPart>()
}

callPart.functionCall.args["current"] shouldBe JsonPrimitive(true)
callPart.args["current"] shouldBe JsonPrimitive(true)
}
}

Expand All @@ -387,11 +387,9 @@ internal class UnarySnapshotTests {
it.parts.first().shouldBeInstanceOf<FunctionCallPart>()
}

callPart.functionCall.args["current"] shouldBe JsonPrimitive(true)
callPart.functionCall.args["testObject"]!!
.jsonObject["testProperty"]!!
.jsonPrimitive
.content shouldBe "string property"
callPart.args["current"] shouldBe JsonPrimitive(true)
callPart.args["testObject"]!!.jsonObject["testProperty"]!!.jsonPrimitive.content shouldBe
"string property"
}
}

Expand All @@ -402,8 +400,8 @@ internal class UnarySnapshotTests {
val response = model.generateContent("prompt")
val callPart = response.functionCalls.shouldNotBeEmpty().first()

callPart.functionCall.name shouldBe "current_time"
callPart.functionCall.args.isEmpty() shouldBe true
callPart.name shouldBe "current_time"
callPart.args.isEmpty() shouldBe true
}
}

Expand All @@ -414,9 +412,9 @@ internal class UnarySnapshotTests {
val response = model.generateContent("prompt")
val callPart = response.functionCalls.shouldNotBeEmpty().first()

callPart.functionCall.name shouldBe "sum"
callPart.functionCall.args["x"] shouldBe JsonPrimitive(4)
callPart.functionCall.args["y"] shouldBe JsonPrimitive(5)
callPart.name shouldBe "sum"
callPart.args["x"] shouldBe JsonPrimitive(4)
callPart.args["y"] shouldBe JsonPrimitive(5)
}
}

Expand All @@ -429,8 +427,8 @@ internal class UnarySnapshotTests {

callList.size shouldBe 3
callList.forEach {
it.functionCall.name shouldBe "sum"
it.functionCall.args.size shouldBe 2
it.name shouldBe "sum"
it.args.size shouldBe 2
}
}
}
Expand All @@ -444,7 +442,7 @@ internal class UnarySnapshotTests {

response.text shouldBe "The sum of [1, 2, 3] is"
callList.size shouldBe 2
callList.forEach { it.functionCall.args.size shouldBe 2 }
callList.forEach { it.args.size shouldBe 2 }
}
}

Expand Down
Loading