19
19
import android .os .Bundle ;
20
20
import android .os .Handler ;
21
21
import android .os .Looper ;
22
+ import android .os .Process ;
22
23
import android .provider .MediaStore ;
23
24
import android .system .ErrnoException ;
24
25
import android .system .Os ;
44
45
import java .lang .reflect .Type ;
45
46
import java .util .ArrayList ;
46
47
import java .util .List ;
48
+ import java .util .concurrent .Executor ;
49
+ import java .util .concurrent .Executors ;
47
50
import org .pytorch .executorch .LlamaCallback ;
48
51
import org .pytorch .executorch .LlamaModule ;
49
52
@@ -71,15 +74,16 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa
71
74
private Handler mMemoryUpdateHandler ;
72
75
private Runnable memoryUpdater ;
73
76
private int promptID = 0 ;
74
-
77
+ private long startPos = 0 ;
75
78
private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2 ;
79
+ private Executor executor ;
76
80
77
81
@ Override
78
82
public void onResult (String result ) {
79
83
if (result .equals (PromptFormat .getStopToken (mCurrentSettingsFields .getModelType ()))) {
80
84
return ;
81
85
}
82
- if (result .equals ("\n \n " )) {
86
+ if (result .equals ("\n \n " ) || result . equals ( " \n " ) ) {
83
87
if (!mResultMessage .getText ().isEmpty ()) {
84
88
mResultMessage .appendText (result );
85
89
run ();
@@ -150,6 +154,12 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera
150
154
+ (float ) loadDuration / 1000
151
155
+ " sec."
152
156
+ " You can send text or image for inference" ;
157
+
158
+ if (mCurrentSettingsFields .getModelType () == ModelType .LLAVA_1_5 ) {
159
+ ETLogging .getInstance ().log ("Llava start prefill prompt" );
160
+ startPos = mModule .prefillPrompt (PromptFormat .getLlavaPresetPrompt (), 0 , 1 , 0 );
161
+ ETLogging .getInstance ().log ("Llava completes prefill prompt" );
162
+ }
153
163
}
154
164
155
165
Message modelLoadedMessage = new Message (modelInfo , false , MessageType .SYSTEM , 0 );
@@ -241,6 +251,7 @@ protected void onCreate(Bundle savedInstanceState) {
241
251
setupCameraRoll ();
242
252
startMemoryUpdate ();
243
253
setupShowLogsButton ();
254
+ executor = Executors .newSingleThreadExecutor ();
244
255
}
245
256
246
257
@ Override
@@ -546,6 +557,32 @@ private void showMediaPreview(List<Uri> uris) {
546
557
imageViews .get (i ).setVisibility (View .VISIBLE );
547
558
imageViews .get (i ).setImageURI (mSelectedImageUri .get (i ));
548
559
}
560
+
561
+ // For LLava, we want to call prefill_image as soon as an image is selected
562
+ // Llava only support 1 image for now
563
+ if (mCurrentSettingsFields .getModelType () == ModelType .LLAVA_1_5 ) {
564
+ List <ETImage > processedImageList = getProcessedImagesForModel (mSelectedImageUri );
565
+ if (!processedImageList .isEmpty ()) {
566
+ mMessageAdapter .add (
567
+ new Message ("Llava - Starting image Prefill." , false , MessageType .SYSTEM , 0 ));
568
+ mMessageAdapter .notifyDataSetChanged ();
569
+ Runnable runnable =
570
+ () -> {
571
+ Process .setThreadPriority (Process .THREAD_PRIORITY_MORE_FAVORABLE );
572
+ ETLogging .getInstance ().log ("Starting runnable prefill image" );
573
+ ETImage img = processedImageList .get (0 );
574
+ ETLogging .getInstance ().log ("Llava start prefill image" );
575
+ startPos =
576
+ mModule .prefillImages (
577
+ img .getInts (),
578
+ img .getWidth (),
579
+ img .getHeight (),
580
+ ModelUtils .VISION_MODEL_IMAGE_CHANNELS ,
581
+ startPos );
582
+ };
583
+ executor .execute (runnable );
584
+ }
585
+ }
549
586
}
550
587
551
588
private void addSelectedImagesToChatThread (List <Uri > selectedImageUri ) {
@@ -618,24 +655,6 @@ private void onModelRunStopped() {
618
655
mSendButton .setOnClickListener (
619
656
view -> {
620
657
addSelectedImagesToChatThread (mSelectedImageUri );
621
- List <ETImage > processedImageList = getProcessedImagesForModel (mSelectedImageUri );
622
- processedImageList .forEach (
623
- image -> {
624
- ETLogging .getInstance ()
625
- .log (
626
- "Image preprocessed:"
627
- + " uri = "
628
- + image .getUri ().getLastPathSegment ()
629
- + ","
630
- + " width = "
631
- + image .getWidth ()
632
- + ","
633
- + " height = "
634
- + image .getHeight ()
635
- + ","
636
- + " bytes size = "
637
- + image .getBytes ().length );
638
- });
639
658
String rawPrompt = mEditTextMessage .getText ().toString ();
640
659
// We store raw prompt into message adapter, because we don't want to show the extra
641
660
// tokens from system prompt
@@ -654,6 +673,8 @@ private void onModelRunStopped() {
654
673
new Runnable () {
655
674
@ Override
656
675
public void run () {
676
+ Process .setThreadPriority (Process .THREAD_PRIORITY_MORE_FAVORABLE );
677
+ ETLogging .getInstance ().log ("starting runnable generate()" );
657
678
runOnUiThread (
658
679
new Runnable () {
659
680
@ Override
@@ -664,31 +685,12 @@ public void run() {
664
685
long generateStartTime = System .currentTimeMillis ();
665
686
if (ModelUtils .getModelCategory (mCurrentSettingsFields .getModelType ())
666
687
== ModelUtils .VISION_MODEL ) {
667
- ETLogging .getInstance ().log ("Running inference.. prompt=" + rawPrompt );
668
- if (!processedImageList .isEmpty ()) {
669
- // For now, Llava only support 1 image.
670
- ETImage img = processedImageList .get (0 );
671
- mModule .generate (
672
- processedImageList .get (0 ).getInts (),
673
- img .getWidth (),
674
- img .getHeight (),
675
- ModelUtils .VISION_MODEL_IMAGE_CHANNELS ,
676
- rawPrompt ,
677
- ModelUtils .VISION_MODEL_SEQ_LEN ,
678
- MainActivity .this ,
679
- false );
680
- } else {
681
- // no image selected, we pass in empty int array
682
- mModule .generate (
683
- new int [0 ],
684
- 0 ,
685
- 0 ,
686
- ModelUtils .VISION_MODEL_IMAGE_CHANNELS ,
687
- rawPrompt ,
688
- ModelUtils .VISION_MODEL_SEQ_LEN ,
689
- MainActivity .this ,
690
- false );
691
- }
688
+ mModule .generateFromPos (
689
+ mCurrentSettingsFields .getFormattedSystemAndUserPrompt (rawPrompt ),
690
+ ModelUtils .VISION_MODEL_SEQ_LEN ,
691
+ startPos ,
692
+ MainActivity .this ,
693
+ false );
692
694
} else {
693
695
String finalPrompt =
694
696
getTotalFormattedPrompt (getConversationHistory (), rawPrompt );
@@ -712,7 +714,7 @@ public void run() {
712
714
ETLogging .getInstance ().log ("Inference completed" );
713
715
}
714
716
};
715
- new Thread ( runnable ). start ( );
717
+ executor . execute ( runnable );
716
718
});
717
719
mMessageAdapter .notifyDataSetChanged ();
718
720
}
0 commit comments