Skip to content

Commit c858d4c

Browse files
annzimmerschmidt-sebastian
authored andcommitted
Handle expired url file paths with retries. (#2276)
1 parent 8559a7d commit c858d4c

File tree

4 files changed

+251
-23
lines changed

4 files changed

+251
-23
lines changed

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,27 @@ private Task<CustomModel> getCustomModelTask(
172172

173173
// if modelHash matches current local model just return local model.
174174
// Should be handled by above case but just in case.
175-
if (currentModel != null
176-
&& currentModel
177-
.getModelHash()
178-
.equals(incomingModelDetails.getResult().getModelHash())) {
179-
if (!currentModel.getLocalFilePath().isEmpty()
175+
if (currentModel != null) {
176+
// is this the same model?
177+
if (currentModel
178+
.getModelHash()
179+
.equals(incomingModelDetails.getResult().getModelHash())
180+
&& currentModel.getLocalFilePath() != null
181+
&& !currentModel.getLocalFilePath().isEmpty()
180182
&& new File(currentModel.getLocalFilePath()).exists()) {
181183
return Tasks.forResult(currentModel);
182184
}
185+
186+
// is download already in progress for this hash?
187+
if (currentModel.getDownloadId() != 0) {
188+
CustomModel downloadingModel =
189+
sharedPreferencesUtil.getDownloadingCustomModelDetails(modelName);
190+
if (downloadingModel != null
191+
&& downloadingModel
192+
.getModelHash()
193+
.equals(incomingModelDetails.getResult().getModelHash()))
194+
return Tasks.forResult(downloadingModel);
195+
}
183196
// todo(annzimmer) this shouldn't happen unless they are calling the sdk with multiple
184197
// sets of download types/conditions.
185198
// this should be a download in progress - add appropriate handling.
@@ -193,19 +206,77 @@ && new File(currentModel.getLocalFilePath()).exists()) {
193206
downloadTask -> {
194207
if (downloadTask.isSuccessful()) {
195208
// read the updated model
196-
CustomModel downloadedModel =
197-
sharedPreferencesUtil.getCustomModelDetails(modelName);
198-
// trigger the file to be moved to permanent location.
199-
fileDownloadService.loadNewlyDownloadedModelFile(downloadedModel);
200-
return Tasks.forResult(downloadedModel);
209+
CustomModel updatedModel =
210+
sharedPreferencesUtil.getDownloadingCustomModelDetails(modelName);
211+
if (updatedModel == null) {
212+
// either download failed or it completed really fast.
213+
return Tasks.forResult(
214+
sharedPreferencesUtil.getCustomModelDetails(modelName));
215+
}
216+
// trigger the file to be moved to permanent location
217+
// This handles immediate download and completion.
218+
fileDownloadService.loadNewlyDownloadedModelFile(updatedModel);
219+
updatedModel =
220+
sharedPreferencesUtil.getDownloadingCustomModelDetails(modelName);
221+
// download complete - get current model.
222+
if (updatedModel == null) {
223+
updatedModel = sharedPreferencesUtil.getCustomModelDetails(modelName);
224+
}
225+
return Tasks.forResult(updatedModel);
226+
} else {
227+
return retryExpiredUrlDownload(modelName, conditions, downloadTask, 2);
201228
}
202-
return Tasks.forException(new Exception("File download failed."));
203229
});
204230
}
205231
return Tasks.forException(incomingModelDetailTask.getException());
206232
});
207233
}
208234

235+
private Task<CustomModel> retryExpiredUrlDownload(
236+
@NonNull String modelName,
237+
@Nullable CustomModelDownloadConditions conditions,
238+
Task<Void> downloadTask,
239+
int retryCounter)
240+
throws Exception {
241+
if (downloadTask.getException().getMessage().contains("Retry: Expired URL")) {
242+
// this is likely an expired url - retry once.
243+
Task<CustomModel> retryModelDetails =
244+
modelDownloadService.getCustomModelDetails(
245+
firebaseOptions.getProjectId(), modelName, null);
246+
// no local model - start download.
247+
return retryModelDetails.continueWithTask(
248+
executor,
249+
retryModelDetailTask -> {
250+
if (retryModelDetailTask.isSuccessful()) {
251+
// start download
252+
return fileDownloadService
253+
.download(retryModelDetailTask.getResult(), conditions)
254+
.continueWithTask(
255+
executor,
256+
retryDownloadTask -> {
257+
if (retryDownloadTask.isSuccessful()) {
258+
// read the updated model
259+
CustomModel downloadedModel =
260+
sharedPreferencesUtil.getCustomModelDetails(modelName);
261+
// TODO(annz) trigger file move here as well... right
262+
// now it's temp
263+
// call loadNewlyDownloadedModelFile
264+
return Tasks.forResult(downloadedModel);
265+
}
266+
if (retryCounter > 1) {
267+
return retryExpiredUrlDownload(
268+
modelName, conditions, downloadTask, retryCounter - 1);
269+
}
270+
return Tasks.forException(
271+
new Exception("File download failed. Too many attempts."));
272+
});
273+
}
274+
return Tasks.forException(retryModelDetailTask.getException());
275+
});
276+
}
277+
return Tasks.forException(new Exception("File download failed."));
278+
}
279+
209280
/**
210281
* Triggers the move to permanent storage of successful model downloads and lists all models
211282
* downloaded to device.

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadService.java

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import com.google.firebase.ml.modeldownloader.CustomModelDownloadConditions;
4141
import java.io.File;
4242
import java.io.FileNotFoundException;
43+
import java.util.Date;
4344
import java.util.regex.Matcher;
4445
import java.util.regex.Pattern;
4546

@@ -120,21 +121,23 @@ Task<Void> ensureModelDownloaded(CustomModel customModel) {
120121
return Tasks.forException(new Exception("Failed to schedule the download task"));
121122
}
122123

123-
return registerReceiverForDownloadId(newDownloadId);
124+
return registerReceiverForDownloadId(newDownloadId, customModel.getName());
124125
}
125126

126-
private synchronized DownloadBroadcastReceiver getReceiverInstance(long downloadId) {
127+
private synchronized DownloadBroadcastReceiver getReceiverInstance(
128+
long downloadId, String modelName) {
127129
DownloadBroadcastReceiver receiver = receiverMaps.get(downloadId);
128130
if (receiver == null) {
129131
receiver =
130-
new DownloadBroadcastReceiver(downloadId, getTaskCompletionSourceInstance(downloadId));
132+
new DownloadBroadcastReceiver(
133+
downloadId, modelName, getTaskCompletionSourceInstance(downloadId));
131134
receiverMaps.put(downloadId, receiver);
132135
}
133136
return receiver;
134137
}
135138

136-
private Task<Void> registerReceiverForDownloadId(long downloadId) {
137-
BroadcastReceiver broadcastReceiver = getReceiverInstance(downloadId);
139+
private Task<Void> registerReceiverForDownloadId(long downloadId, String modelName) {
140+
BroadcastReceiver broadcastReceiver = getReceiverInstance(downloadId, modelName);
138141
// It is okay to always register here. Since the broadcast receiver is the same via the lookup
139142
// for the same download id, the same broadcast receiver will be notified only once.
140143
context.registerReceiver(
@@ -249,10 +252,13 @@ public void maybeCheckDownloadingComplete() throws Exception {
249252
if (matcher.find()) {
250253
String modelName = matcher.group(matcher.groupCount());
251254
CustomModel downloadingModel = sharedPreferencesUtil.getCustomModelDetails(modelName);
252-
Integer statusCode = getDownloadingModelStatusCode(downloadingModel.getDownloadId());
253-
if (statusCode == DownloadManager.STATUS_SUCCESSFUL
254-
|| statusCode == DownloadManager.STATUS_FAILED) {
255-
loadNewlyDownloadedModelFile(downloadingModel);
255+
if (downloadingModel != null) {
256+
Integer statusCode = getDownloadingModelStatusCode(downloadingModel.getDownloadId());
257+
if (statusCode != null
258+
&& (statusCode == DownloadManager.STATUS_SUCCESSFUL
259+
|| statusCode == DownloadManager.STATUS_FAILED)) {
260+
loadNewlyDownloadedModelFile(downloadingModel);
261+
}
256262
}
257263
}
258264
}
@@ -264,7 +270,7 @@ public File loadNewlyDownloadedModelFile(CustomModel model) throws Exception {
264270
Long downloadingId = model.getDownloadId();
265271
String downloadingModelHash = model.getModelHash();
266272

267-
if (downloadingId == null || downloadingModelHash == null) {
273+
if (downloadingId == 0 || downloadingModelHash.isEmpty()) {
268274
// no downloading model file or incomplete info.
269275
return null;
270276
}
@@ -318,11 +324,13 @@ private class DownloadBroadcastReceiver extends BroadcastReceiver {
318324
// Download Id is captured inside this class in memory. So there is no concern of inconsistency
319325
// with the persisted download id in shared preferences.
320326
private final long downloadId;
327+
private final String modelName;
321328
private final TaskCompletionSource<Void> taskCompletionSource;
322329

323330
private DownloadBroadcastReceiver(
324-
long downloadId, TaskCompletionSource<Void> taskCompletionSource) {
331+
long downloadId, String modelName, TaskCompletionSource<Void> taskCompletionSource) {
325332
this.downloadId = downloadId;
333+
this.modelName = modelName;
326334
this.taskCompletionSource = taskCompletionSource;
327335
}
328336

@@ -351,6 +359,12 @@ public void onReceive(Context context, Intent intent) {
351359

352360
if (statusCode != null) {
353361
if (statusCode == DownloadManager.STATUS_FAILED) {
362+
if (checkErrorCausedByExpiry(id, modelName)) {
363+
// retry as a new download
364+
// todo change to FirebaseMlException retry error.
365+
taskCompletionSource.setException(new Exception("Retry: Expired URL"));
366+
return;
367+
}
354368
// todo add failure reason and logging
355369
System.out.println("Download Failed for id: " + id);
356370
taskCompletionSource.setException(new Exception("Failed"));
@@ -367,5 +381,28 @@ public void onReceive(Context context, Intent intent) {
367381
// Status code is null or not one of success or fail.
368382
taskCompletionSource.setException(new Exception("Model downloading failed"));
369383
}
384+
385+
private boolean checkErrorCausedByExpiry(Long downloadId, String modelName) {
386+
CustomModel model = sharedPreferencesUtil.getCustomModelDetails(modelName);
387+
388+
if (model == null) {
389+
return false;
390+
}
391+
392+
final Date time = new Date();
393+
394+
if (model.getDownloadUrlExpiry() < time.getTime()) {
395+
Cursor cursor =
396+
(downloadManager == null || downloadId == null)
397+
? null
398+
: downloadManager.query(new Query().setFilterById(downloadId));
399+
if (cursor != null && cursor.moveToFirst()) {
400+
int reason = cursor.getInt(cursor.getColumnIndex(DownloadManager.COLUMN_REASON));
401+
// 400 implies possibility of url expiry
402+
return (reason == 400);
403+
}
404+
}
405+
return false;
406+
}
370407
}
371408
}

firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ public class FirebaseModelDownloaderTest {
6161
private static final String MODEL_NAME = "MODEL_NAME_1";
6262
private static final String MODEL_URL = "https://project.firebase.com/modelName/23424.jpg";
6363
private static final long URL_EXPIRATION = 604800L;
64+
private static final long DOWNLOAD_ID = 99;
6465

6566
private static final CustomModelDownloadConditions DEFAULT_DOWNLOAD_CONDITIONS =
6667
new CustomModelDownloadConditions.Builder().build();
@@ -73,6 +74,8 @@ public class FirebaseModelDownloaderTest {
7374
private final CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0);
7475
private final CustomModel UPDATE_CUSTOM_MODEL_URL =
7576
new CustomModel(MODEL_NAME, UPDATE_MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION + 10L);
77+
private final CustomModel UPDATE_CUSTOM_MODEL_DOWNLOADING =
78+
new CustomModel(MODEL_NAME, UPDATE_MODEL_HASH, 100, DOWNLOAD_ID);
7679
private CustomModel customModelUploaded;
7780
private CustomModel customModelLoaded;
7881

@@ -262,6 +265,31 @@ public void getModel_latestModel_localExists_UpdateFound() throws Exception {
262265
assertEquals(customModel, customModelUploaded);
263266
}
264267

268+
@Test
269+
public void getModel_latestModel_localExists_DownloadInProgress() throws Exception {
270+
CustomModel customModelLoadedWithDownload =
271+
new CustomModel(MODEL_NAME, MODEL_HASH, 100, 99, expectedDestinationFolder + "/0");
272+
273+
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(customModelLoadedWithDownload);
274+
when(mockPrefs.getDownloadingCustomModelDetails(eq(MODEL_NAME)))
275+
.thenReturn(UPDATE_CUSTOM_MODEL_DOWNLOADING);
276+
277+
when(mockModelDownloadService.getCustomModelDetails(
278+
eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(MODEL_HASH)))
279+
.thenReturn(Tasks.forResult(UPDATE_CUSTOM_MODEL_URL));
280+
281+
TestOnCompleteListener<CustomModel> onCompleteListener = new TestOnCompleteListener<>();
282+
Task<CustomModel> task =
283+
firebaseModelDownloader.getModel(
284+
MODEL_NAME, DownloadType.LATEST_MODEL, DEFAULT_DOWNLOAD_CONDITIONS);
285+
task.addOnCompleteListener(executor, onCompleteListener);
286+
CustomModel customModel = onCompleteListener.await();
287+
288+
verify(mockPrefs, times(2)).getCustomModelDetails(eq(MODEL_NAME));
289+
assertThat(task.isComplete()).isTrue();
290+
assertEquals(customModel, UPDATE_CUSTOM_MODEL_DOWNLOADING);
291+
}
292+
265293
@Test
266294
public void getModel_latestModel_noLocalModel() throws Exception {
267295
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME)))
@@ -459,6 +487,58 @@ public void getModel_local_noLocalModel() throws Exception {
459487
assertEquals(customModel, CUSTOM_MODEL);
460488
}
461489

490+
@Test
491+
public void getModel_local_noLocalModel_urlRetry() throws Exception {
492+
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(null).thenReturn(CUSTOM_MODEL);
493+
when(mockModelDownloadService.getCustomModelDetails(
494+
eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(null)))
495+
.thenReturn(Tasks.forResult(CUSTOM_MODEL));
496+
when(mockFileDownloadService.download(any(), eq(DOWNLOAD_CONDITIONS)))
497+
.thenReturn(Tasks.forException(new Exception("Retry: Expired URL")))
498+
.thenReturn(Tasks.forResult(null));
499+
when(mockFileDownloadService.loadNewlyDownloadedModelFile(eq(customModelUploaded)))
500+
.thenReturn(firstDeviceModelFile);
501+
TestOnCompleteListener<CustomModel> onCompleteListener = new TestOnCompleteListener<>();
502+
Task<CustomModel> task =
503+
firebaseModelDownloader.getModel(MODEL_NAME, DownloadType.LOCAL_MODEL, DOWNLOAD_CONDITIONS);
504+
task.addOnCompleteListener(executor, onCompleteListener);
505+
CustomModel customModel = onCompleteListener.await();
506+
507+
verify(mockPrefs, times(3)).getCustomModelDetails(eq(MODEL_NAME));
508+
verify(mockFileDownloadService, times(2)).download(any(), eq(DOWNLOAD_CONDITIONS));
509+
assertThat(task.isComplete()).isTrue();
510+
assertEquals(customModel, CUSTOM_MODEL);
511+
}
512+
513+
@Test
514+
public void getModel_local_noLocalModel_urlRetry_maxTries() throws Exception {
515+
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME)))
516+
.thenReturn(null)
517+
.thenReturn(null)
518+
.thenReturn(CUSTOM_MODEL);
519+
when(mockModelDownloadService.getCustomModelDetails(
520+
eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(null)))
521+
.thenReturn(Tasks.forResult(CUSTOM_MODEL));
522+
when(mockFileDownloadService.download(any(), eq(DOWNLOAD_CONDITIONS)))
523+
.thenReturn(Tasks.forException(new Exception("Retry: Expired URL")));
524+
TestOnCompleteListener<CustomModel> onCompleteListener = new TestOnCompleteListener<>();
525+
Task<CustomModel> task =
526+
firebaseModelDownloader.getModel(MODEL_NAME, DownloadType.LOCAL_MODEL, DOWNLOAD_CONDITIONS);
527+
task.addOnCompleteListener(executor, onCompleteListener);
528+
try {
529+
onCompleteListener.await();
530+
} catch (Exception ex) {
531+
assertThat(ex.getMessage().contains("download failed")).isTrue();
532+
assertThat(ex.getMessage().contains("Too many attempts")).isTrue();
533+
}
534+
535+
verify(mockPrefs, times(2)).getCustomModelDetails(eq(MODEL_NAME));
536+
verify(mockFileDownloadService, times(3)).download(any(), eq(DOWNLOAD_CONDITIONS));
537+
verify(mockFileDownloadService, never()).loadNewlyDownloadedModelFile(any());
538+
assertThat(task.isComplete()).isTrue();
539+
assertThat(task.isSuccessful()).isFalse();
540+
}
541+
462542
@Test
463543
public void getModel_local_noLocalModel_error() throws Exception {
464544
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME)))

0 commit comments

Comments
 (0)