Skip to content

Commit b52d4b6

Browse files
authored
Enable Llama3 Multi-turn conversation
Differential Revision: D61134262 Pull Request resolved: #4721
1 parent 647bfd4 commit b52d4b6

File tree

4 files changed

+118
-10
lines changed

4 files changed

+118
-10
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa
7070
private SettingsFields mCurrentSettingsFields;
7171
private Handler mMemoryUpdateHandler;
7272
private Runnable memoryUpdater;
73+
private int promptID = 0;
74+
75+
private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2;
7376

7477
@Override
7578
public void onResult(String result) {
@@ -195,6 +198,11 @@ private void populateExistingMessages(String existingMsgJSON) {
195198
mMessageAdapter.notifyDataSetChanged();
196199
}
197200

201+
private int setPromptID() {
202+
203+
return mMessageAdapter.getMaxPromptID() + 1;
204+
}
205+
198206
@Override
199207
protected void onCreate(Bundle savedInstanceState) {
200208
super.onCreate(savedInstanceState);
@@ -216,6 +224,7 @@ protected void onCreate(Bundle savedInstanceState) {
216224
String existingMsgJSON = mDemoSharedPreferences.getSavedMessages();
217225
if (!existingMsgJSON.isEmpty()) {
218226
populateExistingMessages(existingMsgJSON);
227+
promptID = setPromptID();
219228
}
220229
mSettingsButton = requireViewById(R.id.settings);
221230
mSettingsButton.setOnClickListener(
@@ -552,6 +561,48 @@ private void addSelectedImagesToChatThread(List<Uri> selectedImageUri) {
552561
mMessageAdapter.notifyDataSetChanged();
553562
}
554563

564+
private String getConversationHistory() {
565+
String conversationHistory = "";
566+
567+
ArrayList<Message> conversations =
568+
mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK);
569+
if (conversations.isEmpty()) {
570+
return conversationHistory;
571+
}
572+
573+
int prevPromptID = conversations.get(0).getPromptID();
574+
String conversationFormat =
575+
PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType());
576+
String format = conversationFormat;
577+
for (int i = 0; i < conversations.size(); i++) {
578+
Message conversation = conversations.get(i);
579+
int currentPromptID = conversation.getPromptID();
580+
if (currentPromptID != prevPromptID) {
581+
conversationHistory = conversationHistory + format;
582+
format = conversationFormat;
583+
prevPromptID = currentPromptID;
584+
}
585+
if (conversation.getIsSent()) {
586+
format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText());
587+
} else {
588+
format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText());
589+
}
590+
}
591+
conversationHistory = conversationHistory + format;
592+
593+
return conversationHistory;
594+
}
595+
596+
private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) {
597+
if (conversationHistory.isEmpty()) {
598+
return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
599+
}
600+
601+
return mCurrentSettingsFields.getFormattedSystemPrompt()
602+
+ conversationHistory
603+
+ mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt);
604+
}
605+
555606
private void onModelRunStarted() {
556607
mSendButton.setClickable(false);
557608
mSendButton.setImageResource(R.drawable.baseline_stop_24);
@@ -586,19 +637,19 @@ private void onModelRunStopped() {
586637
+ image.getBytes().length);
587638
});
588639
String rawPrompt = mEditTextMessage.getText().toString();
589-
String prompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
590640
// We store raw prompt into message adapter, because we don't want to show the extra
591641
// tokens from system prompt
592-
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, 0));
642+
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID));
593643
mMessageAdapter.notifyDataSetChanged();
594644
mEditTextMessage.setText("");
595-
mResultMessage = new Message("", false, MessageType.TEXT, 0);
645+
mResultMessage = new Message("", false, MessageType.TEXT, promptID);
596646
mMessageAdapter.add(mResultMessage);
597647
// Scroll to bottom of the list
598648
mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1);
599649
// After images are added to prompt and chat thread, we clear the imageURI list
600650
// Note: This has to be done after imageURIs are no longer needed by LlamaModule
601651
mSelectedImageUri = null;
652+
promptID++;
602653
Runnable runnable =
603654
new Runnable() {
604655
@Override
@@ -610,10 +661,10 @@ public void run() {
610661
onModelRunStarted();
611662
}
612663
});
613-
ETLogging.getInstance().log("Running inference.. prompt=" + prompt);
614664
long generateStartTime = System.currentTimeMillis();
615665
if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType())
616666
== ModelUtils.VISION_MODEL) {
667+
ETLogging.getInstance().log("Running inference.. prompt=" + rawPrompt);
617668
if (!processedImageList.isEmpty()) {
618669
// For now, Llava only support 1 image.
619670
ETImage img = processedImageList.get(0);
@@ -622,7 +673,7 @@ public void run() {
622673
img.getWidth(),
623674
img.getHeight(),
624675
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
625-
prompt,
676+
rawPrompt,
626677
ModelUtils.VISION_MODEL_SEQ_LEN,
627678
false,
628679
MainActivity.this);
@@ -633,14 +684,20 @@ public void run() {
633684
0,
634685
0,
635686
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
636-
prompt,
687+
rawPrompt,
637688
ModelUtils.VISION_MODEL_SEQ_LEN,
638689
false,
639690
MainActivity.this);
640691
}
641692
} else {
693+
String finalPrompt =
694+
getTotalFormattedPrompt(getConversationHistory(), rawPrompt);
695+
ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt);
642696
mModule.generate(
643-
prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, false, MainActivity.this);
697+
finalPrompt,
698+
(int) (finalPrompt.length() * 0.75) + 64,
699+
false,
700+
MainActivity.this);
644701
}
645702

646703
long generateDuration = System.currentTimeMillis() - generateStartTime;

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import android.widget.ImageView;
1717
import android.widget.TextView;
1818
import java.util.ArrayList;
19+
import java.util.Collections;
1920

2021
public class MessageAdapter extends ArrayAdapter<Message> {
2122

@@ -90,4 +91,41 @@ public void clear() {
9091
public ArrayList<Message> getSavedMessages() {
9192
return savedMessages;
9293
}
94+
95+
public ArrayList<Message> getRecentSavedTextMessages(int numOfLatestPromptMessages) {
96+
ArrayList<Message> recentMessages = new ArrayList<Message>();
97+
int lastIndex = savedMessages.size() - 1;
98+
Message messageToAdd = savedMessages.get(lastIndex);
99+
int oldPromptID = messageToAdd.getPromptID();
100+
101+
for (int i = 0; i < savedMessages.size(); i++) {
102+
messageToAdd = savedMessages.get(lastIndex - i);
103+
if (messageToAdd.getMessageType() != MessageType.SYSTEM) {
104+
if (messageToAdd.getPromptID() != oldPromptID) {
105+
numOfLatestPromptMessages--;
106+
oldPromptID = messageToAdd.getPromptID();
107+
}
108+
if (numOfLatestPromptMessages > 0) {
109+
if (messageToAdd.getMessageType() == MessageType.TEXT) {
110+
recentMessages.add(messageToAdd);
111+
}
112+
} else {
113+
break;
114+
}
115+
}
116+
}
117+
118+
// To place the order in [input1, output1, input2, output2...]
119+
Collections.reverse(recentMessages);
120+
return recentMessages;
121+
}
122+
123+
public int getMaxPromptID() {
124+
int maxPromptID = -1;
125+
for (Message msg : savedMessages) {
126+
127+
maxPromptID = Math.max(msg.getPromptID(), maxPromptID);
128+
}
129+
return maxPromptID;
130+
}
93131
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ public class PromptFormat {
1212

1313
public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}";
1414
public static final String USER_PLACEHOLDER = "{{ user_prompt }}";
15+
public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}";
1516

1617
public static String getSystemPromptTemplate(ModelType modelType) {
1718
switch (modelType) {
@@ -33,8 +34,20 @@ public static String getUserPromptTemplate(ModelType modelType) {
3334
case LLAMA_3_1:
3435
return "<|start_header_id|>user<|end_header_id|>\n"
3536
+ USER_PLACEHOLDER
36-
+ "<|eot_id|>\n"
37+
+ "<|eot_id|>"
3738
+ "<|start_header_id|>assistant<|end_header_id|>";
39+
40+
case LLAVA_1_5:
41+
default:
42+
return USER_PLACEHOLDER;
43+
}
44+
}
45+
46+
public static String getConversationFormat(ModelType modelType) {
47+
switch (modelType) {
48+
case LLAMA_3:
49+
case LLAMA_3_1:
50+
return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>";
3851
case LLAVA_1_5:
3952
return USER_PLACEHOLDER + " ASSISTANT:";
4053
default:

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ public String getFormattedSystemAndUserPrompt(String prompt) {
3838
return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt);
3939
}
4040

41-
private String getFormattedSystemPrompt() {
41+
public String getFormattedSystemPrompt() {
4242
return PromptFormat.getSystemPromptTemplate(modelType)
4343
.replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt);
4444
}
4545

46-
private String getFormattedUserPrompt(String prompt) {
46+
public String getFormattedUserPrompt(String prompt) {
4747
return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt);
4848
}
4949

0 commit comments

Comments
 (0)