Skip to content

Add usageMetadata to GenerateContentResponse #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion Sources/GoogleAI/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,28 @@ import Foundation
/// The model's response to a generate content request.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public struct GenerateContentResponse {
/// Token usage metadata for processing the generate content request.
public struct UsageMetadata {
/// The number of tokens in the request prompt.
public let promptTokenCount: Int

/// The total number of tokens across the generated response candidates.
public let candidatesTokenCount: Int

/// The total number of tokens in both the request and response.
public let totalTokenCount: Int
}

/// A list of candidate response content, ordered from best to worst.
public let candidates: [CandidateResponse]

/// A value containing the safety ratings for the response, or, if the request was blocked, a
/// reason for blocking the request.
public let promptFeedback: PromptFeedback?

/// Token usage metadata for processing the generate content request.
public let usageMetadata: UsageMetadata?

/// The response's content as text, if it exists.
public var text: String? {
guard let candidate = candidates.first else {
Expand Down Expand Up @@ -51,9 +66,11 @@ public struct GenerateContentResponse {
}

/// Initializer for SwiftUI previews or tests.
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil) {
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil,
usageMetadata: UsageMetadata? = nil) {
self.candidates = candidates
self.promptFeedback = promptFeedback
self.usageMetadata = usageMetadata
}
}

Expand Down Expand Up @@ -170,6 +187,7 @@ extension GenerateContentResponse: Decodable {
enum CodingKeys: CodingKey {
case candidates
case promptFeedback
case usageMetadata
}

public init(from decoder: Decoder) throws {
Expand All @@ -194,6 +212,24 @@ extension GenerateContentResponse: Decodable {
candidates = []
}
promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback)
usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata)
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension GenerateContentResponse.UsageMetadata: Decodable {
enum CodingKeys: CodingKey {
case promptTokenCount
case candidatesTokenCount
case totalTokenCount
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
promptTokenCount = try container.decodeIfPresent(Int.self, forKey: .promptTokenCount) ?? 0
candidatesTokenCount = try container
.decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0
totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
data: {"candidates": [{"content": {"parts": [{"text": "Cheyenne"}]},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"promptFeedback": {"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}}

data: {"candidates": [{"content": {"parts": [{"text": "Mountain View, California"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"candidatesTokenCount": 4}}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"content": {
"parts": [
{
"text": "Mountain View, California, United States"
"text": "Mountain View, California"
}
],
"role": "model"
Expand All @@ -31,24 +31,7 @@
]
}
],
"promptFeedback": {
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
"usageMetadata": {
"candidatesTokenCount": 4
}
}
49 changes: 45 additions & 4 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,9 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(candidate.safetyRatings, safetyRatingsNegligible)
XCTAssertEqual(candidate.content.parts.count, 1)
let part = try XCTUnwrap(candidate.content.parts.first)
XCTAssertEqual(part.text, "Mountain View, California, United States")
XCTAssertEqual(part.text, "Mountain View, California")
XCTAssertEqual(response.text, part.text)
let promptFeedback = try XCTUnwrap(response.promptFeedback)
XCTAssertNil(promptFeedback.blockReason)
XCTAssertEqual(promptFeedback.safetyRatings, safetyRatingsNegligible)
XCTAssertNil(response.promptFeedback)
XCTAssertEqual(response.functionCalls, [])
}

Expand Down Expand Up @@ -256,6 +254,22 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(response.functionCalls, [functionCall])
}

func testGenerateContent_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

let usageMetadata = try XCTUnwrap(response.usageMetadata)
// TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added.
XCTAssertEqual(usageMetadata.promptTokenCount, 0)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
XCTAssertEqual(usageMetadata.totalTokenCount, 0)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol
Expand Down Expand Up @@ -756,6 +770,33 @@ final class GenerativeModelTests: XCTestCase {
}))
}

func testGenerateContentStream_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt"
)
var responses = [GenerateContentResponse]()

let stream = model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

for (index, response) in responses.enumerated() {
if index == responses.endIndex - 1 {
let usageMetadata = try XCTUnwrap(response.usageMetadata)
// TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added.
XCTAssertEqual(usageMetadata.promptTokenCount, 0)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
XCTAssertEqual(usageMetadata.totalTokenCount, 0)
} else {
// Only the last streamed response contains usage metadata
XCTAssertNil(response.usageMetadata)
}
}
}

func testGenerateContentStream_errorMidStream() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "streaming-failure-error-mid-stream",
Expand Down