|
29 | 29 | import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService;
|
30 | 30 | import com.google.firebase.ml.modeldownloader.internal.ModelFileManager;
|
31 | 31 | import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
|
| 32 | +import java.io.File; |
32 | 33 | import java.util.Set;
|
33 | 34 | import java.util.concurrent.Executor;
|
34 | 35 | import java.util.concurrent.Executors;
|
@@ -120,47 +121,89 @@ public Task<CustomModel> getModel(
|
120 | 121 | @Nullable CustomModelDownloadConditions conditions)
|
121 | 122 | throws Exception {
|
122 | 123 | CustomModel localModel = sharedPreferencesUtil.getCustomModelDetails(modelName);
|
| 124 | + if (localModel == null) { |
| 125 | + // no local model - get latest. |
| 126 | + return getCustomModelTask(modelName, conditions); |
| 127 | + } |
| 128 | + |
123 | 129 | switch (downloadType) {
|
124 | 130 | case LOCAL_MODEL:
|
125 |
| - if (localModel != null) { |
126 |
| - return Tasks.forResult(localModel); |
127 |
| - } |
128 |
| - Task<CustomModel> modelDetails = |
129 |
| - modelDownloadService.getCustomModelDetails( |
130 |
| - firebaseOptions.getProjectId(), modelName, null); |
131 |
| - |
132 |
| - // no local model - start download. |
133 |
| - return modelDetails.continueWithTask( |
134 |
| - executor, |
135 |
| - modelDetailTask -> { |
136 |
| - if (modelDetailTask.isSuccessful()) { |
137 |
| - // start download |
138 |
| - return fileDownloadService |
139 |
| - .download(modelDetailTask.getResult(), conditions) |
140 |
| - .continueWithTask( |
141 |
| - executor, |
142 |
| - downloadTask -> { |
143 |
| - if (downloadTask.isSuccessful()) { |
144 |
| - // read the updated model |
145 |
| - CustomModel downloadedModel = |
146 |
| - sharedPreferencesUtil.getCustomModelDetails(modelName); |
147 |
| - // TODO(annz) trigger file move here as well... right now it's temp |
148 |
| - // call loadNewlyDownloadedModelFile |
149 |
| - return Tasks.forResult(downloadedModel); |
150 |
| - } |
151 |
| - return Tasks.forException(new Exception("File download failed.")); |
152 |
| - }); |
153 |
| - } |
154 |
| - return Tasks.forException(modelDetailTask.getException()); |
155 |
| - }); |
| 131 | + return Tasks.forResult(localModel); |
156 | 132 | case LATEST_MODEL:
|
157 |
| - // check for latest model and download newest |
158 |
| - break; |
| 133 | + // check for latest model, wait for download if newer model exists |
| 134 | + return getCustomModelTask(modelName, conditions, localModel.getModelHash()); |
159 | 135 | case LOCAL_MODEL_UPDATE_IN_BACKGROUND:
|
160 |
| - // start download in back ground return current model if not null. |
161 |
| - break; |
| 136 | + // start download in back ground, return local model |
| 137 | + getCustomModelTask(modelName, conditions, localModel.getModelHash()); |
| 138 | + return Tasks.forResult(localModel); |
162 | 139 | }
|
163 |
| - throw new UnsupportedOperationException("Not yet implemented."); |
| 140 | + throw new IllegalArgumentException( |
| 141 | + "Unsupported downloadType, please chose LOCAL_MODEL, LATEST_MODEL, or LOCAL_MODEL_UPDATE_IN_BACKGROUND"); |
| 142 | + } |
| 143 | + |
| 144 | + // This version of getCustomModelTask will always call the modelDownloadService and upon |
| 145 | + // success will then trigger file download. |
| 146 | + private Task<CustomModel> getCustomModelTask( |
| 147 | + @NonNull String modelName, @Nullable CustomModelDownloadConditions conditions) |
| 148 | + throws Exception { |
| 149 | + return getCustomModelTask(modelName, conditions, null); |
| 150 | + } |
| 151 | + |
| 152 | + // This version of getCustomModelTask will call the modelDownloadService and upon |
| 153 | + // success will only trigger file download, if there is a new model hash value. |
| 154 | + private Task<CustomModel> getCustomModelTask( |
| 155 | + @NonNull String modelName, |
| 156 | + @Nullable CustomModelDownloadConditions conditions, |
| 157 | + @Nullable String modelHash) |
| 158 | + throws Exception { |
| 159 | + Task<CustomModel> incomingModelDetails = |
| 160 | + modelDownloadService.getCustomModelDetails( |
| 161 | + firebaseOptions.getProjectId(), modelName, modelHash); |
| 162 | + |
| 163 | + return incomingModelDetails.continueWithTask( |
| 164 | + executor, |
| 165 | + incomingModelDetailTask -> { |
| 166 | + if (incomingModelDetailTask.isSuccessful()) { |
| 167 | + CustomModel currentModel = sharedPreferencesUtil.getCustomModelDetails(modelName); |
| 168 | + // null means we have the latest model |
| 169 | + if (incomingModelDetailTask.getResult() == null) { |
| 170 | + return Tasks.forResult(currentModel); |
| 171 | + } |
| 172 | + |
| 173 | + // if modelHash matches current local model just return local model. |
| 174 | + // Should be handled by above case but just in case. |
| 175 | + if (currentModel != null |
| 176 | + && currentModel |
| 177 | + .getModelHash() |
| 178 | + .equals(incomingModelDetails.getResult().getModelHash())) { |
| 179 | + if (!currentModel.getLocalFilePath().isEmpty() |
| 180 | + && new File(currentModel.getLocalFilePath()).exists()) { |
| 181 | + return Tasks.forResult(currentModel); |
| 182 | + } |
| 183 | + // todo(annzimmer) this shouldn't happen unless they are calling the sdk with multiple |
| 184 | + // sets of download types/conditions. |
| 185 | + // this should be a download in progress - add appropriate handling. |
| 186 | + } |
| 187 | + |
| 188 | + // start download |
| 189 | + return fileDownloadService |
| 190 | + .download(incomingModelDetailTask.getResult(), conditions) |
| 191 | + .continueWithTask( |
| 192 | + executor, |
| 193 | + downloadTask -> { |
| 194 | + if (downloadTask.isSuccessful()) { |
| 195 | + // read the updated model |
| 196 | + CustomModel downloadedModel = |
| 197 | + sharedPreferencesUtil.getCustomModelDetails(modelName); |
| 198 | + // trigger the file to be moved to permanent location. |
| 199 | + fileDownloadService.loadNewlyDownloadedModelFile(downloadedModel); |
| 200 | + return Tasks.forResult(downloadedModel); |
| 201 | + } |
| 202 | + return Tasks.forException(new Exception("File download failed.")); |
| 203 | + }); |
| 204 | + } |
| 205 | + return Tasks.forException(incomingModelDetailTask.getException()); |
| 206 | + }); |
164 | 207 | }
|
165 | 208 |
|
166 | 209 | /**
|
|
0 commit comments