@@ -121,34 +121,114 @@ public static FirebaseModelDownloader getInstance(@NonNull FirebaseApp app) {
121
121
public Task <CustomModel > getModel (
122
122
@ NonNull String modelName ,
123
123
@ NonNull DownloadType downloadType ,
124
- @ Nullable CustomModelDownloadConditions conditions )
125
- throws Exception {
126
- CustomModel localModel = sharedPreferencesUtil .getCustomModelDetails (modelName );
127
- if (localModel == null ) {
124
+ @ Nullable CustomModelDownloadConditions conditions ) {
125
+ CustomModel localModelDetails = getLocalModelDetails (modelName );
126
+ if (localModelDetails == null ) {
128
127
// no local model - get latest.
129
128
return getCustomModelTask (modelName , conditions );
130
129
}
131
130
132
131
switch (downloadType ) {
133
132
case LOCAL_MODEL :
134
- return Tasks . forResult ( localModel );
133
+ return getCompletedLocalCustomModelTask ( localModelDetails );
135
134
case LATEST_MODEL :
136
135
// check for latest model, wait for download if newer model exists
137
- return getCustomModelTask (modelName , conditions , localModel .getModelHash ());
136
+ return getCustomModelTask (modelName , conditions , localModelDetails .getModelHash ());
138
137
case LOCAL_MODEL_UPDATE_IN_BACKGROUND :
139
- // start download in back ground, return local model
140
- getCustomModelTask (modelName , conditions , localModel .getModelHash ());
141
- return Tasks . forResult ( localModel );
138
+ // start download in background, if newer model exists
139
+ getCustomModelTask (modelName , conditions , localModelDetails .getModelHash ());
140
+ return getCompletedLocalCustomModelTask ( localModelDetails );
142
141
}
143
- throw new IllegalArgumentException (
144
- "Unsupported downloadType, please chose LOCAL_MODEL, LATEST_MODEL, or LOCAL_MODEL_UPDATE_IN_BACKGROUND" );
142
+ return Tasks .forException (
143
+ new FirebaseMlException (
144
+ "Unsupported downloadType, please chose LOCAL_MODEL, LATEST_MODEL, or LOCAL_MODEL_UPDATE_IN_BACKGROUND" ,
145
+ FirebaseMlException .INVALID_ARGUMENT ));
146
+ }
147
+
148
+ /**
149
+ * Checks the local model, if a completed download exists - returns this model. Else if a download
150
+ * is in progress returns the downloading model version. Otherwise, this model is in a bad state -
151
+ * clears the model and return null
152
+ *
153
+ * @param modelName - name of the model
154
+ * @return the local model with file downloaded details or null if no local model.
155
+ */
156
+ @ Nullable
157
+ private CustomModel getLocalModelDetails (@ NonNull String modelName ) {
158
+ CustomModel localModel = sharedPreferencesUtil .getCustomModelDetails (modelName );
159
+ if (localModel == null ) {
160
+ return null ;
161
+ }
162
+
163
+ // valid model file exists when local file path is set
164
+ if (localModel .getLocalFilePath () != null && localModel .isModelFilePresent ()) {
165
+ return localModel ;
166
+ }
167
+
168
+ // download is in progress - return downloading model details
169
+ if (localModel .getDownloadId () != 0 ) {
170
+ return sharedPreferencesUtil .getDownloadingCustomModelDetails (modelName );
171
+ }
172
+
173
+ // bad model state - delete all existing details and return null
174
+ deleteModelDetails (localModel .getName ());
175
+ return null ;
176
+ }
177
+
178
+ // Given a model, if the local file path is present, return model.
179
+ // Else if there is a file download is in progress, returns the download task.
180
+ // Otherwise reset model and return null - this should not happen.
181
+ private Task <CustomModel > getCompletedLocalCustomModelTask (@ NonNull CustomModel model ) {
182
+ // model file exists - use this
183
+ if (model .isModelFilePresent ()) {
184
+ return Tasks .forResult (model );
185
+ }
186
+
187
+ // download in progress - return the downloading task.
188
+ if (model .getDownloadId () != 0 ) {
189
+
190
+ // download in progress - find existing download task and wait for it to complete.
191
+ Task <Void > downloadInProgressTask =
192
+ fileDownloadService .getExistingDownloadTask (model .getDownloadId ());
193
+
194
+ if (downloadInProgressTask != null ) {
195
+ return downloadInProgressTask .continueWithTask (
196
+ executor ,
197
+ downloadTask -> {
198
+ if (downloadTask .isSuccessful ()) {
199
+ return finishModelDownload (model .getName ());
200
+ } else if (downloadTask .getException () instanceof FirebaseMlException ) {
201
+ return Tasks .forException ((FirebaseMlException ) downloadTask .getException ());
202
+ }
203
+ return Tasks .forException (
204
+ new FirebaseMlException (
205
+ "Model download failed for " + model .getName (),
206
+ FirebaseMlException .INTERNAL ));
207
+ });
208
+ }
209
+
210
+ // maybe download just completed - fetch latest model to check.
211
+ CustomModel latestModel = sharedPreferencesUtil .getCustomModelDetails (model .getName ());
212
+ if (latestModel != null && latestModel .isModelFilePresent ()) {
213
+ return Tasks .forResult (latestModel );
214
+ }
215
+ }
216
+
217
+ // bad model state - delete all existing model details and return exception
218
+ return deleteDownloadedModel (model .getName ())
219
+ .continueWithTask (
220
+ executor ,
221
+ deletionTask ->
222
+ Tasks .forException (
223
+ new FirebaseMlException (
224
+ "Model download in bad state - please retry" ,
225
+ FirebaseMlException .INTERNAL )));
145
226
}
146
227
147
228
// This version of getCustomModelTask will always call the modelDownloadService and upon
148
229
// success will then trigger file download.
149
230
private Task <CustomModel > getCustomModelTask (
150
- @ NonNull String modelName , @ Nullable CustomModelDownloadConditions conditions )
151
- throws Exception {
231
+ @ NonNull String modelName , @ Nullable CustomModelDownloadConditions conditions ) {
152
232
return getCustomModelTask (modelName , conditions , null );
153
233
}
154
234
@@ -157,8 +237,12 @@ private Task<CustomModel> getCustomModelTask(
157
237
private Task <CustomModel > getCustomModelTask (
158
238
@ NonNull String modelName ,
159
239
@ Nullable CustomModelDownloadConditions conditions ,
160
- @ Nullable String modelHash )
161
- throws Exception {
240
+ @ Nullable String modelHash ) {
241
+ CustomModel currentModel = sharedPreferencesUtil .getCustomModelDetails (modelName );
242
+ if (currentModel == null && modelHash != null ) {
243
+ // todo(annzimmer) log something about mismatched state and use hash = null
244
+ modelHash = null ;
245
+ }
162
246
Task <CustomModel > incomingModelDetails =
163
247
modelDownloadService .getCustomModelDetails (
164
248
firebaseOptions .getProjectId (), modelName , modelHash );
@@ -167,10 +251,22 @@ private Task<CustomModel> getCustomModelTask(
167
251
executor ,
168
252
incomingModelDetailTask -> {
169
253
if (incomingModelDetailTask .isSuccessful ()) {
170
- CustomModel currentModel = sharedPreferencesUtil .getCustomModelDetails (modelName );
171
- // null means we have the latest model
254
+ // null means we have the latest model or we failed to connect.
172
255
if (incomingModelDetailTask .getResult () == null ) {
173
- return Tasks .forResult (currentModel );
256
+ if (currentModel != null ) {
257
+ return getCompletedLocalCustomModelTask (currentModel );
258
+ }
259
+ // double check due to timing.
260
+ CustomModel updatedModel = sharedPreferencesUtil .getCustomModelDetails (modelName );
261
+ if (updatedModel != null ) {
262
+ return getCompletedLocalCustomModelTask (updatedModel );
263
+ }
264
+ // clean up model internally
265
+ deleteModelDetails (currentModel .getName ());
266
+ return Tasks .forException (
267
+ new FirebaseMlException (
268
+ "Possible caching issues: no model associated with " + modelName + "." ,
269
+ FirebaseMlException .INTERNAL ));
174
270
}
175
271
176
272
// if modelHash matches current local model just return local model.
@@ -183,7 +279,7 @@ private Task<CustomModel> getCustomModelTask(
183
279
&& currentModel .getLocalFilePath () != null
184
280
&& !currentModel .getLocalFilePath ().isEmpty ()
185
281
&& new File (currentModel .getLocalFilePath ()).exists ()) {
186
- return Tasks . forResult (currentModel );
282
+ return getCompletedLocalCustomModelTask (currentModel );
187
283
}
188
284
189
285
// is download already in progress for this hash?
@@ -193,12 +289,14 @@ && new File(currentModel.getLocalFilePath()).exists()) {
193
289
if (downloadingModel != null
194
290
&& downloadingModel
195
291
.getModelHash ()
196
- .equals (incomingModelDetails .getResult ().getModelHash ()))
292
+ .equals (incomingModelDetails .getResult ().getModelHash ())) {
197
293
return Tasks .forResult (downloadingModel );
294
+ }
295
+ // todo(annzimmer) this shouldn't happen unless they are calling the sdk with
296
+ // multiple
297
+ // sets of download types/conditions.
298
+ // this should be a download in progress - add appropriate handling.
198
299
}
199
- // todo(annzimmer) this shouldn't happen unless they are calling the sdk with multiple
200
- // sets of download types/conditions.
201
- // this should be a download in progress - add appropriate handling.
202
300
}
203
301
204
302
// start download
@@ -208,24 +306,7 @@ && new File(currentModel.getLocalFilePath()).exists()) {
208
306
executor ,
209
307
downloadTask -> {
210
308
if (downloadTask .isSuccessful ()) {
211
- // read the updated model
212
- CustomModel updatedModel =
213
- sharedPreferencesUtil .getDownloadingCustomModelDetails (modelName );
214
- if (updatedModel == null ) {
215
- // either download failed or it completed really fast.
216
- return Tasks .forResult (
217
- sharedPreferencesUtil .getCustomModelDetails (modelName ));
218
- }
219
- // trigger the file to be moved to permanent location
220
- // This handles immediate download and completion.
221
- fileDownloadService .loadNewlyDownloadedModelFile (updatedModel );
222
- updatedModel =
223
- sharedPreferencesUtil .getDownloadingCustomModelDetails (modelName );
224
- // download complete - get current model.
225
- if (updatedModel == null ) {
226
- updatedModel = sharedPreferencesUtil .getCustomModelDetails (modelName );
227
- }
228
- return Tasks .forResult (updatedModel );
309
+ return finishModelDownload (modelName );
229
310
} else {
230
311
return retryExpiredUrlDownload (modelName , conditions , downloadTask , 2 );
231
312
}
@@ -239,13 +320,12 @@ private Task<CustomModel> retryExpiredUrlDownload(
239
320
@ NonNull String modelName ,
240
321
@ Nullable CustomModelDownloadConditions conditions ,
241
322
Task <Void > downloadTask ,
242
- int retryCounter )
243
- throws Exception {
323
+ int retryCounter ) {
244
324
if (downloadTask .getException ().getMessage ().contains ("Retry: Expired URL" )) {
245
- // this is likely an expired url - retry once .
325
+ // this is likely an expired url - retry.
246
326
Task <CustomModel > retryModelDetails =
247
- modelDownloadService .getCustomModelDetails (
248
- firebaseOptions .getProjectId (), modelName , null );
327
+ modelDownloadService .getNewDownloadUrlWithExpiry (
328
+ firebaseOptions .getProjectId (), modelName );
249
329
// no local model - start download.
250
330
return retryModelDetails .continueWithTask (
251
331
executor ,
@@ -258,13 +338,7 @@ private Task<CustomModel> retryExpiredUrlDownload(
258
338
executor ,
259
339
retryDownloadTask -> {
260
340
if (retryDownloadTask .isSuccessful ()) {
261
- // read the updated model
262
- CustomModel downloadedModel =
263
- sharedPreferencesUtil .getCustomModelDetails (modelName );
264
- // TODO(annz) trigger file move here as well... right
265
- // now it's temp
266
- // call loadNewlyDownloadedModelFile
267
- return Tasks .forResult (downloadedModel );
341
+ return finishModelDownload (modelName );
268
342
}
269
343
if (retryCounter > 1 ) {
270
344
return retryExpiredUrlDownload (
@@ -280,6 +354,24 @@ private Task<CustomModel> retryExpiredUrlDownload(
280
354
return Tasks .forException (new Exception ("File download failed." ));
281
355
}
282
356
357
+ private Task <CustomModel > finishModelDownload (@ NonNull String modelName ) {
358
+ // read the updated model
359
+ CustomModel downloadedModel = sharedPreferencesUtil .getDownloadingCustomModelDetails (modelName );
360
+ if (downloadedModel == null ) {
361
+ // check if latest download completed - if so use current.
362
+ downloadedModel = sharedPreferencesUtil .getCustomModelDetails (modelName );
363
+ if (downloadedModel == null ) {
364
+ return Tasks .forException (
365
+ new Exception (
366
+ "Model (" + modelName + ") expected and not found during download completion." ));
367
+ }
368
+ }
369
+ // trigger the file to be moved to permanent location.
370
+ fileDownloadService .loadNewlyDownloadedModelFile (downloadedModel );
371
+ downloadedModel = sharedPreferencesUtil .getCustomModelDetails (modelName );
372
+ return Tasks .forResult (downloadedModel );
373
+ }
374
+
283
375
/**
284
376
* Triggers the move to permanent storage of successful model downloads and lists all models
285
377
* downloaded to device.
@@ -290,11 +382,7 @@ private Task<CustomModel> retryExpiredUrlDownload(
290
382
@ NonNull
291
383
public Task <Set <CustomModel >> listDownloadedModels () {
292
384
// trigger completion of file moves for download files.
293
- try {
294
- fileDownloadService .maybeCheckDownloadingComplete ();
295
- } catch (Exception ex ) {
296
- System .out .println ("Error checking for in progress downloads: " + ex .getMessage ());
297
- }
385
+ fileDownloadService .maybeCheckDownloadingComplete ();
298
386
299
387
TaskCompletionSource <Set <CustomModel >> taskCompletionSource = new TaskCompletionSource <>();
300
388
executor .execute (
@@ -314,13 +402,17 @@ public Task<Void> deleteDownloadedModel(@NonNull String modelName) {
314
402
executor .execute (
315
403
() -> {
316
404
// remove all files associated with this model and then clean up model references.
317
- fileManager .deleteAllModels (modelName );
318
- sharedPreferencesUtil .clearModelDetails (modelName );
405
+ deleteModelDetails (modelName );
319
406
taskCompletionSource .setResult (null );
320
407
});
321
408
return taskCompletionSource .getTask ();
322
409
}
323
410
411
+ private void deleteModelDetails (@ NonNull String modelName ) {
412
+ fileManager .deleteAllModels (modelName );
413
+ sharedPreferencesUtil .clearModelDetails (modelName );
414
+ }
415
+
324
416
/**
325
417
* Update the settings which allow logging to firelog.
326
418
*
0 commit comments