Skip to content

Commit bc436f0

Browse files
committed
Update to background thread execution for listDownloadModels call.
1 parent deb18fc commit bc436f0

File tree

4 files changed

+140
-10
lines changed

4 files changed

+140
-10
lines changed

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,36 @@
1818
import androidx.annotation.VisibleForTesting;
1919
import com.google.android.gms.common.internal.Preconditions;
2020
import com.google.android.gms.tasks.Task;
21-
import com.google.android.gms.tasks.Tasks;
21+
import com.google.android.gms.tasks.TaskCompletionSource;
2222
import com.google.firebase.FirebaseApp;
2323
import com.google.firebase.FirebaseOptions;
2424
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
2525
import java.util.Set;
26+
import java.util.concurrent.Executor;
27+
import java.util.concurrent.Executors;
2628

2729
public class FirebaseModelDownloader {
2830

2931
private final FirebaseOptions firebaseOptions;
3032
private final SharedPreferencesUtil sharedPreferencesUtil;
33+
private final Executor executor;
3134

3235
FirebaseModelDownloader(FirebaseApp firebaseApp) {
3336
this.firebaseOptions = firebaseApp.getOptions();
3437
this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp);
38+
this.executor = Executors.newCachedThreadPool();
3539
}
3640

3741
@VisibleForTesting
3842
FirebaseModelDownloader(
39-
FirebaseOptions firebaseOptions, SharedPreferencesUtil sharedPreferencesUtil) {
43+
FirebaseOptions firebaseOptions,
44+
SharedPreferencesUtil sharedPreferencesUtil,
45+
Executor executor) {
4046
this.firebaseOptions = firebaseOptions;
4147
this.sharedPreferencesUtil = sharedPreferencesUtil;
48+
this.executor = executor;
4249
}
50+
4351
/**
4452
* Returns the {@link FirebaseModelDownloader} initialized with the default {@link FirebaseApp}.
4553
*
@@ -94,7 +102,10 @@ public Task<CustomModel> getModel(
94102
/** @return The set of all models that are downloaded to this device. */
95103
@NonNull
96104
public Task<Set<CustomModel>> listDownloadedModels() {
97-
return Tasks.forResult(sharedPreferencesUtil.listDownloadedModels());
105+
TaskCompletionSource<Set<CustomModel>> taskCompletionSource = new TaskCompletionSource<>();
106+
executor.execute(
107+
() -> taskCompletionSource.setResult(sharedPreferencesUtil.listDownloadedModels()));
108+
return taskCompletionSource.getTask();
98109
}
99110

100111
/*

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,28 @@ public synchronized Set<CustomModel> listDownloadedModels() {
209209
if (extractModel != null) {
210210
customModels.add(extractModel);
211211
}
212+
} else {
213+
matcher = Pattern.compile(DOWNLOADING_MODEL_ID_PATTERN).matcher(key);
214+
if (matcher.find()) {
215+
String modelName = matcher.group(matcher.groupCount());
216+
CustomModel extractModel = isDownloadCompleted(modelName);
217+
if (extractModel != null) {
218+
customModels.add(extractModel);
219+
}
220+
}
212221
}
213222
}
214223
return customModels;
215224
}
216225

226+
synchronized CustomModel isDownloadCompleted(String modelName) {
227+
CustomModel downloadModel = getCustomModelDetails(modelName);
228+
// TODO(annz) check here if download currently in progress have completed.
229+
// if yes, then complete file relocation and return the updated model, otherwise return null
230+
231+
return null;
232+
}
233+
217234
/**
218235
* Clears all stored data related to a custom model download.
219236
*

firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import static com.google.common.truth.Truth.assertThat;
1818
import static org.junit.Assert.assertEquals;
1919
import static org.junit.Assert.assertThrows;
20-
import static org.mockito.Mockito.mock;
2120
import static org.mockito.Mockito.when;
2221

2322
import androidx.test.core.app.ApplicationProvider;
@@ -28,9 +27,17 @@
2827
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
2928
import java.util.Collections;
3029
import java.util.Set;
30+
import java.util.concurrent.ExecutionException;
31+
import java.util.concurrent.ExecutorService;
32+
import java.util.concurrent.LinkedBlockingQueue;
33+
import java.util.concurrent.ThreadPoolExecutor;
34+
import java.util.concurrent.TimeUnit;
35+
import org.junit.After;
3136
import org.junit.Before;
3237
import org.junit.Test;
3338
import org.junit.runner.RunWith;
39+
import org.mockito.Mock;
40+
import org.mockito.MockitoAnnotations;
3441
import org.robolectric.RobolectricTestRunner;
3542

3643
@RunWith(RobolectricTestRunner.class)
@@ -51,14 +58,28 @@ public class FirebaseModelDownloaderTest {
5158
CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, 0, 100, MODEL_HASH);
5259

5360
FirebaseModelDownloader firebaseModelDownloader;
54-
SharedPreferencesUtil mockPrefs = mock(SharedPreferencesUtil.class);
61+
@Mock SharedPreferencesUtil mockPrefs;
62+
63+
ExecutorService executor;
5564

5665
@Before
5766
public void setUp() {
67+
MockitoAnnotations.initMocks(this);
5868
FirebaseApp.clearInstancesForTest();
5969
// default app
6070
FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS);
61-
firebaseModelDownloader = new FirebaseModelDownloader(FIREBASE_OPTIONS, mockPrefs);
71+
72+
executor = new ThreadPoolExecutor(0, 1, 30L, TimeUnit.SECONDS, new LinkedBlockingQueue<>());
73+
firebaseModelDownloader = new FirebaseModelDownloader(FIREBASE_OPTIONS, mockPrefs, executor);
74+
}
75+
76+
@After
77+
public void cleanUp() {
78+
try {
79+
executor.awaitTermination(250, TimeUnit.MILLISECONDS);
80+
} catch (InterruptedException e) {
81+
// do nothing.
82+
}
6283
}
6384

6485
@Test
@@ -71,19 +92,32 @@ public void getModel_unimplemented() {
7192
}
7293

7394
@Test
74-
public void listDownloadedModels_returnsEmptyModelList() {
95+
public void listDownloadedModels_returnsEmptyModelList()
96+
throws ExecutionException, InterruptedException {
7597
when(mockPrefs.listDownloadedModels()).thenReturn(Collections.emptySet());
98+
TestOnCompleteListener<Set<CustomModel>> onCompleteListener = new TestOnCompleteListener<>();
7699
Task<Set<CustomModel>> task = firebaseModelDownloader.listDownloadedModels();
100+
task.addOnCompleteListener(executor, onCompleteListener);
101+
Set<CustomModel> customModelSet = onCompleteListener.await();
102+
77103
assertThat(task.isComplete()).isTrue();
78-
assertEquals(task.getResult(), Collections.EMPTY_SET);
104+
assertEquals(customModelSet, Collections.EMPTY_SET);
79105
}
80106

81107
@Test
82-
public void listDownloadedModels_returnsModelList() {
108+
public void listDownloadedModels_returnsModelList()
109+
throws ExecutionException, InterruptedException {
83110
when(mockPrefs.listDownloadedModels()).thenReturn(Collections.singleton(CUSTOM_MODEL));
111+
112+
TestOnCompleteListener<Set<CustomModel>> onCompleteListener = new TestOnCompleteListener<>();
84113
Task<Set<CustomModel>> task = firebaseModelDownloader.listDownloadedModels();
114+
task.addOnCompleteListener(executor, onCompleteListener);
115+
Set<CustomModel> customModelSet = onCompleteListener.await();
116+
85117
assertThat(task.isComplete()).isTrue();
86-
assertEquals(task.getResult(), Collections.singleton(CUSTOM_MODEL));
118+
assertEquals(customModelSet, Collections.singleton(CUSTOM_MODEL));
119+
120+
executor.awaitTermination(500, TimeUnit.MILLISECONDS);
87121
}
88122

89123
@Test
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright 2020 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package com.google.firebase.ml.modeldownloader;
16+
17+
import androidx.annotation.NonNull;
18+
import com.google.android.gms.tasks.OnCompleteListener;
19+
import com.google.android.gms.tasks.Task;
20+
import java.io.IOException;
21+
import java.util.concurrent.CountDownLatch;
22+
import java.util.concurrent.ExecutionException;
23+
import java.util.concurrent.TimeUnit;
24+
25+
/**
26+
* Helper listener that works around a limitation of the Tasks API where await() cannot be called on
27+
* the main thread. This listener works around it by running itself on a different thread, thus
28+
* allowing the main thread to be woken up when the Tasks complete.
29+
*/
30+
public class TestOnCompleteListener<TResult> implements OnCompleteListener<TResult> {
31+
private static final long TIMEOUT_MS = 5000;
32+
private final CountDownLatch latch = new CountDownLatch(1);
33+
private Task<TResult> task;
34+
private volatile TResult result;
35+
private volatile Exception exception;
36+
private volatile boolean successful;
37+
38+
@Override
39+
public void onComplete(@NonNull Task<TResult> task) {
40+
this.task = task;
41+
successful = task.isSuccessful();
42+
if (successful) {
43+
result = task.getResult();
44+
} else {
45+
exception = task.getException();
46+
}
47+
latch.countDown();
48+
}
49+
50+
/** Blocks until the {@link #onComplete} is called. */
51+
public TResult await() throws InterruptedException, ExecutionException {
52+
if (!latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)) {
53+
throw new InterruptedException("timed out waiting for result");
54+
}
55+
if (successful) {
56+
return result;
57+
} else {
58+
if (exception instanceof InterruptedException) {
59+
throw (InterruptedException) exception;
60+
}
61+
// todo(annz) add firebase ml exception handling here.
62+
if (exception instanceof IOException) {
63+
throw new ExecutionException(exception);
64+
}
65+
throw new IllegalStateException("got an unexpected exception type", exception);
66+
}
67+
}
68+
}

0 commit comments

Comments
 (0)