Skip to content

Commit d675a68

Browse files
committed
[Vertex AI] Integrate with AppCheckInterop (#12856)
1 parent aead392 commit d675a68

File tree

8 files changed

+163
-6
lines changed

8 files changed

+163
-6
lines changed

FirebaseVertexAI/Sources/GenerativeAIService.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
import FirebaseAppCheckInterop
16+
import FirebaseAuthInterop
1617
import FirebaseCore
1718
import Foundation
1819

@@ -29,11 +30,14 @@ struct GenerativeAIService {
2930

3031
private let appCheck: AppCheckInterop?
3132

33+
private let auth: AuthInterop?
34+
3235
private let urlSession: URLSession
3336

34-
init(apiKey: String, appCheck: AppCheckInterop?, urlSession: URLSession) {
37+
init(apiKey: String, appCheck: AppCheckInterop?, auth: AuthInterop?, urlSession: URLSession) {
3538
self.apiKey = apiKey
3639
self.appCheck = appCheck
40+
self.auth = auth
3741
self.urlSession = urlSession
3842
}
3943

@@ -176,6 +180,10 @@ struct GenerativeAIService {
176180
}
177181
}
178182

183+
if let auth, let authToken = try await auth.getToken(forcingRefresh: false) {
184+
urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization")
185+
}
186+
179187
let encoder = JSONEncoder()
180188
encoder.keyEncodingStrategy = .convertToSnakeCase
181189
urlRequest.httpBody = try encoder.encode(request)

FirebaseVertexAI/Sources/GenerativeModel.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
import FirebaseAppCheckInterop
16+
import FirebaseAuthInterop
1617
import Foundation
1718

1819
/// A type that represents a remote multimodal model (like Gemini), with the ability to generate
@@ -69,11 +70,13 @@ public final class GenerativeModel {
6970
systemInstruction: ModelContent? = nil,
7071
requestOptions: RequestOptions,
7172
appCheck: AppCheckInterop?,
73+
auth: AuthInterop?,
7274
urlSession: URLSession = .shared) {
7375
modelResourceName = GenerativeModel.modelResourceName(name: name)
7476
generativeAIService = GenerativeAIService(
7577
apiKey: apiKey,
7678
appCheck: appCheck,
79+
auth: auth,
7780
urlSession: urlSession
7881
)
7982
self.generationConfig = generationConfig

FirebaseVertexAI/Sources/VertexAI.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
import FirebaseAppCheckInterop
16+
import FirebaseAuthInterop
1617
import FirebaseCore
1718
import Foundation
1819

@@ -92,7 +93,8 @@ public class VertexAI: NSObject {
9293
toolConfig: toolConfig,
9394
systemInstruction: systemInstruction,
9495
requestOptions: requestOptions,
95-
appCheck: appCheck
96+
appCheck: appCheck,
97+
auth: auth
9698
)
9799
}
98100

@@ -103,12 +105,15 @@ public class VertexAI: NSObject {
103105

104106
private let appCheck: AppCheckInterop?
105107

108+
private let auth: AuthInterop?
109+
106110
let location: String
107111

108112
init(app: FirebaseApp, location: String) {
109113
self.app = app
110114
self.location = location
111115
appCheck = ComponentType<AppCheckInterop>.instance(for: AppCheckInterop.self, in: app.container)
116+
auth = ComponentType<AuthInterop>.instance(for: AuthInterop.self, in: app.container)
112117
}
113118

114119
private func modelResourceName(modelName: String, location: String) -> String {

FirebaseVertexAI/Sources/VertexAIComponent.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
import FirebaseAppCheckInterop
16+
import FirebaseAuthInterop
1617
import FirebaseCore
1718
import Foundation
1819

@@ -51,10 +52,12 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider {
5152

5253
static func componentsToRegister() -> [Component] {
5354
let appCheckInterop = Dependency(with: AppCheckInterop.self, isRequired: false)
55+
let authInterop = Dependency(with: AuthInterop.self, isRequired: false)
5456
return [Component(VertexAIProvider.self,
5557
instantiationTiming: .lazy,
5658
dependencies: [
5759
appCheckInterop,
60+
authInterop,
5861
]) { container, isCacheable in
5962
guard let app = container.app else { return nil }
6063
isCacheable.pointee = true

FirebaseVertexAI/Tests/Unit/ChatTests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ final class ChatTests: XCTestCase {
5353
tools: nil,
5454
requestOptions: RequestOptions(),
5555
appCheck: nil,
56+
auth: nil,
5657
urlSession: urlSession
5758
)
5859
let chat = Chat(model: model, history: [])
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import FirebaseAuthInterop
16+
import Foundation
17+
18+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
19+
class AuthInteropFake: NSObject, AuthInterop {
20+
let token: String?
21+
let error: Error?
22+
23+
func getToken(forcingRefresh forceRefresh: Bool) async throws -> String? {
24+
if let error {
25+
throw error
26+
}
27+
28+
return token
29+
}
30+
31+
func getUserID() -> String? {
32+
fatalError("\(#function) not implemented.")
33+
}
34+
35+
private init(token: String?, error: Error?) {
36+
self.token = token
37+
self.error = error
38+
}
39+
40+
convenience init(error: Error) {
41+
self.init(token: nil, error: error)
42+
}
43+
44+
convenience init(token: String?) {
45+
self.init(token: token, error: nil)
46+
}
47+
}
48+
49+
struct AuthErrorFake: Error {}

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
import FirebaseAppCheckInterop
16+
import FirebaseAuthInterop
1617
import FirebaseCore
1718
import XCTest
1819

@@ -41,6 +42,7 @@ final class GenerativeModelTests: XCTestCase {
4142
tools: nil,
4243
requestOptions: RequestOptions(),
4344
appCheck: nil,
45+
auth: nil,
4446
urlSession: urlSession
4547
)
4648
}
@@ -182,6 +184,7 @@ final class GenerativeModelTests: XCTestCase {
182184
tools: nil,
183185
requestOptions: RequestOptions(),
184186
appCheck: nil,
187+
auth: nil,
185188
urlSession: urlSession
186189
)
187190

@@ -266,6 +269,7 @@ final class GenerativeModelTests: XCTestCase {
266269
tools: nil,
267270
requestOptions: RequestOptions(),
268271
appCheck: AppCheckInteropFake(token: appCheckToken),
272+
auth: nil,
269273
urlSession: urlSession
270274
)
271275
MockURLProtocol
@@ -285,6 +289,7 @@ final class GenerativeModelTests: XCTestCase {
285289
tools: nil,
286290
requestOptions: RequestOptions(),
287291
appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
292+
auth: nil,
288293
urlSession: urlSession
289294
)
290295
MockURLProtocol
@@ -297,6 +302,74 @@ final class GenerativeModelTests: XCTestCase {
297302
_ = try await model.generateContent(testPrompt)
298303
}
299304

305+
func testGenerateContent_auth_validAuthToken() async throws {
306+
let authToken = "test-valid-token"
307+
model = GenerativeModel(
308+
name: "my-model",
309+
apiKey: "API_KEY",
310+
tools: nil,
311+
requestOptions: RequestOptions(),
312+
appCheck: nil,
313+
auth: AuthInteropFake(token: authToken),
314+
urlSession: urlSession
315+
)
316+
MockURLProtocol
317+
.requestHandler = try httpRequestHandler(
318+
forResource: "unary-success-basic-reply-short",
319+
withExtension: "json",
320+
authToken: authToken
321+
)
322+
323+
_ = try await model.generateContent(testPrompt)
324+
}
325+
326+
func testGenerateContent_auth_nilAuthToken() async throws {
327+
model = GenerativeModel(
328+
name: "my-model",
329+
apiKey: "API_KEY",
330+
tools: nil,
331+
requestOptions: RequestOptions(),
332+
appCheck: nil,
333+
auth: AuthInteropFake(token: nil),
334+
urlSession: urlSession
335+
)
336+
MockURLProtocol
337+
.requestHandler = try httpRequestHandler(
338+
forResource: "unary-success-basic-reply-short",
339+
withExtension: "json",
340+
authToken: nil
341+
)
342+
343+
_ = try await model.generateContent(testPrompt)
344+
}
345+
346+
func testGenerateContent_auth_authTokenRefreshError() async throws {
347+
model = GenerativeModel(
348+
name: "my-model",
349+
apiKey: "API_KEY",
350+
tools: nil,
351+
requestOptions: RequestOptions(),
352+
appCheck: nil,
353+
auth: AuthInteropFake(error: AuthErrorFake()),
354+
urlSession: urlSession
355+
)
356+
MockURLProtocol
357+
.requestHandler = try httpRequestHandler(
358+
forResource: "unary-success-basic-reply-short",
359+
withExtension: "json",
360+
authToken: nil
361+
)
362+
363+
do {
364+
_ = try await model.generateContent(testPrompt)
365+
XCTFail("Should throw internalError(AuthErrorFake); no error.")
366+
} catch GenerateContentError.internalError(_ as AuthErrorFake) {
367+
//
368+
} catch {
369+
XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
370+
}
371+
}
372+
300373
func testGenerateContent_usageMetadata() async throws {
301374
MockURLProtocol
302375
.requestHandler = try httpRequestHandler(
@@ -598,6 +671,7 @@ final class GenerativeModelTests: XCTestCase {
598671
tools: nil,
599672
requestOptions: requestOptions,
600673
appCheck: nil,
674+
auth: nil,
601675
urlSession: urlSession
602676
)
603677

@@ -808,6 +882,7 @@ final class GenerativeModelTests: XCTestCase {
808882
tools: nil,
809883
requestOptions: RequestOptions(),
810884
appCheck: AppCheckInteropFake(token: appCheckToken),
885+
auth: nil,
811886
urlSession: urlSession
812887
)
813888
MockURLProtocol
@@ -828,6 +903,7 @@ final class GenerativeModelTests: XCTestCase {
828903
tools: nil,
829904
requestOptions: RequestOptions(),
830905
appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
906+
auth: nil,
831907
urlSession: urlSession
832908
)
833909
MockURLProtocol
@@ -972,6 +1048,7 @@ final class GenerativeModelTests: XCTestCase {
9721048
tools: nil,
9731049
requestOptions: requestOptions,
9741050
appCheck: nil,
1051+
auth: nil,
9751052
urlSession: urlSession
9761053
)
9771054

@@ -1048,6 +1125,7 @@ final class GenerativeModelTests: XCTestCase {
10481125
tools: nil,
10491126
requestOptions: requestOptions,
10501127
appCheck: nil,
1128+
auth: nil,
10511129
urlSession: urlSession
10521130
)
10531131

@@ -1067,7 +1145,8 @@ final class GenerativeModelTests: XCTestCase {
10671145
apiKey: "API_KEY",
10681146
tools: nil,
10691147
requestOptions: RequestOptions(),
1070-
appCheck: nil
1148+
appCheck: nil,
1149+
auth: nil
10711150
)
10721151

10731152
XCTAssertEqual(model.modelResourceName, modelResourceName)
@@ -1081,7 +1160,8 @@ final class GenerativeModelTests: XCTestCase {
10811160
apiKey: "API_KEY",
10821161
tools: nil,
10831162
requestOptions: RequestOptions(),
1084-
appCheck: nil
1163+
appCheck: nil,
1164+
auth: nil
10851165
)
10861166

10871167
XCTAssertEqual(model.modelResourceName, modelResourceName)
@@ -1095,7 +1175,8 @@ final class GenerativeModelTests: XCTestCase {
10951175
apiKey: "API_KEY",
10961176
tools: nil,
10971177
requestOptions: RequestOptions(),
1098-
appCheck: nil
1178+
appCheck: nil,
1179+
auth: nil
10991180
)
11001181

11011182
XCTAssertEqual(model.modelResourceName, tunedModelResourceName)
@@ -1123,7 +1204,8 @@ final class GenerativeModelTests: XCTestCase {
11231204
withExtension ext: String,
11241205
statusCode: Int = 200,
11251206
timeout: TimeInterval = URLRequest.defaultTimeoutInterval(),
1126-
appCheckToken: String? = nil) throws -> ((URLRequest) throws -> (
1207+
appCheckToken: String? = nil,
1208+
authToken: String? = nil) throws -> ((URLRequest) throws -> (
11271209
URLResponse,
11281210
AsyncLineSequence<URL.AsyncBytes>?
11291211
)) {
@@ -1137,6 +1219,11 @@ final class GenerativeModelTests: XCTestCase {
11371219
XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
11381220
XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
11391221
XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
1222+
if let authToken {
1223+
XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
1224+
} else {
1225+
XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
1226+
}
11401227
let response = try XCTUnwrap(HTTPURLResponse(
11411228
url: requestURL,
11421229
statusCode: statusCode,

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,7 @@ let package = Package(
13681368
name: "FirebaseVertexAI",
13691369
dependencies: [
13701370
"FirebaseAppCheckInterop",
1371+
"FirebaseAuthInterop",
13711372
"FirebaseCore",
13721373
"FirebaseCoreExtension",
13731374
],

0 commit comments

Comments
 (0)