Skip to content

Commit 83fc25d

Browse files
cmodi-metafacebook-github-bot
authored andcommitted
Enable Llama3 Multi-turn conversation (#4721)
Summary: Pull Request resolved: #4721 To provide more conversational type output, we now include previous prompt/responses as part of our input into the `generate()` by the [Llama3 prompt formatting ](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/). 1. Currently we look back at the past 2 conversations (i.e. `CONVERSATION_HISTORY_MESSAGE_LOOKBACK`) 2. This is for text only prompt and responses. 3. Supports if user closes + re-opens app again. 3. As part of this needed to separate out how `prompt` is placed in the `generate()` function since system prompt is always first, followed by conversation history (if present) and then current prompt. Multi-turn format (with example from [Llama 3 Model Card](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3)): ``` <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|> What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Bonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|> What can I do there?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Paris, the City of Light, offers a romantic getaway with must-see attractions like the Eiffel Tower and Louvre Museum, romantic experiences like river cruises and charming neighborhoods, and delicious food and drink options, with helpful tips for making the most of your trip.<|eot_id|><|start_header_id|>user<|end_header_id|> Give me a detailed list of the attractions I should visit, and time it takes in each one, to plan my trip accordingly.<|eot_id|><|start_header_id|>assistant<|end_header_id|> ``` Reviewed By: Riandy Differential Revision: D61134262
1 parent 99fbca3 commit 83fc25d

File tree

4 files changed

+118
-10
lines changed

4 files changed

+118
-10
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa
7070
private SettingsFields mCurrentSettingsFields;
7171
private Handler mMemoryUpdateHandler;
7272
private Runnable memoryUpdater;
73+
private int promptID = 0;
74+
75+
private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2;
7376

7477
@Override
7578
public void onResult(String result) {
@@ -195,6 +198,11 @@ private void populateExistingMessages(String existingMsgJSON) {
195198
mMessageAdapter.notifyDataSetChanged();
196199
}
197200

201+
private int setPromptID() {
202+
203+
return mMessageAdapter.getMaxPromptID() + 1;
204+
}
205+
198206
@Override
199207
protected void onCreate(Bundle savedInstanceState) {
200208
super.onCreate(savedInstanceState);
@@ -216,6 +224,7 @@ protected void onCreate(Bundle savedInstanceState) {
216224
String existingMsgJSON = mDemoSharedPreferences.getSavedMessages();
217225
if (!existingMsgJSON.isEmpty()) {
218226
populateExistingMessages(existingMsgJSON);
227+
promptID = setPromptID();
219228
}
220229
mSettingsButton = requireViewById(R.id.settings);
221230
mSettingsButton.setOnClickListener(
@@ -552,6 +561,48 @@ private void addSelectedImagesToChatThread(List<Uri> selectedImageUri) {
552561
mMessageAdapter.notifyDataSetChanged();
553562
}
554563

564+
private String getConversationHistory() {
565+
String conversationHistory = "";
566+
567+
ArrayList<Message> conversations =
568+
mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK);
569+
if (conversations.isEmpty()) {
570+
return conversationHistory;
571+
}
572+
573+
int prevPromptID = conversations.get(0).getPromptID();
574+
String conversationFormat =
575+
PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType());
576+
String format = conversationFormat;
577+
for (int i = 0; i < conversations.size(); i++) {
578+
Message conversation = conversations.get(i);
579+
int currentPromptID = conversation.getPromptID();
580+
if (currentPromptID != prevPromptID) {
581+
conversationHistory = conversationHistory + format;
582+
format = conversationFormat;
583+
prevPromptID = currentPromptID;
584+
}
585+
if (conversation.getIsSent()) {
586+
format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText());
587+
} else {
588+
format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText());
589+
}
590+
}
591+
conversationHistory = conversationHistory + format;
592+
593+
return conversationHistory;
594+
}
595+
596+
private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) {
597+
if (conversationHistory.isEmpty()) {
598+
return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
599+
}
600+
601+
return mCurrentSettingsFields.getFormattedSystemPrompt()
602+
+ conversationHistory
603+
+ mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt);
604+
}
605+
555606
private void onModelRunStarted() {
556607
mSendButton.setClickable(false);
557608
mSendButton.setImageResource(R.drawable.baseline_stop_24);
@@ -586,19 +637,19 @@ private void onModelRunStopped() {
586637
+ image.getBytes().length);
587638
});
588639
String rawPrompt = mEditTextMessage.getText().toString();
589-
String prompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
590640
// We store raw prompt into message adapter, because we don't want to show the extra
591641
// tokens from system prompt
592-
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, 0));
642+
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID));
593643
mMessageAdapter.notifyDataSetChanged();
594644
mEditTextMessage.setText("");
595-
mResultMessage = new Message("", false, MessageType.TEXT, 0);
645+
mResultMessage = new Message("", false, MessageType.TEXT, promptID);
596646
mMessageAdapter.add(mResultMessage);
597647
// Scroll to bottom of the list
598648
mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1);
599649
// After images are added to prompt and chat thread, we clear the imageURI list
600650
// Note: This has to be done after imageURIs are no longer needed by LlamaModule
601651
mSelectedImageUri = null;
652+
promptID++;
602653
Runnable runnable =
603654
new Runnable() {
604655
@Override
@@ -610,10 +661,10 @@ public void run() {
610661
onModelRunStarted();
611662
}
612663
});
613-
ETLogging.getInstance().log("Running inference.. prompt=" + prompt);
614664
long generateStartTime = System.currentTimeMillis();
615665
if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType())
616666
== ModelUtils.VISION_MODEL) {
667+
ETLogging.getInstance().log("Running inference.. prompt=" + rawPrompt);
617668
if (!processedImageList.isEmpty()) {
618669
// For now, Llava only support 1 image.
619670
ETImage img = processedImageList.get(0);
@@ -622,7 +673,7 @@ public void run() {
622673
img.getWidth(),
623674
img.getHeight(),
624675
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
625-
prompt,
676+
rawPrompt,
626677
ModelUtils.VISION_MODEL_SEQ_LEN,
627678
false,
628679
MainActivity.this);
@@ -633,14 +684,20 @@ public void run() {
633684
0,
634685
0,
635686
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
636-
prompt,
687+
rawPrompt,
637688
ModelUtils.VISION_MODEL_SEQ_LEN,
638689
false,
639690
MainActivity.this);
640691
}
641692
} else {
693+
String finalPrompt =
694+
getTotalFormattedPrompt(getConversationHistory(), rawPrompt);
695+
ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt);
642696
mModule.generate(
643-
prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, false, MainActivity.this);
697+
finalPrompt,
698+
(int) (finalPrompt.length() * 0.75) + 64,
699+
false,
700+
MainActivity.this);
644701
}
645702

646703
long generateDuration = System.currentTimeMillis() - generateStartTime;

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import android.widget.ImageView;
1717
import android.widget.TextView;
1818
import java.util.ArrayList;
19+
import java.util.Collections;
1920

2021
public class MessageAdapter extends ArrayAdapter<Message> {
2122

@@ -90,4 +91,41 @@ public void clear() {
9091
public ArrayList<Message> getSavedMessages() {
9192
return savedMessages;
9293
}
94+
95+
public ArrayList<Message> getRecentSavedTextMessages(int numOfLatestPromptMessages) {
96+
ArrayList<Message> recentMessages = new ArrayList<Message>();
97+
int lastIndex = savedMessages.size() - 1;
98+
Message messageToAdd = savedMessages.get(lastIndex);
99+
int oldPromptID = messageToAdd.getPromptID();
100+
101+
for (int i = 0; i < savedMessages.size(); i++) {
102+
messageToAdd = savedMessages.get(lastIndex - i);
103+
if (messageToAdd.getMessageType() != MessageType.SYSTEM) {
104+
if (messageToAdd.getPromptID() != oldPromptID) {
105+
numOfLatestPromptMessages--;
106+
oldPromptID = messageToAdd.getPromptID();
107+
}
108+
if (numOfLatestPromptMessages > 0) {
109+
if (messageToAdd.getMessageType() == MessageType.TEXT) {
110+
recentMessages.add(messageToAdd);
111+
}
112+
} else {
113+
break;
114+
}
115+
}
116+
}
117+
118+
// To place the order in [input1, output1, input2, output2...]
119+
Collections.reverse(recentMessages);
120+
return recentMessages;
121+
}
122+
123+
public int getMaxPromptID() {
124+
int maxPromptID = -1;
125+
for (Message msg : savedMessages) {
126+
127+
maxPromptID = Math.max(msg.getPromptID(), maxPromptID);
128+
}
129+
return maxPromptID;
130+
}
93131
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ public class PromptFormat {
1212

1313
public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}";
1414
public static final String USER_PLACEHOLDER = "{{ user_prompt }}";
15+
public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}";
1516

1617
public static String getSystemPromptTemplate(ModelType modelType) {
1718
switch (modelType) {
@@ -33,8 +34,20 @@ public static String getUserPromptTemplate(ModelType modelType) {
3334
case LLAMA_3_1:
3435
return "<|start_header_id|>user<|end_header_id|>\n"
3536
+ USER_PLACEHOLDER
36-
+ "<|eot_id|>\n"
37+
+ "<|eot_id|>"
3738
+ "<|start_header_id|>assistant<|end_header_id|>";
39+
40+
case LLAVA_1_5:
41+
default:
42+
return USER_PLACEHOLDER;
43+
}
44+
}
45+
46+
public static String getConversationFormat(ModelType modelType) {
47+
switch (modelType) {
48+
case LLAMA_3:
49+
case LLAMA_3_1:
50+
return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>";
3851
case LLAVA_1_5:
3952
return USER_PLACEHOLDER + " ASSISTANT:";
4053
default:

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ public String getFormattedSystemAndUserPrompt(String prompt) {
3838
return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt);
3939
}
4040

41-
private String getFormattedSystemPrompt() {
41+
public String getFormattedSystemPrompt() {
4242
return PromptFormat.getSystemPromptTemplate(modelType)
4343
.replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt);
4444
}
4545

46-
private String getFormattedUserPrompt(String prompt) {
46+
public String getFormattedUserPrompt(String prompt) {
4747
return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt);
4848
}
4949

0 commit comments

Comments
 (0)