Skip to content

Commit 4285659

Browse files
authored
Add Ml shared preferences for custom model storage (#2104)
* Refactor FirebaseMlModel* class names to FirebaseModel*. * Create Storage for CustomModel Details. * Update shared preference calls. * Fix formatting after merge. * Updating based on reviewer comments.
1 parent 844ff28 commit 4285659

File tree

6 files changed

+426
-12
lines changed

6 files changed

+426
-12
lines changed

firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dependencies {
4747
implementation project(':firebase-components')
4848

4949
implementation 'com.google.android.gms:play-services-tasks:17.2.0'
50+
implementation 'javax.inject:javax.inject:1'
5051

5152
compileOnly "com.google.auto.value:auto-value-annotations:1.6.6"
5253
annotationProcessor "com.google.auto.value:auto-value:1.6.5"

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModel.java

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import androidx.annotation.NonNull;
1818
import androidx.annotation.Nullable;
19+
import com.google.android.gms.common.internal.Objects;
1920
import java.io.File;
2021

2122
/**
@@ -28,7 +29,44 @@ public class CustomModel {
2829
private final long downloadId;
2930
private final long fileSize;
3031
private final String modelHash;
31-
private final String localFilePath = "";
32+
private final String localFilePath;
33+
34+
/**
35+
* Use when creating a custom model while the initial download is still in progress.
36+
*
37+
* @param name - model name
38+
* @param downloadId - Android Download Manger - download id
39+
* @param fileSize - model file size
40+
* @param modelHash - model hash size
41+
* @hide
42+
*/
43+
public CustomModel(
44+
@NonNull String name, long downloadId, long fileSize, @NonNull String modelHash) {
45+
this(name, downloadId, fileSize, modelHash, "");
46+
}
47+
48+
/**
49+
* Use when creating a custom model while the initial download is still in progress.
50+
*
51+
* @param name - model name
52+
* @param downloadId - Android Download Manger - download id
53+
* @param fileSize - model file size
54+
* @param modelHash - model hash size
55+
* @param localFilePath - location of the current file
56+
* @hide
57+
*/
58+
public CustomModel(
59+
@NonNull String name,
60+
long downloadId,
61+
long fileSize,
62+
@NonNull String modelHash,
63+
@NonNull String localFilePath) {
64+
this.modelHash = modelHash;
65+
this.name = name;
66+
this.fileSize = fileSize;
67+
this.downloadId = downloadId;
68+
this.localFilePath = localFilePath;
69+
}
3270

3371
@NonNull
3472
public String getName() {
@@ -76,18 +114,36 @@ public long getDownloadId() {
76114
return downloadId;
77115
}
78116

117+
@Override
118+
public boolean equals(Object o) {
119+
if (o == this) {
120+
return true;
121+
}
122+
123+
if (!(o instanceof CustomModel)) {
124+
return false;
125+
}
126+
127+
CustomModel other = (CustomModel) o;
128+
129+
return Objects.equal(name, other.name)
130+
&& Objects.equal(modelHash, other.modelHash)
131+
&& Objects.equal(fileSize, other.fileSize)
132+
&& Objects.equal(localFilePath, other.localFilePath)
133+
&& Objects.equal(downloadId, other.downloadId);
134+
}
135+
136+
@Override
137+
public int hashCode() {
138+
return Objects.hashCode(name, modelHash, fileSize, localFilePath, downloadId);
139+
}
140+
79141
/**
80-
* Use when creating a custom model while the initial download is still in progress.
81-
*
82-
* @param name - model name
83-
* @param downloadId - Android Download Manger - download id
84-
* @param fileSize - model file size
85-
* @param modelHash - model hash size
142+
* @return the model file path
143+
* @hide
86144
*/
87-
CustomModel(String name, long downloadId, long fileSize, String modelHash) {
88-
this.modelHash = modelHash;
89-
this.name = name;
90-
this.fileSize = fileSize;
91-
this.downloadId = downloadId;
145+
@NonNull
146+
public String getLocalFilePath() {
147+
return localFilePath;
92148
}
93149
}

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package com.google.firebase.ml.modeldownloader;
1616

1717
import androidx.annotation.NonNull;
18+
import com.google.firebase.FirebaseApp;
1819
import com.google.firebase.FirebaseOptions;
1920
import com.google.firebase.components.Component;
2021
import com.google.firebase.components.ComponentRegistrar;
@@ -36,6 +37,7 @@ public class FirebaseModelDownloaderRegistrar implements ComponentRegistrar {
3637
public List<Component<?>> getComponents() {
3738
return Arrays.asList(
3839
Component.builder(FirebaseModelDownloader.class)
40+
.add(Dependency.required(FirebaseApp.class))
3941
.add(Dependency.required(FirebaseOptions.class))
4042
.factory(c -> new FirebaseModelDownloader(c.get(FirebaseOptions.class)))
4143
.build(),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
// Copyright 2020 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package com.google.firebase.ml.modeldownloader.internal;
16+
17+
import android.content.Context;
18+
import android.content.SharedPreferences;
19+
import android.content.SharedPreferences.Editor;
20+
import android.os.SystemClock;
21+
import androidx.annotation.NonNull;
22+
import androidx.annotation.Nullable;
23+
import androidx.annotation.VisibleForTesting;
24+
import com.google.firebase.FirebaseApp;
25+
import com.google.firebase.ml.modeldownloader.CustomModel;
26+
27+
/** @hide */
28+
public class SharedPreferencesUtil {
29+
30+
@VisibleForTesting
31+
static final String PREFERENCES_PACKAGE_NAME = "com.google.firebase.ml.modelDownloader";
32+
33+
// local model details
34+
private static final String LOCAL_MODEL_HASH_PATTERN = "current_model_hash_%s_%s";
35+
private static final String LOCAL_MODEL_FILE_PATH_PATTERN = "current_model_path_%s_%s";
36+
private static final String LOCAL_MODEL_FILE_SIZE_PATTERN = "current_model_size_%s_%s";
37+
// details about model during download.
38+
private static final String DOWNLOADING_MODEL_HASH_PATTERN = "downloading_model_hash_%s_%s";
39+
private static final String DOWNLOADING_MODEL_SIZE_PATTERN = "downloading_model_size_%s_%s";
40+
private static final String DOWNLOADING_MODEL_ID_PATTERN = "downloading_model_id_%s_%s";
41+
private static final String DOWNLOAD_BEGIN_TIME_MS_PATTERN = "downloading_begin_time_%s_%s";
42+
43+
private final String persistenceKey;
44+
private final FirebaseApp firebaseApp;
45+
46+
public SharedPreferencesUtil(FirebaseApp firebaseApp) {
47+
this.firebaseApp = firebaseApp;
48+
this.persistenceKey = firebaseApp.getPersistenceKey();
49+
}
50+
51+
/**
52+
* Returns the Custom Model details currently associated with this model. If a fully downloaded
53+
* model is present - this returns the details of that model, including local file path. If an
54+
* update of an existing model is in progress, the local model plus the download id for the new
55+
* upload is returned. To get only details related to the downloading model use {@link
56+
* #getDownloadingCustomModelDetails}. If this is the initial download of a local file - the
57+
* downloading model details are returned.
58+
*
59+
* @param modelName - name of the model
60+
* @return current version of the Custom Model
61+
*/
62+
@Nullable
63+
public synchronized CustomModel getCustomModelDetails(@NonNull String modelName) {
64+
String modelHash =
65+
getSharedPreferences()
66+
.getString(String.format(LOCAL_MODEL_HASH_PATTERN, persistenceKey, modelName), null);
67+
68+
if (modelHash == null || modelHash.isEmpty()) {
69+
// no model downloaded - check if model is being downloaded.
70+
return getDownloadingCustomModelDetails(modelName);
71+
}
72+
73+
String filePath =
74+
getSharedPreferences()
75+
.getString(String.format(LOCAL_MODEL_FILE_PATH_PATTERN, persistenceKey, modelName), "");
76+
77+
long fileSize =
78+
getSharedPreferences()
79+
.getLong(String.format(LOCAL_MODEL_FILE_SIZE_PATTERN, persistenceKey, modelName), 0);
80+
81+
// if no-zero - local model is present and new model being downloaded
82+
long id =
83+
getSharedPreferences()
84+
.getLong(String.format(DOWNLOADING_MODEL_ID_PATTERN, persistenceKey, modelName), 0);
85+
86+
return new CustomModel(modelName, id, fileSize, modelHash, filePath);
87+
}
88+
89+
/**
90+
* Returns the Custom Model details associated with this version of this model currently being
91+
* downloaded. If no download is in progress return null. Contains no information about local
92+
* model, only download status.
93+
*
94+
* @param modelName name of the model
95+
* @return Download version of CustomModel
96+
*/
97+
@Nullable
98+
public synchronized CustomModel getDownloadingCustomModelDetails(@NonNull String modelName) {
99+
String modelHash =
100+
getSharedPreferences()
101+
.getString(
102+
String.format(DOWNLOADING_MODEL_HASH_PATTERN, persistenceKey, modelName), null);
103+
104+
if (modelHash == null || modelHash.isEmpty()) {
105+
// no model hash means no download in progress
106+
return null;
107+
}
108+
109+
long fileSize =
110+
getSharedPreferences()
111+
.getLong(String.format(DOWNLOADING_MODEL_SIZE_PATTERN, persistenceKey, modelName), 0);
112+
113+
long id =
114+
getSharedPreferences()
115+
.getLong(String.format(DOWNLOADING_MODEL_ID_PATTERN, persistenceKey, modelName), 0);
116+
117+
return new CustomModel(modelName, id, fileSize, modelHash);
118+
}
119+
120+
/**
121+
* The information about the new custom model download that need to be stored.
122+
*
123+
* @param customModel custom model details to be stored.
124+
*/
125+
public synchronized void setDownloadingCustomModelDetails(@NonNull CustomModel customModel) {
126+
String modelName = customModel.getName();
127+
String modelHash = customModel.getModelHash();
128+
long downloadId = customModel.getDownloadId();
129+
long modelSize = customModel.getSize();
130+
getSharedPreferences()
131+
.edit()
132+
.putString(
133+
String.format(DOWNLOADING_MODEL_HASH_PATTERN, persistenceKey, modelName), modelHash)
134+
.putLong(
135+
String.format(DOWNLOADING_MODEL_SIZE_PATTERN, persistenceKey, modelName), modelSize)
136+
.putLong(String.format(DOWNLOADING_MODEL_ID_PATTERN, persistenceKey, modelName), downloadId)
137+
// The following assumes the download will finish before the system reboots.
138+
// If not, the download duration won't be correct, which isn't critical.
139+
.putLong(
140+
String.format(DOWNLOAD_BEGIN_TIME_MS_PATTERN, persistenceKey, modelName),
141+
SystemClock.elapsedRealtime())
142+
.commit();
143+
}
144+
145+
/**
146+
* The information about a completed custom model download. Updates the local model information
147+
* and clears the download details associated with this model.
148+
*
149+
* @param customModel custom model details to be stored.
150+
*/
151+
public synchronized void setUploadedCustomModelDetails(@NonNull CustomModel customModel)
152+
throws IllegalArgumentException {
153+
Long id = customModel.getDownloadId();
154+
// only call when download is completed and download id is reset to 0;
155+
if (!id.equals(0L)) {
156+
throw new IllegalArgumentException("Only call when Custom model has completed download.");
157+
}
158+
Editor editor = getSharedPreferences().edit();
159+
clearDownloadingModelDetails(editor, customModel.getName());
160+
161+
String modelName = customModel.getName();
162+
String hash = customModel.getModelHash();
163+
long size = customModel.getSize();
164+
String filePath = customModel.getLocalFilePath();
165+
editor
166+
.putString(String.format(LOCAL_MODEL_HASH_PATTERN, persistenceKey, modelName), hash)
167+
.putLong(String.format(LOCAL_MODEL_FILE_SIZE_PATTERN, persistenceKey, modelName), size)
168+
.putString(
169+
String.format(LOCAL_MODEL_FILE_PATH_PATTERN, persistenceKey, modelName), filePath)
170+
.commit();
171+
}
172+
173+
/**
174+
* Clears all stored data related to a local custom model, including download details.
175+
*
176+
* @param modelName - name of model
177+
*/
178+
public synchronized void clearModelDetails(@NonNull String modelName, boolean cleanUpModelFile) {
179+
if (cleanUpModelFile) {
180+
// TODO(annz) - add code to remove model files from device
181+
}
182+
Editor editor = getSharedPreferences().edit();
183+
184+
clearDownloadingModelDetails(editor, modelName);
185+
186+
editor
187+
.remove(String.format(LOCAL_MODEL_FILE_PATH_PATTERN, persistenceKey, modelName))
188+
.remove(String.format(LOCAL_MODEL_FILE_SIZE_PATTERN, persistenceKey, modelName))
189+
.remove(String.format(LOCAL_MODEL_HASH_PATTERN, persistenceKey, modelName))
190+
.commit();
191+
}
192+
193+
/**
194+
* Clears all stored data related to a custom model download.
195+
*
196+
* @param modelName - name of model
197+
*/
198+
@VisibleForTesting
199+
synchronized void clearDownloadingModelDetails(Editor editor, @NonNull String modelName) {
200+
editor
201+
.remove(String.format(DOWNLOADING_MODEL_ID_PATTERN, persistenceKey, modelName))
202+
.remove(String.format(DOWNLOADING_MODEL_HASH_PATTERN, persistenceKey, modelName))
203+
.remove(String.format(DOWNLOADING_MODEL_SIZE_PATTERN, persistenceKey, modelName))
204+
.remove(String.format(DOWNLOAD_BEGIN_TIME_MS_PATTERN, persistenceKey, modelName))
205+
.apply();
206+
}
207+
208+
@VisibleForTesting
209+
SharedPreferences getSharedPreferences() {
210+
return firebaseApp
211+
.getApplicationContext()
212+
.getSharedPreferences(PREFERENCES_PACKAGE_NAME, Context.MODE_PRIVATE);
213+
}
214+
}

firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
package com.google.firebase.ml.modeldownloader;
1616

1717
import static org.junit.Assert.assertEquals;
18+
import static org.junit.Assert.assertFalse;
19+
import static org.junit.Assert.assertNotEquals;
1820
import static org.junit.Assert.assertNull;
21+
import static org.junit.Assert.assertTrue;
1922

2023
import org.junit.Test;
2124
import org.junit.runner.RunWith;
@@ -52,4 +55,21 @@ public void customModel_getDownloadId() {
5255
public void customModel_getFile_downloadIncomplete() {
5356
assertNull(CUSTOM_MODEL.getFile());
5457
}
58+
59+
@Test
60+
public void customModel_equals() {
61+
assertTrue(CUSTOM_MODEL.equals(new CustomModel(MODEL_NAME, 0, 100, MODEL_HASH)));
62+
assertFalse(CUSTOM_MODEL.equals(new CustomModel(MODEL_NAME, 0, 101, MODEL_HASH)));
63+
assertFalse(CUSTOM_MODEL.equals(new CustomModel(MODEL_NAME, 101, 100, MODEL_HASH)));
64+
}
65+
66+
@Test
67+
public void customModel_hashCode() {
68+
assertEquals(
69+
CUSTOM_MODEL.hashCode(), new CustomModel(MODEL_NAME, 0, 100, MODEL_HASH).hashCode());
70+
assertNotEquals(
71+
CUSTOM_MODEL.hashCode(), new CustomModel(MODEL_NAME, 0, 101, MODEL_HASH).hashCode());
72+
assertNotEquals(
73+
CUSTOM_MODEL.hashCode(), new CustomModel(MODEL_NAME, 101, 100, MODEL_HASH).hashCode());
74+
}
5575
}

0 commit comments

Comments
 (0)