@@ -56,10 +56,10 @@ public class ModelDownloader {
56
56
/// DispatchQueue to manage download task dictionary.
57
57
let taskSerialQueue = DispatchQueue ( label: " downloadtask.serial.queue " )
58
58
59
- /// Handler that always runs on the main thread
60
- let mainQueueHandler = { handler in
59
+ /// Re-dispatch a function on the main queue.
60
+ func asyncOnMainQueue ( _ work : @autoclosure @escaping ( ) -> Void ) {
61
61
DispatchQueue . main. async {
62
- handler
62
+ work ( )
63
63
}
64
64
}
65
65
@@ -132,13 +132,18 @@ public class ModelDownloader {
132
132
conditions: ModelDownloadConditions ,
133
133
progressHandler: ( ( Float ) -> Void ) ? = nil ,
134
134
completion: @escaping ( Result < CustomModel , DownloadError > ) -> Void ) {
135
+ guard !modelName. isEmpty else {
136
+ asyncOnMainQueue ( completion ( . failure( . emptyModelName) ) )
137
+ return
138
+ }
139
+
135
140
switch downloadType {
136
141
case . localModel:
137
142
if let localModel = getLocalModel ( modelName: modelName) {
138
143
DeviceLogger . logEvent ( level: . debug,
139
144
message: ModelDownloader . DebugDescription. localModelFound,
140
145
messageCode: . localModelFound)
141
- mainQueueHandler ( completion ( . success( localModel) ) )
146
+ asyncOnMainQueue ( completion ( . success( localModel) ) )
142
147
} else {
143
148
getRemoteModel (
144
149
modelName: modelName,
@@ -153,7 +158,7 @@ public class ModelDownloader {
153
158
DeviceLogger . logEvent ( level: . debug,
154
159
message: ModelDownloader . DebugDescription. localModelFound,
155
160
messageCode: . localModelFound)
156
- mainQueueHandler ( completion ( . success( localModel) ) )
161
+ asyncOnMainQueue ( completion ( . success( localModel) ) )
157
162
telemetryLogger? . logModelDownloadEvent (
158
163
eventName: . modelDownload,
159
164
status: . scheduled,
@@ -225,7 +230,7 @@ public class ModelDownloader {
225
230
DeviceLogger . logEvent ( level: . debug,
226
231
message: description,
227
232
messageCode: . modelNameParseError)
228
- mainQueueHandler ( completion ( . failure( . internalError( description: description) ) ) )
233
+ asyncOnMainQueue ( completion ( . failure( . internalError( description: description) ) ) )
229
234
return
230
235
}
231
236
// Check if model information corresponding to model is stored in UserDefaults.
@@ -234,19 +239,19 @@ public class ModelDownloader {
234
239
DeviceLogger . logEvent ( level: . debug,
235
240
message: description,
236
241
messageCode: . noLocalModelInfo)
237
- mainQueueHandler ( completion ( . failure( . internalError( description: description) ) ) )
242
+ asyncOnMainQueue ( completion ( . failure( . internalError( description: description) ) ) )
238
243
return
239
244
}
240
245
// Ensure that local model path is as expected, and reachable.
241
246
guard let modelURL = ModelFileManager . getDownloadedModelFileURL (
242
247
appName: appName,
243
248
modelName: modelName
244
249
) ,
245
- url == modelURL , ModelFileManager . isFileReachable ( at: modelURL) else {
250
+ ModelFileManager . isFileReachable ( at: modelURL) else {
246
251
DeviceLogger . logEvent ( level: . debug,
247
252
message: ModelDownloader . ErrorDescription. outdatedModelPath,
248
253
messageCode: . outdatedModelPathError)
249
- mainQueueHandler ( completion ( . failure( . internalError( description: ModelDownloader
254
+ asyncOnMainQueue ( completion ( . failure( . internalError( description: ModelDownloader
250
255
. ErrorDescription. outdatedModelPath) ) ) )
251
256
return
252
257
}
@@ -262,12 +267,12 @@ public class ModelDownloader {
262
267
DeviceLogger . logEvent ( level: . debug,
263
268
message: ModelDownloader . ErrorDescription. listModelsFailed ( error) ,
264
269
messageCode: . listModelsError)
265
- mainQueueHandler ( completion ( . failure( error) ) )
270
+ asyncOnMainQueue ( completion ( . failure( error) ) )
266
271
} catch {
267
272
DeviceLogger . logEvent ( level: . debug,
268
273
message: ModelDownloader . ErrorDescription. listModelsFailed ( error) ,
269
274
messageCode: . listModelsError)
270
- mainQueueHandler ( completion ( . failure( . internalError( description: error
275
+ asyncOnMainQueue ( completion ( . failure( . internalError( description: error
271
276
. localizedDescription) ) ) )
272
277
}
273
278
}
@@ -290,7 +295,7 @@ public class ModelDownloader {
290
295
DeviceLogger . logEvent ( level: . debug,
291
296
message: ModelDownloader . ErrorDescription. modelNotFound ( modelName) ,
292
297
messageCode: . modelNotFound)
293
- mainQueueHandler ( completion ( . failure( . notFound) ) )
298
+ asyncOnMainQueue ( completion ( . failure( . notFound) ) )
294
299
return
295
300
}
296
301
do {
@@ -305,7 +310,7 @@ public class ModelDownloader {
305
310
eventName: . remoteModelDeleteOnDevice,
306
311
isSuccessful: true
307
312
)
308
- mainQueueHandler ( completion ( . success( ( ) ) ) )
313
+ asyncOnMainQueue ( completion ( . success( ( ) ) ) )
309
314
} catch let error as DownloadedModelError {
310
315
DeviceLogger . logEvent ( level: . debug,
311
316
message: ModelDownloader . ErrorDescription. modelDeletionFailed ( error) ,
@@ -314,7 +319,7 @@ public class ModelDownloader {
314
319
eventName: . remoteModelDeleteOnDevice,
315
320
isSuccessful: false
316
321
)
317
- mainQueueHandler ( completion ( . failure( error) ) )
322
+ asyncOnMainQueue ( completion ( . failure( error) ) )
318
323
} catch {
319
324
DeviceLogger . logEvent ( level: . debug,
320
325
message: ModelDownloader . ErrorDescription. modelDeletionFailed ( error) ,
@@ -323,7 +328,7 @@ public class ModelDownloader {
323
328
eventName: . remoteModelDeleteOnDevice,
324
329
isSuccessful: false
325
330
)
326
- mainQueueHandler ( completion ( . failure( . internalError( description: error
331
+ asyncOnMainQueue ( completion ( . failure( . internalError( description: error
327
332
. localizedDescription) ) ) )
328
333
}
329
334
}
@@ -417,28 +422,28 @@ extension ModelDownloader {
417
422
// Progress handler for model file download.
418
423
let taskProgressHandler : ModelDownloadTask . ProgressHandler = { progress in
419
424
if let progressHandler = progressHandler {
420
- self . mainQueueHandler ( progressHandler ( progress) )
425
+ self . asyncOnMainQueue ( progressHandler ( progress) )
421
426
}
422
427
}
423
428
// Completion handler for model file download.
424
429
let taskCompletion : ModelDownloadTask . Completion = { result in
425
430
switch result {
426
431
case let . success( model) :
427
- self . mainQueueHandler ( completion ( . success( model) ) )
432
+ self . asyncOnMainQueue ( completion ( . success( model) ) )
428
433
case let . failure( error) :
429
434
switch error {
430
435
case . notFound:
431
- self . mainQueueHandler ( completion ( . failure( . notFound) ) )
436
+ self . asyncOnMainQueue ( completion ( . failure( . notFound) ) )
432
437
case . invalidArgument:
433
- self . mainQueueHandler ( completion ( . failure( . invalidArgument) ) )
438
+ self . asyncOnMainQueue ( completion ( . failure( . invalidArgument) ) )
434
439
case . permissionDenied:
435
- self . mainQueueHandler ( completion ( . failure( . permissionDenied) ) )
440
+ self . asyncOnMainQueue ( completion ( . failure( . permissionDenied) ) )
436
441
// This is the error returned when model download URL has expired.
437
442
case . expiredDownloadURL:
438
443
// Retry model info and model file download, if allowed.
439
444
guard self . numberOfRetries > 0 else {
440
445
self
441
- . mainQueueHandler (
446
+ . asyncOnMainQueue (
442
447
completion ( . failure( . internalError( description: ModelDownloader
443
448
. ErrorDescription
444
449
. expiredModelInfo) ) )
@@ -458,7 +463,7 @@ extension ModelDownloader {
458
463
completion: completion
459
464
)
460
465
default :
461
- self . mainQueueHandler ( completion ( . failure( error) ) )
466
+ self . asyncOnMainQueue ( completion ( . failure( error) ) )
462
467
}
463
468
}
464
469
self . taskSerialQueue. async {
@@ -502,15 +507,15 @@ extension ModelDownloader {
502
507
guard let localModel = self . getLocalModel ( modelName: modelName) else {
503
508
// This can only happen if either local model info or the model file was wiped out after model info request but before server response.
504
509
self
505
- . mainQueueHandler ( completion ( . failure( . internalError( description: ModelDownloader
510
+ . asyncOnMainQueue ( completion ( . failure( . internalError( description: ModelDownloader
506
511
. ErrorDescription. deletedLocalModelInfoOrFile) ) ) )
507
512
return
508
513
}
509
- self . mainQueueHandler ( completion ( . success( localModel) ) )
514
+ self . asyncOnMainQueue ( completion ( . success( localModel) ) )
510
515
}
511
516
// Error retrieving model info.
512
517
case let . failure( error) :
513
- self . mainQueueHandler ( completion ( . failure( error) ) )
518
+ self . asyncOnMainQueue ( completion ( . failure( error) ) )
514
519
}
515
520
}
516
521
}
@@ -530,6 +535,8 @@ public enum DownloadError: Error, Equatable {
530
535
case notEnoughSpace
531
536
/// Malformed model name or Firebase app options.
532
537
case invalidArgument
538
+ /// Model name is empty.
539
+ case emptyModelName
533
540
/// Other errors with description.
534
541
case internalError( description: String )
535
542
}
0 commit comments