Skip to content

Commit 0341e9c

Browse files
ML model downloader fixes (#7515)
* listDownloadedModels: don't cross-check URLs * Store models in a subfolder. * Test completion and progress handlers being called on the main thread. * More tests for callbacks on the main queue. * Fix public API handlers to be called on the main queue. * .gitignore ML GoogleService-Info.plist files * style * getModel: fail fast when modelName is empty string
1 parent c25f724 commit 0341e9c

File tree

4 files changed

+113
-31
lines changed

4 files changed

+113
-31
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ FirebaseSegmentation/Tests/Sample/GoogleService-Info.plist
3636

3737
# FirebaseMLModelDownloader integration tests GoogleService-Info.plist
3838
FirebaseMLModelDownloader/Tests/Integration/Resources/GoogleService-Info.plist
39-
FirebaseMLModelDownloader/Apps/Sample/Resources/GoogleService-Info.plist
39+
FirebaseMLModelDownloader/Apps/Sample/**/GoogleService-Info.plist
4040

4141
# FirebasePerformance dev test App and integration tests GoogleService-Info.plist
4242
FirebasePerformance/**/GoogleService-Info.plist

FirebaseMLModelDownloader/Sources/ModelDownloader.swift

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ public class ModelDownloader {
5656
/// DispatchQueue to manage download task dictionary.
5757
let taskSerialQueue = DispatchQueue(label: "downloadtask.serial.queue")
5858

59-
/// Handler that always runs on the main thread
60-
let mainQueueHandler = { handler in
59+
/// Re-dispatch a function on the main queue.
60+
func asyncOnMainQueue(_ work: @autoclosure @escaping () -> Void) {
6161
DispatchQueue.main.async {
62-
handler
62+
work()
6363
}
6464
}
6565

@@ -132,13 +132,18 @@ public class ModelDownloader {
132132
conditions: ModelDownloadConditions,
133133
progressHandler: ((Float) -> Void)? = nil,
134134
completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
135+
guard !modelName.isEmpty else {
136+
asyncOnMainQueue(completion(.failure(.emptyModelName)))
137+
return
138+
}
139+
135140
switch downloadType {
136141
case .localModel:
137142
if let localModel = getLocalModel(modelName: modelName) {
138143
DeviceLogger.logEvent(level: .debug,
139144
message: ModelDownloader.DebugDescription.localModelFound,
140145
messageCode: .localModelFound)
141-
mainQueueHandler(completion(.success(localModel)))
146+
asyncOnMainQueue(completion(.success(localModel)))
142147
} else {
143148
getRemoteModel(
144149
modelName: modelName,
@@ -153,7 +158,7 @@ public class ModelDownloader {
153158
DeviceLogger.logEvent(level: .debug,
154159
message: ModelDownloader.DebugDescription.localModelFound,
155160
messageCode: .localModelFound)
156-
mainQueueHandler(completion(.success(localModel)))
161+
asyncOnMainQueue(completion(.success(localModel)))
157162
telemetryLogger?.logModelDownloadEvent(
158163
eventName: .modelDownload,
159164
status: .scheduled,
@@ -225,7 +230,7 @@ public class ModelDownloader {
225230
DeviceLogger.logEvent(level: .debug,
226231
message: description,
227232
messageCode: .modelNameParseError)
228-
mainQueueHandler(completion(.failure(.internalError(description: description))))
233+
asyncOnMainQueue(completion(.failure(.internalError(description: description))))
229234
return
230235
}
231236
// Check if model information corresponding to model is stored in UserDefaults.
@@ -234,19 +239,19 @@ public class ModelDownloader {
234239
DeviceLogger.logEvent(level: .debug,
235240
message: description,
236241
messageCode: .noLocalModelInfo)
237-
mainQueueHandler(completion(.failure(.internalError(description: description))))
242+
asyncOnMainQueue(completion(.failure(.internalError(description: description))))
238243
return
239244
}
240245
// Ensure that local model path is as expected, and reachable.
241246
guard let modelURL = ModelFileManager.getDownloadedModelFileURL(
242247
appName: appName,
243248
modelName: modelName
244249
),
245-
url == modelURL, ModelFileManager.isFileReachable(at: modelURL) else {
250+
ModelFileManager.isFileReachable(at: modelURL) else {
246251
DeviceLogger.logEvent(level: .debug,
247252
message: ModelDownloader.ErrorDescription.outdatedModelPath,
248253
messageCode: .outdatedModelPathError)
249-
mainQueueHandler(completion(.failure(.internalError(description: ModelDownloader
254+
asyncOnMainQueue(completion(.failure(.internalError(description: ModelDownloader
250255
.ErrorDescription.outdatedModelPath))))
251256
return
252257
}
@@ -262,12 +267,12 @@ public class ModelDownloader {
262267
DeviceLogger.logEvent(level: .debug,
263268
message: ModelDownloader.ErrorDescription.listModelsFailed(error),
264269
messageCode: .listModelsError)
265-
mainQueueHandler(completion(.failure(error)))
270+
asyncOnMainQueue(completion(.failure(error)))
266271
} catch {
267272
DeviceLogger.logEvent(level: .debug,
268273
message: ModelDownloader.ErrorDescription.listModelsFailed(error),
269274
messageCode: .listModelsError)
270-
mainQueueHandler(completion(.failure(.internalError(description: error
275+
asyncOnMainQueue(completion(.failure(.internalError(description: error
271276
.localizedDescription))))
272277
}
273278
}
@@ -290,7 +295,7 @@ public class ModelDownloader {
290295
DeviceLogger.logEvent(level: .debug,
291296
message: ModelDownloader.ErrorDescription.modelNotFound(modelName),
292297
messageCode: .modelNotFound)
293-
mainQueueHandler(completion(.failure(.notFound)))
298+
asyncOnMainQueue(completion(.failure(.notFound)))
294299
return
295300
}
296301
do {
@@ -305,7 +310,7 @@ public class ModelDownloader {
305310
eventName: .remoteModelDeleteOnDevice,
306311
isSuccessful: true
307312
)
308-
mainQueueHandler(completion(.success(())))
313+
asyncOnMainQueue(completion(.success(())))
309314
} catch let error as DownloadedModelError {
310315
DeviceLogger.logEvent(level: .debug,
311316
message: ModelDownloader.ErrorDescription.modelDeletionFailed(error),
@@ -314,7 +319,7 @@ public class ModelDownloader {
314319
eventName: .remoteModelDeleteOnDevice,
315320
isSuccessful: false
316321
)
317-
mainQueueHandler(completion(.failure(error)))
322+
asyncOnMainQueue(completion(.failure(error)))
318323
} catch {
319324
DeviceLogger.logEvent(level: .debug,
320325
message: ModelDownloader.ErrorDescription.modelDeletionFailed(error),
@@ -323,7 +328,7 @@ public class ModelDownloader {
323328
eventName: .remoteModelDeleteOnDevice,
324329
isSuccessful: false
325330
)
326-
mainQueueHandler(completion(.failure(.internalError(description: error
331+
asyncOnMainQueue(completion(.failure(.internalError(description: error
327332
.localizedDescription))))
328333
}
329334
}
@@ -417,28 +422,28 @@ extension ModelDownloader {
417422
// Progress handler for model file download.
418423
let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
419424
if let progressHandler = progressHandler {
420-
self.mainQueueHandler(progressHandler(progress))
425+
self.asyncOnMainQueue(progressHandler(progress))
421426
}
422427
}
423428
// Completion handler for model file download.
424429
let taskCompletion: ModelDownloadTask.Completion = { result in
425430
switch result {
426431
case let .success(model):
427-
self.mainQueueHandler(completion(.success(model)))
432+
self.asyncOnMainQueue(completion(.success(model)))
428433
case let .failure(error):
429434
switch error {
430435
case .notFound:
431-
self.mainQueueHandler(completion(.failure(.notFound)))
436+
self.asyncOnMainQueue(completion(.failure(.notFound)))
432437
case .invalidArgument:
433-
self.mainQueueHandler(completion(.failure(.invalidArgument)))
438+
self.asyncOnMainQueue(completion(.failure(.invalidArgument)))
434439
case .permissionDenied:
435-
self.mainQueueHandler(completion(.failure(.permissionDenied)))
440+
self.asyncOnMainQueue(completion(.failure(.permissionDenied)))
436441
// This is the error returned when model download URL has expired.
437442
case .expiredDownloadURL:
438443
// Retry model info and model file download, if allowed.
439444
guard self.numberOfRetries > 0 else {
440445
self
441-
.mainQueueHandler(
446+
.asyncOnMainQueue(
442447
completion(.failure(.internalError(description: ModelDownloader
443448
.ErrorDescription
444449
.expiredModelInfo)))
@@ -458,7 +463,7 @@ extension ModelDownloader {
458463
completion: completion
459464
)
460465
default:
461-
self.mainQueueHandler(completion(.failure(error)))
466+
self.asyncOnMainQueue(completion(.failure(error)))
462467
}
463468
}
464469
self.taskSerialQueue.async {
@@ -502,15 +507,15 @@ extension ModelDownloader {
502507
guard let localModel = self.getLocalModel(modelName: modelName) else {
503508
// This can only happen if either local model info or the model file was wiped out after model info request but before server response.
504509
self
505-
.mainQueueHandler(completion(.failure(.internalError(description: ModelDownloader
510+
.asyncOnMainQueue(completion(.failure(.internalError(description: ModelDownloader
506511
.ErrorDescription.deletedLocalModelInfoOrFile))))
507512
return
508513
}
509-
self.mainQueueHandler(completion(.success(localModel)))
514+
self.asyncOnMainQueue(completion(.success(localModel)))
510515
}
511516
// Error retrieving model info.
512517
case let .failure(error):
513-
self.mainQueueHandler(completion(.failure(error)))
518+
self.asyncOnMainQueue(completion(.failure(error)))
514519
}
515520
}
516521
}
@@ -530,6 +535,8 @@ public enum DownloadError: Error, Equatable {
530535
case notEnoughSpace
531536
/// Malformed model name or Firebase app options.
532537
case invalidArgument
538+
/// Model name is empty.
539+
case emptyModelName
533540
/// Other errors with description.
534541
case internalError(description: String)
535542
}

FirebaseMLModelDownloader/Sources/ModelFileManager.swift

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,32 @@ enum ModelFileManager {
2727

2828
/// Root directory of model file storage on device.
2929
static var modelsDirectory: URL? {
30-
#if os(tvOS)
31-
return fileManager.urls(for: .cachesDirectory, in: .userDomainMask).first
32-
#else
33-
return fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first ?? nil
34-
#endif
30+
let rootDirOptional: URL? = {
31+
#if os(tvOS)
32+
return fileManager.urls(for: .cachesDirectory, in: .userDomainMask).first
33+
#else
34+
return fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first ?? nil
35+
#endif
36+
}()
37+
38+
guard let rootDirURL = rootDirOptional else {
39+
return nil
40+
}
41+
42+
let modelDirURL = rootDirURL.appendingPathComponent(
43+
"com.firebase.FirebaseMLModelDownloader",
44+
isDirectory: true
45+
)
46+
47+
do {
48+
if !fileManager.fileExists(atPath: modelDirURL.absoluteString) {
49+
try fileManager.createDirectory(at: modelDirURL, withIntermediateDirectories: true)
50+
}
51+
} catch {
52+
return nil
53+
}
54+
55+
return modelDirURL
3556
}
3657

3758
/// Name for model file stored on device.

FirebaseMLModelDownloader/Tests/Integration/ModelDownloaderIntegrationTests.swift

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,14 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
226226
downloadType: downloadType,
227227
conditions: conditions,
228228
progressHandler: { progress in
229+
XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
230+
229231
XCTAssertLessThanOrEqual(progress, 1)
230232
XCTAssertGreaterThanOrEqual(progress, 0)
231233
}
232234
) { result in
235+
XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
236+
233237
switch result {
234238
case let .success(model):
235239
XCTAssertNotNil(model.path)
@@ -299,6 +303,48 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
299303
wait(for: [localModelExpectation], timeout: 5)
300304
}
301305

306+
func testGetModelWhenNameIsEmpty() {
307+
guard let testApp = FirebaseApp.app() else {
308+
XCTFail("Default app was not configured.")
309+
return
310+
}
311+
let testName = String(#function.dropLast(2))
312+
let emptyModelName = ""
313+
314+
let conditions = ModelDownloadConditions()
315+
let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
316+
.createTestInstance(testName: testName),
317+
app: testApp
318+
)
319+
320+
let completionExpectation = expectation(description: "getModel")
321+
let progressExpectation = expectation(description: "progressHandler")
322+
progressExpectation.isInverted = true
323+
324+
modelDownloader.getModel(
325+
name: emptyModelName,
326+
downloadType: .latestModel,
327+
conditions: conditions,
328+
progressHandler: { progress in
329+
progressExpectation.fulfill()
330+
}
331+
) { result in
332+
XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
333+
334+
switch result {
335+
case .failure(.emptyModelName):
336+
// The expected error.
337+
break
338+
339+
default:
340+
XCTFail("Unexpected result: \(result)")
341+
}
342+
completionExpectation.fulfill()
343+
}
344+
345+
wait(for: [completionExpectation, progressExpectation], timeout: 5)
346+
}
347+
302348
/// Delete previously downloaded model.
303349
func testDeleteModel() {
304350
guard let testApp = FirebaseApp.app() else {
@@ -323,10 +369,13 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
323369
downloadType: downloadType,
324370
conditions: conditions,
325371
progressHandler: { progress in
372+
XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
326373
XCTAssertLessThanOrEqual(progress, 1)
327374
XCTAssertGreaterThanOrEqual(progress, 0)
328375
}
329376
) { result in
377+
XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
378+
330379
switch result {
331380
case let .success(model):
332381
XCTAssertNotNil(model.path)
@@ -343,6 +392,8 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
343392
let deleteExpectation = expectation(description: "Wait for model deletion.")
344393
modelDownloader.deleteDownloadedModel(name: testModelName) { result in
345394
deleteExpectation.fulfill()
395+
XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
396+
346397
switch result {
347398
case .success: break
348399
case let .failure(error):
@@ -494,10 +545,13 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
494545
downloadType: .latestModel,
495546
conditions: conditions,
496547
progressHandler: { progress in
548+
XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
497549
XCTAssertLessThanOrEqual(progress, 1)
498550
XCTAssertGreaterThanOrEqual(progress, 0)
499551
}
500552
) { result in
553+
XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
554+
501555
switch result {
502556
case let .success(model):
503557
XCTAssertNotNil(model.path)

0 commit comments

Comments
 (0)