Skip to content

Commit a4d67e2

Browse files
authored
Android: Leverage prefillPrompt and prefillImage on Llava
Differential Revision: D62411342 Pull Request resolved: #5224
1 parent d38ca81 commit a4d67e2

File tree

2 files changed

+53
-46
lines changed

2 files changed

+53
-46
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import android.os.Bundle;
2020
import android.os.Handler;
2121
import android.os.Looper;
22+
import android.os.Process;
2223
import android.provider.MediaStore;
2324
import android.system.ErrnoException;
2425
import android.system.Os;
@@ -44,6 +45,8 @@
4445
import java.lang.reflect.Type;
4546
import java.util.ArrayList;
4647
import java.util.List;
48+
import java.util.concurrent.Executor;
49+
import java.util.concurrent.Executors;
4750
import org.pytorch.executorch.LlamaCallback;
4851
import org.pytorch.executorch.LlamaModule;
4952

@@ -71,15 +74,16 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa
7174
private Handler mMemoryUpdateHandler;
7275
private Runnable memoryUpdater;
7376
private int promptID = 0;
74-
77+
private long startPos = 0;
7578
private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2;
79+
private Executor executor;
7680

7781
@Override
7882
public void onResult(String result) {
7983
if (result.equals(PromptFormat.getStopToken(mCurrentSettingsFields.getModelType()))) {
8084
return;
8185
}
82-
if (result.equals("\n\n")) {
86+
if (result.equals("\n\n") || result.equals("\n")) {
8387
if (!mResultMessage.getText().isEmpty()) {
8488
mResultMessage.appendText(result);
8589
run();
@@ -150,6 +154,12 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera
150154
+ (float) loadDuration / 1000
151155
+ " sec."
152156
+ " You can send text or image for inference";
157+
158+
if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) {
159+
ETLogging.getInstance().log("Llava start prefill prompt");
160+
startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt(), 0, 1, 0);
161+
ETLogging.getInstance().log("Llava completes prefill prompt");
162+
}
153163
}
154164

155165
Message modelLoadedMessage = new Message(modelInfo, false, MessageType.SYSTEM, 0);
@@ -241,6 +251,7 @@ protected void onCreate(Bundle savedInstanceState) {
241251
setupCameraRoll();
242252
startMemoryUpdate();
243253
setupShowLogsButton();
254+
executor = Executors.newSingleThreadExecutor();
244255
}
245256

246257
@Override
@@ -546,6 +557,32 @@ private void showMediaPreview(List<Uri> uris) {
546557
imageViews.get(i).setVisibility(View.VISIBLE);
547558
imageViews.get(i).setImageURI(mSelectedImageUri.get(i));
548559
}
560+
561+
// For LLava, we want to call prefill_image as soon as an image is selected
562+
// Llava only support 1 image for now
563+
if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) {
564+
List<ETImage> processedImageList = getProcessedImagesForModel(mSelectedImageUri);
565+
if (!processedImageList.isEmpty()) {
566+
mMessageAdapter.add(
567+
new Message("Llava - Starting image Prefill.", false, MessageType.SYSTEM, 0));
568+
mMessageAdapter.notifyDataSetChanged();
569+
Runnable runnable =
570+
() -> {
571+
Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE);
572+
ETLogging.getInstance().log("Starting runnable prefill image");
573+
ETImage img = processedImageList.get(0);
574+
ETLogging.getInstance().log("Llava start prefill image");
575+
startPos =
576+
mModule.prefillImages(
577+
img.getInts(),
578+
img.getWidth(),
579+
img.getHeight(),
580+
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
581+
startPos);
582+
};
583+
executor.execute(runnable);
584+
}
585+
}
549586
}
550587

551588
private void addSelectedImagesToChatThread(List<Uri> selectedImageUri) {
@@ -618,24 +655,6 @@ private void onModelRunStopped() {
618655
mSendButton.setOnClickListener(
619656
view -> {
620657
addSelectedImagesToChatThread(mSelectedImageUri);
621-
List<ETImage> processedImageList = getProcessedImagesForModel(mSelectedImageUri);
622-
processedImageList.forEach(
623-
image -> {
624-
ETLogging.getInstance()
625-
.log(
626-
"Image preprocessed:"
627-
+ " uri = "
628-
+ image.getUri().getLastPathSegment()
629-
+ ","
630-
+ " width = "
631-
+ image.getWidth()
632-
+ ","
633-
+ " height = "
634-
+ image.getHeight()
635-
+ ","
636-
+ " bytes size = "
637-
+ image.getBytes().length);
638-
});
639658
String rawPrompt = mEditTextMessage.getText().toString();
640659
// We store raw prompt into message adapter, because we don't want to show the extra
641660
// tokens from system prompt
@@ -654,6 +673,8 @@ private void onModelRunStopped() {
654673
new Runnable() {
655674
@Override
656675
public void run() {
676+
Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE);
677+
ETLogging.getInstance().log("starting runnable generate()");
657678
runOnUiThread(
658679
new Runnable() {
659680
@Override
@@ -664,31 +685,12 @@ public void run() {
664685
long generateStartTime = System.currentTimeMillis();
665686
if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType())
666687
== ModelUtils.VISION_MODEL) {
667-
ETLogging.getInstance().log("Running inference.. prompt=" + rawPrompt);
668-
if (!processedImageList.isEmpty()) {
669-
// For now, Llava only support 1 image.
670-
ETImage img = processedImageList.get(0);
671-
mModule.generate(
672-
processedImageList.get(0).getInts(),
673-
img.getWidth(),
674-
img.getHeight(),
675-
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
676-
rawPrompt,
677-
ModelUtils.VISION_MODEL_SEQ_LEN,
678-
MainActivity.this,
679-
false);
680-
} else {
681-
// no image selected, we pass in empty int array
682-
mModule.generate(
683-
new int[0],
684-
0,
685-
0,
686-
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
687-
rawPrompt,
688-
ModelUtils.VISION_MODEL_SEQ_LEN,
689-
MainActivity.this,
690-
false);
691-
}
688+
mModule.generateFromPos(
689+
mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt),
690+
ModelUtils.VISION_MODEL_SEQ_LEN,
691+
startPos,
692+
MainActivity.this,
693+
false);
692694
} else {
693695
String finalPrompt =
694696
getTotalFormattedPrompt(getConversationHistory(), rawPrompt);
@@ -712,7 +714,7 @@ public void run() {
712714
ETLogging.getInstance().log("Inference completed");
713715
}
714716
};
715-
new Thread(runnable).start();
717+
executor.execute(runnable);
716718
});
717719
mMessageAdapter.notifyDataSetChanged();
718720
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,9 @@ public static String getStopToken(ModelType modelType) {
6666
return "";
6767
}
6868
}
69+
70+
public static String getLlavaPresetPrompt() {
71+
return "A chat between a curious human and an artificial intelligence assistant. The assistant"
72+
+ " gives helpful, detailed, and polite answers to the human's questions. USER: ";
73+
}
6974
}

0 commit comments

Comments
 (0)