Skip to content

Commit d102ef1

Browse files
Riandyfacebook-github-bot
authored andcommitted
Unified Android aar support for llava and llama models
Summary: - Previously, we need two separate aar for vision and text models. Since ET core runner side has a combined aar built, I am making changes on the Android app side to support this behavior. - Introducing ModelUtils class so we can get the correct model category to be passed on to generate() - Seq_len is now an exposed parameters, defaulting to 128. For llava models, 128 is not enough, hence we are changing it to 768 when calling generate() - Minor bug fix on ETImage logic. Reviewed By: cmodi-meta, kirklandsign Differential Revision: D61406255
1 parent f326ee1 commit d102ef1

File tree

4 files changed

+77
-7
lines changed

4 files changed

+77
-7
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ public byte[] getBytes() {
4646
return bytes;
4747
}
4848

49+
public int[] getInts() {
50+
// We need to convert the byte array to an int array because
51+
// the runner expects an int array as input.
52+
int[] intArray = new int[bytes.length];
53+
for (int i = 0; i < bytes.length; i++) {
54+
intArray[i] = (bytes[i++] & 0xFF);
55+
}
56+
return intArray;
57+
}
58+
4959
private byte[] getBytesFromImageURI(Uri uri) {
5060
try {
5161
int RESIZED_IMAGE_WIDTH = 336;
@@ -72,9 +82,9 @@ private byte[] getBytesFromImageURI(Uri uri) {
7282
int blue = Color.blue(color);
7383

7484
// Store the RGB values in the byte array
75-
rgbValues[(y * width + x) * 3] = (byte) red;
76-
rgbValues[(y * width + x) * 3 + 1] = (byte) green;
77-
rgbValues[(y * width + x) * 3 + 2] = (byte) blue;
85+
rgbValues[y * width + x] = (byte) red;
86+
rgbValues[(y * width + x) + height * width] = (byte) green;
87+
rgbValues[(y * width + x) + 2 * height * width] = (byte) blue;
7888
}
7989
}
8090
return rgbValues;

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera
102102
mMessageAdapter.notifyDataSetChanged();
103103
});
104104
long runStartTime = System.currentTimeMillis();
105-
mModule = new LlamaModule(modelPath, tokenizerPath, temperature);
105+
mModule =
106+
new LlamaModule(
107+
ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()),
108+
modelPath,
109+
tokenizerPath,
110+
temperature);
106111
int loadResult = mModule.load();
107112
long loadDuration = System.currentTimeMillis() - runStartTime;
108113
String modelLoadError = "";
@@ -552,8 +557,6 @@ private void onModelRunStopped() {
552557
mSendButton.setOnClickListener(
553558
view -> {
554559
addSelectedImagesToChatThread(mSelectedImageUri);
555-
// TODO: When ET supports multimodal, this is where we will add the images as part of the
556-
// prompt.
557560
List<ETImage> processedImageList = getProcessedImagesForModel(mSelectedImageUri);
558561
processedImageList.forEach(
559562
image -> {
@@ -599,7 +602,34 @@ public void run() {
599602
});
600603
ETLogging.getInstance().log("Running inference.. prompt=" + prompt);
601604
long generateStartTime = System.currentTimeMillis();
602-
mModule.generate(prompt, MainActivity.this);
605+
if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType())
606+
== ModelUtils.VISION_MODEL) {
607+
if (!processedImageList.isEmpty()) {
608+
// For now, Llava only support 1 image.
609+
ETImage img = processedImageList.get(0);
610+
mModule.generate(
611+
processedImageList.get(0).getInts(),
612+
img.getWidth(),
613+
img.getHeight(),
614+
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
615+
prompt,
616+
ModelUtils.VISION_MODEL_SEQ_LEN,
617+
MainActivity.this);
618+
} else {
619+
// no image selected, we pass in empty int array
620+
mModule.generate(
621+
new int[0],
622+
0,
623+
0,
624+
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
625+
prompt,
626+
ModelUtils.VISION_MODEL_SEQ_LEN,
627+
MainActivity.this);
628+
}
629+
} else {
630+
mModule.generate(prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, MainActivity.this);
631+
}
632+
603633
long generateDuration = System.currentTimeMillis() - generateStartTime;
604634
mResultMessage.setTotalGenerationTime(generateDuration);
605635
runOnUiThread(
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package com.example.executorchllamademo;
10+
11+
public class ModelUtils {
12+
static final int TEXT_MODEL = 1;
13+
static final int VISION_MODEL = 2;
14+
static final int VISION_MODEL_IMAGE_CHANNELS = 3;
15+
static final int VISION_MODEL_SEQ_LEN = 768;
16+
static final int TEXT_MODEL_SEQ_LEN = 256;
17+
18+
public static int getModelCategory(ModelType modelType) {
19+
switch (modelType) {
20+
case LLAVA_1_5:
21+
return VISION_MODEL;
22+
case LLAMA_3:
23+
case LLAMA_3_1:
24+
default:
25+
return TEXT_MODEL;
26+
}
27+
}
28+
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public static String getSystemPromptTemplate(ModelType modelType) {
2121
+ SYSTEM_PLACEHOLDER
2222
+ "<|eot_id|>";
2323
case LLAVA_1_5:
24+
return "USER: ";
2425
default:
2526
return SYSTEM_PLACEHOLDER;
2627
}
@@ -35,6 +36,7 @@ public static String getUserPromptTemplate(ModelType modelType) {
3536
+ "<|eot_id|>\n"
3637
+ "<|start_header_id|>assistant<|end_header_id|>";
3738
case LLAVA_1_5:
39+
return USER_PLACEHOLDER + " ASSISTANT:";
3840
default:
3941
return USER_PLACEHOLDER;
4042
}

0 commit comments

Comments
 (0)