Skip to content

Commit 42b7b6b

Browse files
committed
Update in background
1 parent e2535c9 commit 42b7b6b

File tree

4 files changed

+241
-44
lines changed

4 files changed

+241
-44
lines changed

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -120,44 +120,82 @@ public Task<CustomModel> getModel(
120120
if (localModel != null) {
121121
return Tasks.forResult(localModel);
122122
}
123-
Task<CustomModel> modelDetails =
124-
modelDownloadService.getCustomModelDetails(
125-
firebaseOptions.getProjectId(), modelName, null);
126-
127-
// no local model - start download.
128-
return modelDetails.continueWithTask(
129-
executor,
130-
modelDetailTask -> {
131-
if (modelDetailTask.isSuccessful()) {
132-
// start download
133-
return fileDownloadService
134-
.download(modelDetailTask.getResult(), conditions)
135-
.continueWithTask(
136-
executor,
137-
downloadTask -> {
138-
if (downloadTask.isSuccessful()) {
139-
// read the updated model
140-
CustomModel downloadedModel =
141-
sharedPreferencesUtil.getCustomModelDetails(modelName);
142-
// TODO(annz) trigger file move here as well... right now it's temp
143-
// call loadNewlyDownloadedModelFile
144-
return Tasks.forResult(downloadedModel);
145-
}
146-
return Tasks.forException(new Exception("File download failed."));
147-
});
148-
}
149-
return Tasks.forException(modelDetailTask.getException());
150-
});
123+
return getCustomModelTask(modelName, conditions);
151124
case LATEST_MODEL:
152-
// check for latest model and download newest
153-
break;
125+
// check for latest model, wait for latest if needed download newest
126+
return getCustomModelTask(modelName, conditions, localModel.getModelHash());
154127
case LOCAL_MODEL_UPDATE_IN_BACKGROUND:
155128
// start download in back ground return current model if not null.
156-
break;
129+
if (localModel != null) {
130+
// trigger update in background as needed - ignoring response.
131+
getCustomModelTask(modelName, conditions, localModel.getModelHash());
132+
// return local model
133+
return Tasks.forResult(localModel);
134+
}
135+
// no local model - get latest.
136+
return getCustomModelTask(modelName, conditions);
157137
}
158138
throw new UnsupportedOperationException("Not yet implemented.");
159139
}
160140

141+
// This version of getCustomModelTask will always call the modelDownloadService and upon
142+
// success will then trigger file download.
143+
private Task<CustomModel> getCustomModelTask(
144+
@NonNull String modelName, @Nullable CustomModelDownloadConditions conditions)
145+
throws Exception {
146+
return getCustomModelTask(modelName, conditions, null);
147+
}
148+
149+
// This version of getCustomModelTask will call the modelDownloadService and upon
150+
// success will only trigger file download, if there is a new model hash value.
151+
private Task<CustomModel> getCustomModelTask(
152+
@NonNull String modelName,
153+
@Nullable CustomModelDownloadConditions conditions,
154+
String modelHash)
155+
throws Exception {
156+
Task<CustomModel> incomingModelDetails =
157+
modelDownloadService.getCustomModelDetails(
158+
firebaseOptions.getProjectId(), modelName, modelHash);
159+
160+
return incomingModelDetails.continueWithTask(
161+
executor,
162+
incomingModelDetailTask -> {
163+
if (incomingModelDetailTask.isSuccessful()) {
164+
// null means we have the latest model
165+
CustomModel currentModel = sharedPreferencesUtil.getCustomModelDetails(modelName);
166+
if (incomingModelDetails.getResult() == null) {
167+
return Tasks.forResult(currentModel);
168+
}
169+
170+
// if modelHash matches current local model just return local model.
171+
if (currentModel.getModelHash().equals(incomingModelDetails.getResult().getModelHash())) {
172+
if (!currentModel.getLocalFilePath().isEmpty()) {
173+
return Tasks.forResult(currentModel);
174+
}
175+
// todo(annzimmer) wait for download? continue?
176+
}
177+
178+
// start download
179+
return fileDownloadService
180+
.download(incomingModelDetailTask.getResult(), conditions)
181+
.continueWithTask(
182+
executor,
183+
downloadTask -> {
184+
if (downloadTask.isSuccessful()) {
185+
// read the updated model
186+
CustomModel downloadedModel =
187+
currentModel;
188+
// TODO(annz) trigger file move here as well... right now it's temp
189+
// call loadNewlyDownloadedModelFile
190+
return Tasks.forResult(downloadedModel);
191+
}
192+
return Tasks.forException(new Exception("File download failed."));
193+
});
194+
}
195+
return Tasks.forException(incomingModelDetailTask.getException());
196+
});
197+
}
198+
161199
/**
162200
* Triggers the move to permanent storage of successful model downloads and lists all models
163201
* downloaded to device.

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManager.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
/** Model File Manager is used to move the downloaded file to the appropriate locations. */
3333
public class ModelFileManager {
3434

35-
@VisibleForTesting
36-
static final String CUSTOM_MODEL_ROOT_PATH = "com.google.firebase.ml.custom.models";
35+
public static final String CUSTOM_MODEL_ROOT_PATH = "com.google.firebase.ml.custom.models";
3736

3837
private static final int INVALID_INDEX = -1;
3938
private final Context context;

firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java

Lines changed: 170 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
import static com.google.common.truth.Truth.assertThat;
1818
import static org.junit.Assert.assertEquals;
1919
import static org.junit.Assert.assertThrows;
20+
import static org.junit.Assert.assertTrue;
2021
import static org.mockito.ArgumentMatchers.any;
2122
import static org.mockito.ArgumentMatchers.eq;
2223
import static org.mockito.Mockito.doNothing;
2324
import static org.mockito.Mockito.times;
2425
import static org.mockito.Mockito.verify;
2526
import static org.mockito.Mockito.when;
2627

28+
import android.os.ParcelFileDescriptor;
2729
import androidx.test.core.app.ApplicationProvider;
2830
import com.google.android.gms.tasks.Task;
2931
import com.google.android.gms.tasks.Tasks;
@@ -32,11 +34,14 @@
3234
import com.google.firebase.FirebaseOptions.Builder;
3335
import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService;
3436
import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService;
37+
import com.google.firebase.ml.modeldownloader.internal.ModelFileManager;
3538
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
39+
import java.io.File;
3640
import java.util.Collections;
3741
import java.util.Set;
3842
import java.util.concurrent.ExecutorService;
3943
import java.util.concurrent.Executors;
44+
import org.junit.After;
4045
import org.junit.Before;
4146
import org.junit.Test;
4247
import org.junit.runner.RunWith;
@@ -54,29 +59,42 @@ public class FirebaseModelDownloaderTest {
5459
.setProjectId(TEST_PROJECT_ID)
5560
.build();
5661
public static final String MODEL_NAME = "MODEL_NAME_1";
62+
public static final String MODEL_URL = "https://project.firebase.com/modelName/23424.jpg";
63+
private static final long URL_EXPIRATION = 604800L;
64+
5765
public static final CustomModelDownloadConditions DEFAULT_DOWNLOAD_CONDITIONS =
5866
new CustomModelDownloadConditions.Builder().build();
5967

6068
public static final String MODEL_HASH = "dsf324";
69+
public static final String UPDATE_MODEL_HASH = "fgh564";
6170
public static final CustomModelDownloadConditions DOWNLOAD_CONDITIONS =
6271
new CustomModelDownloadConditions.Builder().requireWifi().build();
6372

6473
// TODO replace with uploaded model.
6574
final CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0);
75+
final CustomModel UPDATE_CUSTOM_MODEL_URL =
76+
new CustomModel(MODEL_NAME, UPDATE_MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION + 10L);
77+
CustomModel customModelUploaded;
6678

6779
FirebaseModelDownloader firebaseModelDownloader;
6880
@Mock SharedPreferencesUtil mockPrefs;
6981
@Mock ModelFileDownloadService mockFileDownloadService;
7082
@Mock CustomModelDownloadService mockModelDownloadService;
7183
ExecutorService executor;
7284

85+
private File testModelFile;
86+
private File updatetestModelFile;
87+
private File modelFile0;
88+
String expectedDestinationFolder;
89+
ModelFileManager fileManager;
90+
7391
@Before
74-
public void setUp() {
92+
public void setUp() throws Exception {
7593
MockitoAnnotations.initMocks(this);
7694
FirebaseApp.clearInstancesForTest();
7795
// default app
78-
FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS);
79-
96+
FirebaseApp app =
97+
FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS);
8098
executor = Executors.newSingleThreadExecutor();
8199
firebaseModelDownloader =
82100
new FirebaseModelDownloader(
@@ -85,22 +103,164 @@ public void setUp() {
85103
mockFileDownloadService,
86104
mockModelDownloadService,
87105
executor);
106+
setUpTestingFiles(app);
107+
}
108+
109+
private void setUpTestingFiles(FirebaseApp app) throws Exception {
110+
fileManager = new ModelFileManager(app);
111+
final File testDir = new File(app.getApplicationContext().getNoBackupFilesDir(), "tmpModels");
112+
testDir.mkdirs();
113+
// make sure the directory is empty. Doesn't recurse into subdirs, but that's OK since
114+
// we're only using this directory for this test and we won't create any subdirs.
115+
for (File f : testDir.listFiles()) {
116+
if (f.isFile()) {
117+
f.delete();
118+
}
119+
}
120+
121+
testModelFile = File.createTempFile("modelFile", ".tflite");
122+
updatetestModelFile = File.createTempFile("modelFileUpdated", ".tflite");
123+
124+
expectedDestinationFolder =
125+
new File(
126+
app.getApplicationContext().getNoBackupFilesDir(),
127+
ModelFileManager.CUSTOM_MODEL_ROOT_PATH)
128+
.getAbsolutePath()
129+
+ "/"
130+
+ app.getPersistenceKey()
131+
+ "/"
132+
+ MODEL_NAME;
133+
// move first test file to a model, keep second for "updates"
134+
ParcelFileDescriptor fd =
135+
ParcelFileDescriptor.open(testModelFile, ParcelFileDescriptor.MODE_READ_ONLY);
136+
137+
modelFile0 = fileManager.moveModelToDestinationFolder(CUSTOM_MODEL, fd);
138+
assertEquals(modelFile0, new File(expectedDestinationFolder + "/0"));
139+
assertTrue(modelFile0.exists());
140+
customModelUploaded =
141+
new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, expectedDestinationFolder + "/0");
88142
}
89143

144+
@After
145+
public void teardown() {
146+
testModelFile.deleteOnExit();
147+
updatetestModelFile.deleteOnExit();
148+
modelFile0.deleteOnExit();
149+
}
150+
151+
// TODO(annz) Add all the conditional unit tests to match!
90152
@Test
91153
public void getModel_unimplemented() {
92154
assertThrows(
93155
UnsupportedOperationException.class,
94156
() ->
95157
FirebaseModelDownloader.getInstance()
96-
.getModel(
97-
MODEL_NAME,
98-
DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
99-
DEFAULT_DOWNLOAD_CONDITIONS));
158+
.getModel(MODEL_NAME, DownloadType.LATEST_MODEL, DEFAULT_DOWNLOAD_CONDITIONS));
159+
}
160+
161+
@Test
162+
public void getModel_updateBackground_localExists_noUpdate() throws Exception {
163+
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(CUSTOM_MODEL);
164+
when(mockModelDownloadService.getCustomModelDetails(
165+
eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(null)))
166+
.thenReturn(Tasks.forResult(null)); // no change found
167+
168+
TestOnCompleteListener<CustomModel> onCompleteListener = new TestOnCompleteListener<>();
169+
Task<CustomModel> task =
170+
firebaseModelDownloader.getModel(
171+
MODEL_NAME, DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, DOWNLOAD_CONDITIONS);
172+
task.addOnCompleteListener(executor, onCompleteListener);
173+
CustomModel customModel = onCompleteListener.await();
174+
175+
verify(mockPrefs, times(2)).getCustomModelDetails(eq(MODEL_NAME));
176+
assertThat(task.isComplete()).isTrue();
177+
assertEquals(customModel, CUSTOM_MODEL);
178+
}
179+
180+
@Test
181+
public void getModel_updateBackground_localExists_sameHash() throws Exception {
182+
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(CUSTOM_MODEL);
183+
when(mockModelDownloadService.getCustomModelDetails(
184+
eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(null)))
185+
.thenReturn(Tasks.forResult(CUSTOM_MODEL)); // no change found
186+
187+
TestOnCompleteListener<CustomModel> onCompleteListener = new TestOnCompleteListener<>();
188+
Task<CustomModel> task =
189+
firebaseModelDownloader.getModel(
190+
MODEL_NAME, DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, DOWNLOAD_CONDITIONS);
191+
task.addOnCompleteListener(executor, onCompleteListener);
192+
CustomModel customModel = onCompleteListener.await();
193+
194+
verify(mockPrefs, times(2)).getCustomModelDetails(eq(MODEL_NAME));
195+
assertThat(task.isComplete()).isTrue();
196+
assertEquals(customModel, CUSTOM_MODEL);
197+
}
198+
199+
@Test
200+
public void getModel_updateBackground_localExists_UpdateFound() throws Exception {
201+
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(CUSTOM_MODEL);
202+
when(mockModelDownloadService.getCustomModelDetails(
203+
eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(null)))
204+
.thenReturn(Tasks.forResult(UPDATE_CUSTOM_MODEL_URL));
205+
206+
TestOnCompleteListener<CustomModel> onCompleteListener = new TestOnCompleteListener<>();
207+
Task<CustomModel> task =
208+
firebaseModelDownloader.getModel(
209+
MODEL_NAME, DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, DOWNLOAD_CONDITIONS);
210+
task.addOnCompleteListener(executor, onCompleteListener);
211+
CustomModel customModel = onCompleteListener.await();
212+
213+
verify(mockPrefs, times(1)).getCustomModelDetails(eq(MODEL_NAME));
214+
assertThat(task.isComplete()).isTrue();
215+
assertEquals(customModel, CUSTOM_MODEL);
216+
}
217+
218+
@Test
219+
public void getModel_updateBackground_noLocalModel() throws Exception {
220+
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(null).thenReturn(CUSTOM_MODEL);
221+
when(mockModelDownloadService.getCustomModelDetails(
222+
eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(null)))
223+
.thenReturn(Tasks.forResult(CUSTOM_MODEL));
224+
when(mockFileDownloadService.download(any(), eq(DOWNLOAD_CONDITIONS)))
225+
.thenReturn(Tasks.forResult(null));
226+
TestOnCompleteListener<CustomModel> onCompleteListener = new TestOnCompleteListener<>();
227+
Task<CustomModel> task =
228+
firebaseModelDownloader.getModel(
229+
MODEL_NAME, DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, DOWNLOAD_CONDITIONS);
230+
task.addOnCompleteListener(executor, onCompleteListener);
231+
CustomModel customModel = onCompleteListener.await();
232+
233+
verify(mockPrefs, times(2)).getCustomModelDetails(eq(MODEL_NAME));
234+
assertThat(task.isComplete()).isTrue();
235+
assertEquals(customModel, CUSTOM_MODEL);
236+
}
237+
238+
@Test
239+
public void getModel_updateBackground_noLocalModel_error() throws Exception {
240+
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(null).thenReturn(CUSTOM_MODEL);
241+
when(mockModelDownloadService.getCustomModelDetails(
242+
eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(null)))
243+
.thenReturn(Tasks.forResult(CUSTOM_MODEL));
244+
when(mockFileDownloadService.download(any(), eq(DOWNLOAD_CONDITIONS)))
245+
.thenReturn(Tasks.forException(new Exception("bad download")));
246+
TestOnCompleteListener<CustomModel> onCompleteListener = new TestOnCompleteListener<>();
247+
Task<CustomModel> task =
248+
firebaseModelDownloader.getModel(
249+
MODEL_NAME, DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, DOWNLOAD_CONDITIONS);
250+
task.addOnCompleteListener(executor, onCompleteListener);
251+
try {
252+
onCompleteListener.await();
253+
} catch (Exception ex) {
254+
assertThat(ex.getMessage().contains("download failed")).isTrue();
255+
}
256+
257+
verify(mockPrefs, times(1)).getCustomModelDetails(eq(MODEL_NAME));
258+
assertThat(task.isComplete()).isTrue();
259+
assertThat(task.isSuccessful()).isFalse();
100260
}
101261

102262
@Test
103-
public void getModel_localExists() throws Exception {
263+
public void getModel_Local_localExists() throws Exception {
104264
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(CUSTOM_MODEL);
105265
TestOnCompleteListener<CustomModel> onCompleteListener = new TestOnCompleteListener<>();
106266
Task<CustomModel> task =
@@ -114,7 +274,7 @@ public void getModel_localExists() throws Exception {
114274
}
115275

116276
@Test
117-
public void getModel_noLocalModel() throws Exception {
277+
public void getModel_local_noLocalModel() throws Exception {
118278
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(null).thenReturn(CUSTOM_MODEL);
119279
when(mockModelDownloadService.getCustomModelDetails(
120280
eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(null)))
@@ -133,7 +293,7 @@ public void getModel_noLocalModel() throws Exception {
133293
}
134294

135295
@Test
136-
public void getModel_noLocalModel_error() throws Exception {
296+
public void getModel_local_noLocalModel_error() throws Exception {
137297
when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(null).thenReturn(CUSTOM_MODEL);
138298
when(mockModelDownloadService.getCustomModelDetails(
139299
eq(TEST_PROJECT_ID), eq(MODEL_NAME), eq(null)))

firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManagerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ private void setUpTestingFiles(FirebaseApp app) throws IOException {
7575
}
7676
}
7777

78-
testModelFile = File.createTempFile("modelFile", "tflite");
78+
testModelFile = File.createTempFile("modelFile", ".tflite");
7979
expectedDestinationFolder =
8080
new File(
8181
app.getApplicationContext().getNoBackupFilesDir(),

0 commit comments

Comments
 (0)