17
17
import static com .google .common .truth .Truth .assertThat ;
18
18
import static org .junit .Assert .assertEquals ;
19
19
import static org .junit .Assert .assertThrows ;
20
+ import static org .junit .Assert .assertTrue ;
20
21
import static org .mockito .ArgumentMatchers .any ;
21
22
import static org .mockito .ArgumentMatchers .eq ;
22
23
import static org .mockito .Mockito .doNothing ;
23
24
import static org .mockito .Mockito .times ;
24
25
import static org .mockito .Mockito .verify ;
25
26
import static org .mockito .Mockito .when ;
26
27
28
+ import android .os .ParcelFileDescriptor ;
27
29
import androidx .test .core .app .ApplicationProvider ;
28
30
import com .google .android .gms .tasks .Task ;
29
31
import com .google .android .gms .tasks .Tasks ;
32
34
import com .google .firebase .FirebaseOptions .Builder ;
33
35
import com .google .firebase .ml .modeldownloader .internal .CustomModelDownloadService ;
34
36
import com .google .firebase .ml .modeldownloader .internal .ModelFileDownloadService ;
37
+ import com .google .firebase .ml .modeldownloader .internal .ModelFileManager ;
35
38
import com .google .firebase .ml .modeldownloader .internal .SharedPreferencesUtil ;
39
+ import java .io .File ;
36
40
import java .util .Collections ;
37
41
import java .util .Set ;
38
42
import java .util .concurrent .ExecutorService ;
39
43
import java .util .concurrent .Executors ;
44
+ import org .junit .After ;
40
45
import org .junit .Before ;
41
46
import org .junit .Test ;
42
47
import org .junit .runner .RunWith ;
@@ -54,29 +59,42 @@ public class FirebaseModelDownloaderTest {
54
59
.setProjectId (TEST_PROJECT_ID )
55
60
.build ();
56
61
public static final String MODEL_NAME = "MODEL_NAME_1" ;
62
+ public static final String MODEL_URL = "https://project.firebase.com/modelName/23424.jpg" ;
63
+ private static final long URL_EXPIRATION = 604800L ;
64
+
57
65
public static final CustomModelDownloadConditions DEFAULT_DOWNLOAD_CONDITIONS =
58
66
new CustomModelDownloadConditions .Builder ().build ();
59
67
60
68
public static final String MODEL_HASH = "dsf324" ;
69
+ public static final String UPDATE_MODEL_HASH = "fgh564" ;
61
70
public static final CustomModelDownloadConditions DOWNLOAD_CONDITIONS =
62
71
new CustomModelDownloadConditions .Builder ().requireWifi ().build ();
63
72
64
73
// TODO replace with uploaded model.
65
74
final CustomModel CUSTOM_MODEL = new CustomModel (MODEL_NAME , MODEL_HASH , 100 , 0 );
75
+ final CustomModel UPDATE_CUSTOM_MODEL_URL =
76
+ new CustomModel (MODEL_NAME , UPDATE_MODEL_HASH , 100 , MODEL_URL , URL_EXPIRATION + 10L );
77
+ CustomModel customModelUploaded ;
66
78
67
79
FirebaseModelDownloader firebaseModelDownloader ;
68
80
@ Mock SharedPreferencesUtil mockPrefs ;
69
81
@ Mock ModelFileDownloadService mockFileDownloadService ;
70
82
@ Mock CustomModelDownloadService mockModelDownloadService ;
71
83
ExecutorService executor ;
72
84
85
+ private File testModelFile ;
86
+ private File updatetestModelFile ;
87
+ private File modelFile0 ;
88
+ String expectedDestinationFolder ;
89
+ ModelFileManager fileManager ;
90
+
73
91
@ Before
74
- public void setUp () {
92
+ public void setUp () throws Exception {
75
93
MockitoAnnotations .initMocks (this );
76
94
FirebaseApp .clearInstancesForTest ();
77
95
// default app
78
- FirebaseApp . initializeApp ( ApplicationProvider . getApplicationContext (), FIREBASE_OPTIONS );
79
-
96
+ FirebaseApp app =
97
+ FirebaseApp . initializeApp ( ApplicationProvider . getApplicationContext (), FIREBASE_OPTIONS );
80
98
executor = Executors .newSingleThreadExecutor ();
81
99
firebaseModelDownloader =
82
100
new FirebaseModelDownloader (
@@ -85,22 +103,164 @@ public void setUp() {
85
103
mockFileDownloadService ,
86
104
mockModelDownloadService ,
87
105
executor );
106
+ setUpTestingFiles (app );
107
+ }
108
+
109
+ private void setUpTestingFiles (FirebaseApp app ) throws Exception {
110
+ fileManager = new ModelFileManager (app );
111
+ final File testDir = new File (app .getApplicationContext ().getNoBackupFilesDir (), "tmpModels" );
112
+ testDir .mkdirs ();
113
+ // make sure the directory is empty. Doesn't recurse into subdirs, but that's OK since
114
+ // we're only using this directory for this test and we won't create any subdirs.
115
+ for (File f : testDir .listFiles ()) {
116
+ if (f .isFile ()) {
117
+ f .delete ();
118
+ }
119
+ }
120
+
121
+ testModelFile = File .createTempFile ("modelFile" , ".tflite" );
122
+ updatetestModelFile = File .createTempFile ("modelFileUpdated" , ".tflite" );
123
+
124
+ expectedDestinationFolder =
125
+ new File (
126
+ app .getApplicationContext ().getNoBackupFilesDir (),
127
+ ModelFileManager .CUSTOM_MODEL_ROOT_PATH )
128
+ .getAbsolutePath ()
129
+ + "/"
130
+ + app .getPersistenceKey ()
131
+ + "/"
132
+ + MODEL_NAME ;
133
+ // move first test file to a model, keep second for "updates"
134
+ ParcelFileDescriptor fd =
135
+ ParcelFileDescriptor .open (testModelFile , ParcelFileDescriptor .MODE_READ_ONLY );
136
+
137
+ modelFile0 = fileManager .moveModelToDestinationFolder (CUSTOM_MODEL , fd );
138
+ assertEquals (modelFile0 , new File (expectedDestinationFolder + "/0" ));
139
+ assertTrue (modelFile0 .exists ());
140
+ customModelUploaded =
141
+ new CustomModel (MODEL_NAME , MODEL_HASH , 100 , 0 , expectedDestinationFolder + "/0" );
88
142
}
89
143
144
+ @ After
145
+ public void teardown () {
146
+ testModelFile .deleteOnExit ();
147
+ updatetestModelFile .deleteOnExit ();
148
+ modelFile0 .deleteOnExit ();
149
+ }
150
+
151
+ // TODO(annz) Add all the conditional unit tests to match!
90
152
@ Test
91
153
public void getModel_unimplemented () {
92
154
assertThrows (
93
155
UnsupportedOperationException .class ,
94
156
() ->
95
157
FirebaseModelDownloader .getInstance ()
96
- .getModel (
97
- MODEL_NAME ,
98
- DownloadType .LOCAL_MODEL_UPDATE_IN_BACKGROUND ,
99
- DEFAULT_DOWNLOAD_CONDITIONS ));
158
+ .getModel (MODEL_NAME , DownloadType .LATEST_MODEL , DEFAULT_DOWNLOAD_CONDITIONS ));
159
+ }
160
+
161
+ @ Test
162
+ public void getModel_updateBackground_localExists_noUpdate () throws Exception {
163
+ when (mockPrefs .getCustomModelDetails (eq (MODEL_NAME ))).thenReturn (CUSTOM_MODEL );
164
+ when (mockModelDownloadService .getCustomModelDetails (
165
+ eq (TEST_PROJECT_ID ), eq (MODEL_NAME ), eq (null )))
166
+ .thenReturn (Tasks .forResult (null )); // no change found
167
+
168
+ TestOnCompleteListener <CustomModel > onCompleteListener = new TestOnCompleteListener <>();
169
+ Task <CustomModel > task =
170
+ firebaseModelDownloader .getModel (
171
+ MODEL_NAME , DownloadType .LOCAL_MODEL_UPDATE_IN_BACKGROUND , DOWNLOAD_CONDITIONS );
172
+ task .addOnCompleteListener (executor , onCompleteListener );
173
+ CustomModel customModel = onCompleteListener .await ();
174
+
175
+ verify (mockPrefs , times (2 )).getCustomModelDetails (eq (MODEL_NAME ));
176
+ assertThat (task .isComplete ()).isTrue ();
177
+ assertEquals (customModel , CUSTOM_MODEL );
178
+ }
179
+
180
+ @ Test
181
+ public void getModel_updateBackground_localExists_sameHash () throws Exception {
182
+ when (mockPrefs .getCustomModelDetails (eq (MODEL_NAME ))).thenReturn (CUSTOM_MODEL );
183
+ when (mockModelDownloadService .getCustomModelDetails (
184
+ eq (TEST_PROJECT_ID ), eq (MODEL_NAME ), eq (null )))
185
+ .thenReturn (Tasks .forResult (CUSTOM_MODEL )); // no change found
186
+
187
+ TestOnCompleteListener <CustomModel > onCompleteListener = new TestOnCompleteListener <>();
188
+ Task <CustomModel > task =
189
+ firebaseModelDownloader .getModel (
190
+ MODEL_NAME , DownloadType .LOCAL_MODEL_UPDATE_IN_BACKGROUND , DOWNLOAD_CONDITIONS );
191
+ task .addOnCompleteListener (executor , onCompleteListener );
192
+ CustomModel customModel = onCompleteListener .await ();
193
+
194
+ verify (mockPrefs , times (2 )).getCustomModelDetails (eq (MODEL_NAME ));
195
+ assertThat (task .isComplete ()).isTrue ();
196
+ assertEquals (customModel , CUSTOM_MODEL );
197
+ }
198
+
199
+ @ Test
200
+ public void getModel_updateBackground_localExists_UpdateFound () throws Exception {
201
+ when (mockPrefs .getCustomModelDetails (eq (MODEL_NAME ))).thenReturn (CUSTOM_MODEL );
202
+ when (mockModelDownloadService .getCustomModelDetails (
203
+ eq (TEST_PROJECT_ID ), eq (MODEL_NAME ), eq (null )))
204
+ .thenReturn (Tasks .forResult (UPDATE_CUSTOM_MODEL_URL ));
205
+
206
+ TestOnCompleteListener <CustomModel > onCompleteListener = new TestOnCompleteListener <>();
207
+ Task <CustomModel > task =
208
+ firebaseModelDownloader .getModel (
209
+ MODEL_NAME , DownloadType .LOCAL_MODEL_UPDATE_IN_BACKGROUND , DOWNLOAD_CONDITIONS );
210
+ task .addOnCompleteListener (executor , onCompleteListener );
211
+ CustomModel customModel = onCompleteListener .await ();
212
+
213
+ verify (mockPrefs , times (1 )).getCustomModelDetails (eq (MODEL_NAME ));
214
+ assertThat (task .isComplete ()).isTrue ();
215
+ assertEquals (customModel , CUSTOM_MODEL );
216
+ }
217
+
218
+ @ Test
219
+ public void getModel_updateBackground_noLocalModel () throws Exception {
220
+ when (mockPrefs .getCustomModelDetails (eq (MODEL_NAME ))).thenReturn (null ).thenReturn (CUSTOM_MODEL );
221
+ when (mockModelDownloadService .getCustomModelDetails (
222
+ eq (TEST_PROJECT_ID ), eq (MODEL_NAME ), eq (null )))
223
+ .thenReturn (Tasks .forResult (CUSTOM_MODEL ));
224
+ when (mockFileDownloadService .download (any (), eq (DOWNLOAD_CONDITIONS )))
225
+ .thenReturn (Tasks .forResult (null ));
226
+ TestOnCompleteListener <CustomModel > onCompleteListener = new TestOnCompleteListener <>();
227
+ Task <CustomModel > task =
228
+ firebaseModelDownloader .getModel (
229
+ MODEL_NAME , DownloadType .LOCAL_MODEL_UPDATE_IN_BACKGROUND , DOWNLOAD_CONDITIONS );
230
+ task .addOnCompleteListener (executor , onCompleteListener );
231
+ CustomModel customModel = onCompleteListener .await ();
232
+
233
+ verify (mockPrefs , times (2 )).getCustomModelDetails (eq (MODEL_NAME ));
234
+ assertThat (task .isComplete ()).isTrue ();
235
+ assertEquals (customModel , CUSTOM_MODEL );
236
+ }
237
+
238
+ @ Test
239
+ public void getModel_updateBackground_noLocalModel_error () throws Exception {
240
+ when (mockPrefs .getCustomModelDetails (eq (MODEL_NAME ))).thenReturn (null ).thenReturn (CUSTOM_MODEL );
241
+ when (mockModelDownloadService .getCustomModelDetails (
242
+ eq (TEST_PROJECT_ID ), eq (MODEL_NAME ), eq (null )))
243
+ .thenReturn (Tasks .forResult (CUSTOM_MODEL ));
244
+ when (mockFileDownloadService .download (any (), eq (DOWNLOAD_CONDITIONS )))
245
+ .thenReturn (Tasks .forException (new Exception ("bad download" )));
246
+ TestOnCompleteListener <CustomModel > onCompleteListener = new TestOnCompleteListener <>();
247
+ Task <CustomModel > task =
248
+ firebaseModelDownloader .getModel (
249
+ MODEL_NAME , DownloadType .LOCAL_MODEL_UPDATE_IN_BACKGROUND , DOWNLOAD_CONDITIONS );
250
+ task .addOnCompleteListener (executor , onCompleteListener );
251
+ try {
252
+ onCompleteListener .await ();
253
+ } catch (Exception ex ) {
254
+ assertThat (ex .getMessage ().contains ("download failed" )).isTrue ();
255
+ }
256
+
257
+ verify (mockPrefs , times (1 )).getCustomModelDetails (eq (MODEL_NAME ));
258
+ assertThat (task .isComplete ()).isTrue ();
259
+ assertThat (task .isSuccessful ()).isFalse ();
100
260
}
101
261
102
262
@ Test
103
- public void getModel_localExists () throws Exception {
263
+ public void getModel_Local_localExists () throws Exception {
104
264
when (mockPrefs .getCustomModelDetails (eq (MODEL_NAME ))).thenReturn (CUSTOM_MODEL );
105
265
TestOnCompleteListener <CustomModel > onCompleteListener = new TestOnCompleteListener <>();
106
266
Task <CustomModel > task =
@@ -114,7 +274,7 @@ public void getModel_localExists() throws Exception {
114
274
}
115
275
116
276
@ Test
117
- public void getModel_noLocalModel () throws Exception {
277
+ public void getModel_local_noLocalModel () throws Exception {
118
278
when (mockPrefs .getCustomModelDetails (eq (MODEL_NAME ))).thenReturn (null ).thenReturn (CUSTOM_MODEL );
119
279
when (mockModelDownloadService .getCustomModelDetails (
120
280
eq (TEST_PROJECT_ID ), eq (MODEL_NAME ), eq (null )))
@@ -133,7 +293,7 @@ public void getModel_noLocalModel() throws Exception {
133
293
}
134
294
135
295
@ Test
136
- public void getModel_noLocalModel_error () throws Exception {
296
+ public void getModel_local_noLocalModel_error () throws Exception {
137
297
when (mockPrefs .getCustomModelDetails (eq (MODEL_NAME ))).thenReturn (null ).thenReturn (CUSTOM_MODEL );
138
298
when (mockModelDownloadService .getCustomModelDetails (
139
299
eq (TEST_PROJECT_ID ), eq (MODEL_NAME ), eq (null )))
0 commit comments