-
Notifications
You must be signed in to change notification settings - Fork 625
Add implementation for listDownloadedModels. #2154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
deb18fc
bc436f0
c83d392
f21799f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,10 @@ | |
import androidx.annotation.VisibleForTesting; | ||
import com.google.firebase.FirebaseApp; | ||
import com.google.firebase.ml.modeldownloader.CustomModel; | ||
import java.util.HashSet; | ||
import java.util.Set; | ||
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
|
||
/** @hide */ | ||
public class SharedPreferencesUtil { | ||
|
@@ -33,6 +37,8 @@ public class SharedPreferencesUtil { | |
// local model details | ||
private static final String LOCAL_MODEL_HASH_PATTERN = "current_model_hash_%s_%s"; | ||
private static final String LOCAL_MODEL_FILE_PATH_PATTERN = "current_model_path_%s_%s"; | ||
private static final String LOCAL_MODEL_FILE_PATH_MATCHER = "current_model_path_(.*?)_([^/]+)/?"; | ||
|
||
private static final String LOCAL_MODEL_FILE_SIZE_PATTERN = "current_model_size_%s_%s"; | ||
// details about model during download. | ||
private static final String DOWNLOADING_MODEL_HASH_PATTERN = "downloading_model_hash_%s_%s"; | ||
|
@@ -190,6 +196,41 @@ public synchronized void clearModelDetails(@NonNull String modelName, boolean cl | |
.commit(); | ||
} | ||
|
||
public synchronized Set<CustomModel> listDownloadedModels() { | ||
Set<CustomModel> customModels = new HashSet<>(); | ||
Set<String> keySet = getSharedPreferences().getAll().keySet(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you're reading all the keys at once and not dealing with the sharedpreferences anymore, it may not be necessary to sync the whole method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The second part (added todo) will need to coordinate with android download manager, so I'll need the sync when I add that. |
||
|
||
for (String key : keySet) { | ||
// if a local file path is present - get model details. | ||
Matcher matcher = Pattern.compile(LOCAL_MODEL_FILE_PATH_MATCHER).matcher(key); | ||
if (matcher.find()) { | ||
String modelName = matcher.group(matcher.groupCount()); | ||
CustomModel extractModel = getCustomModelDetails(modelName); | ||
if (extractModel != null) { | ||
customModels.add(extractModel); | ||
} | ||
} else { | ||
matcher = Pattern.compile(DOWNLOADING_MODEL_ID_PATTERN).matcher(key); | ||
if (matcher.find()) { | ||
String modelName = matcher.group(matcher.groupCount()); | ||
CustomModel extractModel = maybeGetUpdatedModel(modelName); | ||
if (extractModel != null) { | ||
customModels.add(extractModel); | ||
} | ||
} | ||
} | ||
} | ||
return customModels; | ||
} | ||
|
||
synchronized CustomModel maybeGetUpdatedModel(String modelName) { | ||
CustomModel downloadModel = getCustomModelDetails(modelName); | ||
// TODO(annz) check here if download currently in progress have completed. | ||
// if yes, then complete file relocation and return the updated model, otherwise return null | ||
|
||
return null; | ||
} | ||
|
||
/** | ||
* Clears all stored data related to a custom model download. | ||
* | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,34 +14,60 @@ | |
|
||
package com.google.firebase.ml.modeldownloader; | ||
|
||
import static com.google.common.truth.Truth.assertThat; | ||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertThrows; | ||
import static org.mockito.Mockito.when; | ||
|
||
import androidx.test.core.app.ApplicationProvider; | ||
import com.google.android.gms.tasks.Task; | ||
import com.google.firebase.FirebaseApp; | ||
import com.google.firebase.FirebaseOptions; | ||
import com.google.firebase.FirebaseOptions.Builder; | ||
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil; | ||
import java.util.Collections; | ||
import java.util.Set; | ||
import java.util.concurrent.ExecutionException; | ||
import java.util.concurrent.ExecutorService; | ||
import java.util.concurrent.Executors; | ||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.junit.runner.RunWith; | ||
import org.mockito.Mock; | ||
import org.mockito.MockitoAnnotations; | ||
import org.robolectric.RobolectricTestRunner; | ||
|
||
@RunWith(RobolectricTestRunner.class) | ||
public class FirebaseModelDownloaderTest { | ||
|
||
public static final String TEST_PROJECT_ID = "777777777777"; | ||
public static final FirebaseOptions FIREBASE_OPTIONS = | ||
new Builder() | ||
.setApplicationId("1:123456789:android:abcdef") | ||
.setProjectId(TEST_PROJECT_ID) | ||
.build(); | ||
public static final String MODEL_NAME = "MODEL_NAME_1"; | ||
public static final CustomModelDownloadConditions DEFAULT_DOWNLOAD_CONDITIONS = | ||
new CustomModelDownloadConditions.Builder().build(); | ||
|
||
public static final String MODEL_HASH = "dsf324"; | ||
// TODO replace with uploaded model. | ||
CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, 0, 100, MODEL_HASH); | ||
|
||
FirebaseModelDownloader firebaseModelDownloader; | ||
@Mock SharedPreferencesUtil mockPrefs; | ||
|
||
ExecutorService executor; | ||
|
||
@Before | ||
public void setUp() { | ||
MockitoAnnotations.initMocks(this); | ||
FirebaseApp.clearInstancesForTest(); | ||
// default app | ||
FirebaseApp.initializeApp( | ||
ApplicationProvider.getApplicationContext(), | ||
new FirebaseOptions.Builder() | ||
.setApplicationId("1:123456789:android:abcdef") | ||
.setProjectId(TEST_PROJECT_ID) | ||
.build()); | ||
FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS); | ||
|
||
executor = Executors.newSingleThreadExecutor(); | ||
firebaseModelDownloader = new FirebaseModelDownloader(FIREBASE_OPTIONS, mockPrefs, executor); | ||
} | ||
|
||
@Test | ||
|
@@ -54,10 +80,30 @@ public void getModel_unimplemented() { | |
} | ||
|
||
@Test | ||
public void listDownloadedModels_unimplemented() { | ||
assertThrows( | ||
UnsupportedOperationException.class, | ||
() -> FirebaseModelDownloader.getInstance().listDownloadedModels()); | ||
public void listDownloadedModels_returnsEmptyModelList() | ||
throws ExecutionException, InterruptedException { | ||
when(mockPrefs.listDownloadedModels()).thenReturn(Collections.emptySet()); | ||
TestOnCompleteListener<Set<CustomModel>> onCompleteListener = new TestOnCompleteListener<>(); | ||
Task<Set<CustomModel>> task = firebaseModelDownloader.listDownloadedModels(); | ||
task.addOnCompleteListener(executor, onCompleteListener); | ||
Set<CustomModel> customModelSet = onCompleteListener.await(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had tried that got - "java.lang.IllegalStateException: Must not be called on the main application thread", found this solution was used elsewhere. |
||
|
||
assertThat(task.isComplete()).isTrue(); | ||
assertEquals(customModelSet, Collections.EMPTY_SET); | ||
} | ||
|
||
@Test | ||
public void listDownloadedModels_returnsModelList() | ||
throws ExecutionException, InterruptedException { | ||
when(mockPrefs.listDownloadedModels()).thenReturn(Collections.singleton(CUSTOM_MODEL)); | ||
|
||
TestOnCompleteListener<Set<CustomModel>> onCompleteListener = new TestOnCompleteListener<>(); | ||
Task<Set<CustomModel>> task = firebaseModelDownloader.listDownloadedModels(); | ||
task.addOnCompleteListener(executor, onCompleteListener); | ||
Set<CustomModel> customModelSet = onCompleteListener.await(); | ||
|
||
assertThat(task.isComplete()).isTrue(); | ||
assertEquals(customModelSet, Collections.singleton(CUSTOM_MODEL)); | ||
} | ||
|
||
@Test | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// Copyright 2020 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package com.google.firebase.ml.modeldownloader; | ||
|
||
import androidx.annotation.NonNull; | ||
import com.google.android.gms.tasks.OnCompleteListener; | ||
import com.google.android.gms.tasks.Task; | ||
import java.io.IOException; | ||
import java.util.concurrent.CountDownLatch; | ||
import java.util.concurrent.ExecutionException; | ||
import java.util.concurrent.TimeUnit; | ||
|
||
/** | ||
* Helper listener that works around a limitation of the Tasks API where await() cannot be called on | ||
* the main thread. This listener works around it by running itself on a different thread, thus | ||
* allowing the main thread to be woken up when the Tasks complete. | ||
*/ | ||
public class TestOnCompleteListener<TResult> implements OnCompleteListener<TResult> { | ||
private static final long TIMEOUT_MS = 5000; | ||
private final CountDownLatch latch = new CountDownLatch(1); | ||
private Task<TResult> task; | ||
private volatile TResult result; | ||
private volatile Exception exception; | ||
private volatile boolean successful; | ||
|
||
@Override | ||
public void onComplete(@NonNull Task<TResult> task) { | ||
this.task = task; | ||
successful = task.isSuccessful(); | ||
if (successful) { | ||
result = task.getResult(); | ||
} else { | ||
exception = task.getException(); | ||
} | ||
latch.countDown(); | ||
} | ||
|
||
/** Blocks until the {@link #onComplete} is called. */ | ||
public TResult await() throws InterruptedException, ExecutionException { | ||
if (!latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)) { | ||
throw new InterruptedException("timed out waiting for result"); | ||
} | ||
if (successful) { | ||
return result; | ||
} else { | ||
if (exception instanceof InterruptedException) { | ||
throw (InterruptedException) exception; | ||
} | ||
// todo(annz) add firebase ml exception handling here. | ||
if (exception instanceof IOException) { | ||
throw new ExecutionException(exception); | ||
} | ||
throw new IllegalStateException("got an unexpected exception type", exception); | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.