Skip to content

Commit deb18fc

Browse files
committed
Add implementation for listModels.
1 parent 9c2cc45 commit deb18fc

File tree

6 files changed

+116
-16
lines changed

6 files changed

+116
-16
lines changed

firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,6 @@ dependencies {
5555
testImplementation 'androidx.test:core:1.3.0'
5656
testImplementation 'com.google.truth:truth:1.0.1'
5757
testImplementation 'junit:junit:4.13'
58+
testImplementation 'org.mockito:mockito-core:3.3.3'
5859
testImplementation "org.robolectric:robolectric:$robolectricVersion"
5960
}

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,28 @@
1818
import androidx.annotation.VisibleForTesting;
1919
import com.google.android.gms.common.internal.Preconditions;
2020
import com.google.android.gms.tasks.Task;
21+
import com.google.android.gms.tasks.Tasks;
2122
import com.google.firebase.FirebaseApp;
2223
import com.google.firebase.FirebaseOptions;
24+
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
2325
import java.util.Set;
2426

2527
public class FirebaseModelDownloader {
2628

2729
private final FirebaseOptions firebaseOptions;
30+
private final SharedPreferencesUtil sharedPreferencesUtil;
2831

29-
FirebaseModelDownloader(FirebaseOptions firebaseOptions) {
30-
this.firebaseOptions = firebaseOptions;
32+
FirebaseModelDownloader(FirebaseApp firebaseApp) {
33+
this.firebaseOptions = firebaseApp.getOptions();
34+
this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp);
3135
}
3236

37+
@VisibleForTesting
38+
FirebaseModelDownloader(
39+
FirebaseOptions firebaseOptions, SharedPreferencesUtil sharedPreferencesUtil) {
40+
this.firebaseOptions = firebaseOptions;
41+
this.sharedPreferencesUtil = sharedPreferencesUtil;
42+
}
3343
/**
3444
* Returns the {@link FirebaseModelDownloader} initialized with the default {@link FirebaseApp}.
3545
*
@@ -84,7 +94,7 @@ public Task<CustomModel> getModel(
8494
/** @return The set of all models that are downloaded to this device. */
8595
@NonNull
8696
public Task<Set<CustomModel>> listDownloadedModels() {
87-
throw new UnsupportedOperationException("Not yet implemented.");
97+
return Tasks.forResult(sharedPreferencesUtil.listDownloadedModels());
8898
}
8999

90100
/*

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import androidx.annotation.NonNull;
1818
import com.google.firebase.FirebaseApp;
19-
import com.google.firebase.FirebaseOptions;
2019
import com.google.firebase.components.Component;
2120
import com.google.firebase.components.ComponentRegistrar;
2221
import com.google.firebase.components.Dependency;
@@ -38,8 +37,7 @@ public List<Component<?>> getComponents() {
3837
return Arrays.asList(
3938
Component.builder(FirebaseModelDownloader.class)
4039
.add(Dependency.required(FirebaseApp.class))
41-
.add(Dependency.required(FirebaseOptions.class))
42-
.factory(c -> new FirebaseModelDownloader(c.get(FirebaseOptions.class)))
40+
.factory(c -> new FirebaseModelDownloader(c.get(FirebaseApp.class)))
4341
.build(),
4442
LibraryVersionComponent.create("firebase-ml-modeldownloader", BuildConfig.VERSION_NAME));
4543
}

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
import androidx.annotation.VisibleForTesting;
2424
import com.google.firebase.FirebaseApp;
2525
import com.google.firebase.ml.modeldownloader.CustomModel;
26+
import java.util.HashSet;
27+
import java.util.Set;
28+
import java.util.regex.Matcher;
29+
import java.util.regex.Pattern;
2630

2731
/** @hide */
2832
public class SharedPreferencesUtil {
@@ -33,6 +37,8 @@ public class SharedPreferencesUtil {
3337
// local model details
3438
private static final String LOCAL_MODEL_HASH_PATTERN = "current_model_hash_%s_%s";
3539
private static final String LOCAL_MODEL_FILE_PATH_PATTERN = "current_model_path_%s_%s";
40+
private static final String LOCAL_MODEL_FILE_PATH_MATCHER = "current_model_path_(.*?)_([^/]+)/?";
41+
3642
private static final String LOCAL_MODEL_FILE_SIZE_PATTERN = "current_model_size_%s_%s";
3743
// details about model during download.
3844
private static final String DOWNLOADING_MODEL_HASH_PATTERN = "downloading_model_hash_%s_%s";
@@ -190,6 +196,24 @@ public synchronized void clearModelDetails(@NonNull String modelName, boolean cl
190196
.commit();
191197
}
192198

199+
public synchronized Set<CustomModel> listDownloadedModels() {
200+
Set<CustomModel> customModels = new HashSet<>();
201+
Set<String> keySet = getSharedPreferences().getAll().keySet();
202+
203+
for (String key : keySet) {
204+
// if a local file path is present - get model details.
205+
Matcher matcher = Pattern.compile(LOCAL_MODEL_FILE_PATH_MATCHER).matcher(key);
206+
if (matcher.find()) {
207+
String modelName = matcher.group(matcher.groupCount());
208+
CustomModel extractModel = getCustomModelDetails(modelName);
209+
if (extractModel != null) {
210+
customModels.add(extractModel);
211+
}
212+
}
213+
}
214+
return customModels;
215+
}
216+
193217
/**
194218
* Clears all stored data related to a custom model download.
195219
*

firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,20 @@
1414

1515
package com.google.firebase.ml.modeldownloader;
1616

17+
import static com.google.common.truth.Truth.assertThat;
18+
import static org.junit.Assert.assertEquals;
1719
import static org.junit.Assert.assertThrows;
20+
import static org.mockito.Mockito.mock;
21+
import static org.mockito.Mockito.when;
1822

1923
import androidx.test.core.app.ApplicationProvider;
24+
import com.google.android.gms.tasks.Task;
2025
import com.google.firebase.FirebaseApp;
2126
import com.google.firebase.FirebaseOptions;
27+
import com.google.firebase.FirebaseOptions.Builder;
28+
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
29+
import java.util.Collections;
30+
import java.util.Set;
2231
import org.junit.Before;
2332
import org.junit.Test;
2433
import org.junit.runner.RunWith;
@@ -28,20 +37,28 @@
2837
public class FirebaseModelDownloaderTest {
2938

3039
public static final String TEST_PROJECT_ID = "777777777777";
40+
public static final FirebaseOptions FIREBASE_OPTIONS =
41+
new Builder()
42+
.setApplicationId("1:123456789:android:abcdef")
43+
.setProjectId(TEST_PROJECT_ID)
44+
.build();
3145
public static final String MODEL_NAME = "MODEL_NAME_1";
3246
public static final CustomModelDownloadConditions DEFAULT_DOWNLOAD_CONDITIONS =
3347
new CustomModelDownloadConditions.Builder().build();
3448

49+
public static final String MODEL_HASH = "dsf324";
50+
// TODO replace with uploaded model.
51+
CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, 0, 100, MODEL_HASH);
52+
53+
FirebaseModelDownloader firebaseModelDownloader;
54+
SharedPreferencesUtil mockPrefs = mock(SharedPreferencesUtil.class);
55+
3556
@Before
3657
public void setUp() {
3758
FirebaseApp.clearInstancesForTest();
3859
// default app
39-
FirebaseApp.initializeApp(
40-
ApplicationProvider.getApplicationContext(),
41-
new FirebaseOptions.Builder()
42-
.setApplicationId("1:123456789:android:abcdef")
43-
.setProjectId(TEST_PROJECT_ID)
44-
.build());
60+
FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS);
61+
firebaseModelDownloader = new FirebaseModelDownloader(FIREBASE_OPTIONS, mockPrefs);
4562
}
4663

4764
@Test
@@ -54,10 +71,19 @@ public void getModel_unimplemented() {
5471
}
5572

5673
@Test
57-
public void listDownloadedModels_unimplemented() {
58-
assertThrows(
59-
UnsupportedOperationException.class,
60-
() -> FirebaseModelDownloader.getInstance().listDownloadedModels());
74+
public void listDownloadedModels_returnsEmptyModelList() {
75+
when(mockPrefs.listDownloadedModels()).thenReturn(Collections.emptySet());
76+
Task<Set<CustomModel>> task = firebaseModelDownloader.listDownloadedModels();
77+
assertThat(task.isComplete()).isTrue();
78+
assertEquals(task.getResult(), Collections.EMPTY_SET);
79+
}
80+
81+
@Test
82+
public void listDownloadedModels_returnsModelList() {
83+
when(mockPrefs.listDownloadedModels()).thenReturn(Collections.singleton(CUSTOM_MODEL));
84+
Task<Set<CustomModel>> task = firebaseModelDownloader.listDownloadedModels();
85+
assertThat(task.isComplete()).isTrue();
86+
assertEquals(task.getResult(), Collections.singleton(CUSTOM_MODEL));
6187
}
6288

6389
@Test

firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
import static org.junit.Assert.assertEquals;
1818
import static org.junit.Assert.assertNotNull;
1919
import static org.junit.Assert.assertNull;
20+
import static org.junit.Assert.assertTrue;
2021

2122
import androidx.test.core.app.ApplicationProvider;
2223
import com.google.firebase.FirebaseApp;
2324
import com.google.firebase.FirebaseOptions;
2425
import com.google.firebase.ml.modeldownloader.CustomModel;
26+
import java.util.Set;
2527
import org.junit.Before;
2628
import org.junit.Test;
2729
import org.junit.runner.RunWith;
@@ -118,4 +120,43 @@ public void clearDownloadingModelDetails_keepsLocalModel() throws IllegalArgumen
118120
retrievedModel = sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME);
119121
assertEquals(retrievedModel, CUSTOM_MODEL_DOWNLOAD_COMPLETE);
120122
}
123+
124+
@Test
125+
public void listDownloadedModels_localModelFound() throws IllegalArgumentException {
126+
sharedPreferencesUtil.setUploadedCustomModelDetails(CUSTOM_MODEL_DOWNLOAD_COMPLETE);
127+
Set<CustomModel> retrievedModel = sharedPreferencesUtil.listDownloadedModels();
128+
assertEquals(retrievedModel.size(), 1);
129+
assertEquals(retrievedModel.iterator().next(), CUSTOM_MODEL_DOWNLOAD_COMPLETE);
130+
}
131+
132+
@Test
133+
public void listDownloadedModels_downloadingModelNotFound() throws IllegalArgumentException {
134+
sharedPreferencesUtil.setDownloadingCustomModelDetails(CUSTOM_MODEL_DOWNLOADING);
135+
assertEquals(sharedPreferencesUtil.listDownloadedModels().size(), 0);
136+
}
137+
138+
@Test
139+
public void listDownloadedModels_noModels() throws IllegalArgumentException {
140+
assertEquals(sharedPreferencesUtil.listDownloadedModels().size(), 0);
141+
}
142+
143+
@Test
144+
public void listDownloadedModels_multipleModels() throws IllegalArgumentException {
145+
sharedPreferencesUtil.setUploadedCustomModelDetails(CUSTOM_MODEL_DOWNLOAD_COMPLETE);
146+
147+
CustomModel model2 =
148+
new CustomModel(MODEL_NAME + "2", 0, 102, MODEL_HASH + "2", "file/path/store/ModelName2/1");
149+
sharedPreferencesUtil.setUploadedCustomModelDetails(model2);
150+
151+
CustomModel model3 =
152+
new CustomModel(MODEL_NAME + "3", 0, 103, MODEL_HASH + "3", "file/path/store/ModelName3/1");
153+
154+
sharedPreferencesUtil.setUploadedCustomModelDetails(model3);
155+
156+
Set<CustomModel> retrievedModel = sharedPreferencesUtil.listDownloadedModels();
157+
assertEquals(retrievedModel.size(), 3);
158+
assertTrue(retrievedModel.contains(CUSTOM_MODEL_DOWNLOAD_COMPLETE));
159+
assertTrue(retrievedModel.contains(model2));
160+
assertTrue(retrievedModel.contains(model3));
161+
}
121162
}

0 commit comments

Comments
 (0)