Skip to content

Wrap citations inside CitationMetadata #6276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.util.Base64
import com.google.firebase.vertexai.common.client.Schema
import com.google.firebase.vertexai.common.server.CitationSources
import com.google.firebase.vertexai.common.shared.Blob
import com.google.firebase.vertexai.common.shared.FileData
import com.google.firebase.vertexai.common.shared.FunctionCall
Expand All @@ -32,6 +31,7 @@ import com.google.firebase.vertexai.type.BlobPart
import com.google.firebase.vertexai.type.BlockReason
import com.google.firebase.vertexai.type.BlockThreshold
import com.google.firebase.vertexai.type.Candidate
import com.google.firebase.vertexai.type.Citation
import com.google.firebase.vertexai.type.CitationMetadata
import com.google.firebase.vertexai.type.Content
import com.google.firebase.vertexai.type.CountTokensResponse
Expand Down Expand Up @@ -181,7 +181,7 @@ internal fun JSONObject.toInternal() = Json.decodeFromString<JsonObject>(toStrin

internal fun com.google.firebase.vertexai.common.server.Candidate.toPublic(): Candidate {
val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty()
val citations = citationMetadata?.citationSources?.map { it.toPublic() }.orEmpty()
val citations = citationMetadata?.toPublic()
val finishReason = finishReason.toPublic()

return Candidate(
Expand Down Expand Up @@ -228,8 +228,11 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part {
}
}

internal fun CitationSources.toPublic() =
CitationMetadata(startIndex = startIndex, endIndex = endIndex, uri = uri ?: "", license = license)
internal fun com.google.firebase.vertexai.common.server.CitationSources.toPublic() =
Citation(startIndex = startIndex, endIndex = endIndex, uri = uri, license = license)

internal fun com.google.firebase.vertexai.common.server.CitationMetadata.toPublic() =
CitationMetadata(citationSources.map { it.toPublic() })

internal fun com.google.firebase.vertexai.common.server.SafetyRating.toPublic() =
SafetyRating(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Candidate
internal constructor(
val content: Content,
val safetyRatings: List<SafetyRating>,
val citationMetadata: List<CitationMetadata>,
val citationMetadata: CitationMetadata?,
val finishReason: FinishReason?
)

Expand All @@ -37,15 +37,25 @@ internal constructor(
)

/**
* Provides citation metadata for sourcing of content provided by the model between a given
* A collection of source attributions for a piece of content.
*
* @property citations A list of individual cited sources and the parts of the content to which they
* apply.
*/
class CitationMetadata internal constructor(val citations: List<Citation>)

/**
* Provides citation information for sourcing of content provided by the model between a given
* [startIndex] and [endIndex].
*
* @property startIndex The beginning of the citation.
* @property endIndex The end of the citation.
* @property uri The URI of the cited work.
* @property license The license under which the cited work is distributed.
* @property startIndex The inclusive beginning of a sequence in a model response that derives from
* a cited source.
* @property endIndex The exclusive end of a sequence in a model response that derives from a cited
* source.
* @property uri A link to the cited source, if available.
* @property license The license the cited source work is distributed under, if specified.
*/
class CitationMetadata
class Citation
internal constructor(
val startIndex: Int = 0,
val endIndex: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ internal class StreamingSnapshotTests {

withTimeout(testTimeout) {
val responseList = responses.toList()
responseList.any { it.candidates.any { it.citationMetadata.isNotEmpty() } } shouldBe true
responseList.any {
it.candidates.any { it.citationMetadata?.citations?.isNotEmpty() ?: false }
} shouldBe true
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ internal class UnarySnapshotTests {
val response = model.generateContent("prompt")

response.candidates.isEmpty() shouldBe false
response.candidates.first().citationMetadata.size shouldBe 3
response.candidates.first().citationMetadata?.citations?.size shouldBe 3
}
}

Expand All @@ -240,11 +240,14 @@ internal class UnarySnapshotTests {
val response = model.generateContent("prompt")

response.candidates.isEmpty() shouldBe false
response.candidates.first().citationMetadata.isEmpty() shouldBe false
response.candidates.first().citationMetadata?.citations?.isEmpty() shouldBe false
// Verify the values in the citation source
with(response.candidates.first().citationMetadata.first()) {
license shouldBe null
startIndex shouldBe 0
val firstCitation = response.candidates.first().citationMetadata?.citations?.first()
if (firstCitation != null) {
with(firstCitation) {
license shouldBe null
startIndex shouldBe 0
}
}
}
}
Expand Down
Loading