Skip to content

Add model file clean-up on initial load for all models found in share… #2353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ public class ModelFileDownloadService {
private final SharedPreferencesUtil sharedPreferencesUtil;
private final FirebaseMlLogger eventLogger;

private boolean isInitialLoad;

@GuardedBy("this")
// Mapping from download id to broadcast receiver. Because models can update, we cannot just keep
// one instance of DownloadBroadcastReceiver per RemoteModelDownloadManager object.
Expand All @@ -87,6 +89,7 @@ public ModelFileDownloadService(
downloadManager = (DownloadManager) context.getSystemService(Context.DOWNLOAD_SERVICE);
this.fileManager = ModelFileManager.getInstance();
this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp);
this.isInitialLoad = true;
this.eventLogger = FirebaseMlLogger.getInstance();
}

Expand All @@ -96,12 +99,14 @@ public ModelFileDownloadService(
DownloadManager downloadManager,
ModelFileManager fileManager,
SharedPreferencesUtil sharedPreferencesUtil,
FirebaseMlLogger eventLogger) {
FirebaseMlLogger eventLogger,
boolean isInitialLoad) {
this.context = firebaseApp.getApplicationContext();
this.downloadManager = downloadManager;
this.fileManager = fileManager;
this.sharedPreferencesUtil = sharedPreferencesUtil;
this.eventLogger = eventLogger;
this.isInitialLoad = isInitialLoad;
}

/**
Expand Down Expand Up @@ -404,7 +409,8 @@ public File loadNewlyDownloadedModelFile(CustomModel model) {
new CustomModel(
model.getName(), model.getModelHash(), model.getSize(), 0, newModelFile.getPath()));

// todo(annzimmer) Cleans up the old files if it is the initial creation.
maybeCleanUpOldModels();

return newModelFile;
} else if (statusCode == DownloadManager.STATUS_FAILED) {
Log.d(TAG, "Model downloaded failed.");
Expand All @@ -418,6 +424,24 @@ public File loadNewlyDownloadedModelFile(CustomModel model) {
return null;
}

private Task<Void> maybeCleanUpOldModels() {
if (!isInitialLoad) {
return Tasks.forResult(null);
}

// only do once per initialization.
isInitialLoad = false;

// for each custom model directory, find out the latest model and delete the other files.
// If no corresponding model, clean up the full directory.
try {
fileManager.deleteNonLatestCustomModels();
} catch (FirebaseMlException fex) {
Log.d(TAG, "Failed to clean up old models.");
}
return Tasks.forResult(null);
}

private FirebaseMlException getExceptionAccordingToDownloadManager(Long downloadId) {
int errorCode = FirebaseMlException.INTERNAL;
String errorMessage = "Model downloading failed";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import androidx.annotation.Nullable;
import androidx.annotation.VisibleForTesting;
import androidx.annotation.WorkerThread;
import com.google.android.gms.common.internal.Preconditions;
import com.google.firebase.FirebaseApp;
import com.google.firebase.ml.modeldownloader.CustomModel;
import com.google.firebase.ml.modeldownloader.FirebaseMlException;
Expand All @@ -44,10 +43,12 @@ public class ModelFileManager {
private static final int INVALID_INDEX = -1;
private final Context context;
private final FirebaseApp firebaseApp;
private final SharedPreferencesUtil sharedPreferencesUtil;

public ModelFileManager(@NonNull FirebaseApp firebaseApp) {
this.context = firebaseApp.getApplicationContext();
this.firebaseApp = firebaseApp;
this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp);
}

/**
Expand All @@ -61,6 +62,23 @@ public static ModelFileManager getInstance() {
return FirebaseApp.getInstance().get(ModelFileManager.class);
}

void deleteNonLatestCustomModels() throws FirebaseMlException {
File root = getDirImpl("");

boolean ret = true;
if (root.isDirectory()) {
for (File f : root.listFiles()) {
// for each custom model sub directory - extract customModelName and clean up old models.
String modelName = f.getName();

CustomModel model = sharedPreferencesUtil.getCustomModelDetails(modelName);
if (model != null) {
deleteOldModels(modelName, model.getLocalFilePath());
}
}
}
}

/**
* Get the directory where the model is supposed to reside. This method does not ensure that the
* directory specified does exist. If you need to ensure its existence, you should call
Expand Down Expand Up @@ -178,6 +196,40 @@ public synchronized File moveModelToDestinationFolder(
return modelFileDestination;
}

/**
* Deletes old models in the custom model directory, except the {@code latestModelFilePath}. This
* should only be called when no files are in use or more specifically when the first
* initialization, otherwise it may remove a model that is in use.
*
* @param latestModelFilePath The file path to the latest custom model.
*/
@WorkerThread
public synchronized void deleteOldModels(
@NonNull String modelName, @NonNull String latestModelFilePath) {
File modelFolder = getModelDirUnsafe(modelName);
if (!modelFolder.exists()) {
return;
}

File latestFile = new File(latestModelFilePath);
int latestIndex = Integer.parseInt(latestFile.getName());
File[] modelFiles = modelFolder.listFiles();

boolean isAllDeleted = true;
int fileInt;
for (File modelFile : modelFiles) {
try {
fileInt = Integer.parseInt(modelFile.getName());
} catch (NumberFormatException ex) {
// unexpected file - ignore
fileInt = Integer.MAX_VALUE;
}
if (fileInt < latestIndex) {
isAllDeleted = isAllDeleted && modelFile.delete();
}
}
}

/**
* Deletes all previously cached Model File(s) and the model root folder.
*
Expand All @@ -203,7 +255,7 @@ boolean deleteRecursively(@Nullable File root) {

boolean ret = true;
if (root.isDirectory()) {
for (File f : Preconditions.checkNotNull(root.listFiles())) {
for (File f : root.listFiles()) {
ret = ret && deleteRecursively(f);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ public class ModelFileDownloadServiceTest {
File testAppModelFile;

private ModelFileDownloadService modelFileDownloadService;

private ModelFileDownloadService modelFileDownloadServiceInitialLoad;
private SharedPreferencesUtil sharedPreferencesUtil;
@Mock DownloadManager mockDownloadManager;
@Mock ModelFileManager mockFileManager;
Expand All @@ -116,7 +118,21 @@ public void setUp() throws IOException {

modelFileDownloadService =
new ModelFileDownloadService(
app, mockDownloadManager, mockFileManager, sharedPreferencesUtil, mockStatsLogger);
app,
mockDownloadManager,
mockFileManager,
sharedPreferencesUtil,
mockStatsLogger,
false);

modelFileDownloadServiceInitialLoad =
new ModelFileDownloadService(
app,
mockDownloadManager,
mockFileManager,
sharedPreferencesUtil,
mockStatsLogger,
true);

matrixCursor = new MatrixCursor(new String[] {DownloadManager.COLUMN_STATUS});
testTempModelFile = File.createTempFile("fakeTempFile", ".tflite");
Expand Down Expand Up @@ -519,7 +535,6 @@ public void ensureModelDownloaded_alreadyInProgess_UrlExpired() throws Exception
when(mockDownloadManager.enqueue(any())).thenReturn(DOWNLOAD_ID).thenReturn(downloadId2);

// first download will get cancelled and cleaned - up before intent is sent.

// Complete the second download
Intent downloadCompleteIntent = new Intent(DownloadManager.ACTION_DOWNLOAD_COMPLETE);
downloadCompleteIntent.putExtra(DownloadManager.EXTRA_DOWNLOAD_ID, downloadId2);
Expand Down Expand Up @@ -761,6 +776,7 @@ public void loadNewlyDownloadedModelFile_successFilePresent()
CustomModel retrievedModel = sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME);
assertEquals(retrievedModel, customModelDownloadComplete);
verify(mockDownloadManager, times(1)).remove(anyLong());
verify(mockFileManager, never()).deleteNonLatestCustomModels();
verify(mockStatsLogger, times(1))
.logDownloadEventWithErrorCode(
eq(CUSTOM_MODEL_DOWNLOADING),
Expand All @@ -770,7 +786,8 @@ public void loadNewlyDownloadedModelFile_successFilePresent()
}

@Test
public void loadNewlyDownloadedModelFile_successNoFile() throws FileNotFoundException {
public void loadNewlyDownloadedModelFile_successNoFile()
throws FileNotFoundException, FirebaseMlException {
// Not found
assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L));
matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_SUCCESSFUL});
Expand All @@ -783,6 +800,7 @@ public void loadNewlyDownloadedModelFile_successNoFile() throws FileNotFoundExce
assertNull(modelFileDownloadService.loadNewlyDownloadedModelFile(CUSTOM_MODEL_DOWNLOADING));
assertNull(sharedPreferencesUtil.getDownloadingCustomModelDetails(MODEL_NAME));
verify(mockDownloadManager, times(1)).remove(anyLong());
verify(mockFileManager, never()).deleteNonLatestCustomModels();
verify(mockStatsLogger, times(1))
.logDownloadEventWithErrorCode(
eq(CUSTOM_MODEL_DOWNLOADING),
Expand All @@ -792,20 +810,21 @@ public void loadNewlyDownloadedModelFile_successNoFile() throws FileNotFoundExce
}

@Test
public void loadNewlyDownloadedModelFile_Running() {
public void loadNewlyDownloadedModelFile_Running() throws FirebaseMlException {
// Not found
assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L));
matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_RUNNING});
when(mockDownloadManager.query(any())).thenReturn(matrixCursor);
assertNull(modelFileDownloadService.loadNewlyDownloadedModelFile(CUSTOM_MODEL_DOWNLOADING));
assertNull(sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME));
verify(mockDownloadManager, never()).remove(anyLong());
verify(mockFileManager, never()).deleteNonLatestCustomModels();
verify(mockStatsLogger, never())
.logDownloadEventWithErrorCode(any(), anyBoolean(), any(), any());
}

@Test
public void loadNewlyDownloadedModelFile_Failed() {
public void loadNewlyDownloadedModelFile_Failed() throws FirebaseMlException {
// Not found
assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L));
matrixCursor =
Expand All @@ -820,10 +839,44 @@ public void loadNewlyDownloadedModelFile_Failed() {
assertNull(modelFileDownloadService.loadNewlyDownloadedModelFile(CUSTOM_MODEL_DOWNLOADING));
assertNull(sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME));
verify(mockDownloadManager, times(1)).remove(anyLong());
verify(mockStatsLogger, never())
.logDownloadEventWithErrorCode(any(), anyBoolean(), any(), any());
verify(mockStatsLogger, times(1))
.logDownloadFailureWithReason(
eq(CUSTOM_MODEL_DOWNLOADING), eq(false), eq(DownloadManager.ERROR_INSUFFICIENT_SPACE));
verify(mockFileManager, never()).deleteNonLatestCustomModels();
}

@Test
public void loadNewlyDownloadedModelFile_initialLoad_successFilePresent()
throws FirebaseMlException, FileNotFoundException {
// Not found
assertNull(modelFileDownloadServiceInitialLoad.getDownloadingModelStatusCode(0L));
matrixCursor.addRow(new Integer[] {DownloadManager.STATUS_SUCCESSFUL});
when(mockDownloadManager.query(any())).thenReturn(matrixCursor);
when(mockDownloadManager.openDownloadedFile(anyLong()))
.thenReturn(
ParcelFileDescriptor.open(testTempModelFile, ParcelFileDescriptor.MODE_READ_ONLY));
when(mockDownloadManager.remove(anyLong())).thenReturn(1);

when(mockFileManager.moveModelToDestinationFolder(any(), any())).thenReturn(testAppModelFile);

assertEquals(
modelFileDownloadServiceInitialLoad.loadNewlyDownloadedModelFile(CUSTOM_MODEL_DOWNLOADING),
testAppModelFile);

// second attempt should not call deleteNonLatestCustomModels a second time.
assertEquals(
modelFileDownloadServiceInitialLoad.loadNewlyDownloadedModelFile(CUSTOM_MODEL_DOWNLOADING),
testAppModelFile);

CustomModel retrievedModel = sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME);
assertEquals(retrievedModel, customModelDownloadComplete);
verify(mockDownloadManager, times(2)).remove(anyLong());
verify(mockFileManager, times(1)).deleteNonLatestCustomModels();
verify(mockStatsLogger, times(2))
.logDownloadEventWithErrorCode(
eq(CUSTOM_MODEL_DOWNLOADING),
eq(true),
eq(DownloadStatus.SUCCEEDED),
eq(ErrorCode.NO_ERROR));
}
}
Loading