Skip to content

Commit 26cbd40

Browse files
cmodi-metafacebook-github-bot
authored andcommitted
Enable Llama3 Multi-turn conversation
Summary: To provide more conversational type output, we now include previous prompt/responses as part of our input into the `generate()` by the [Llama3 prompt formatting ](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/). 1. Currently we look back at the past 2 conversations (i.e. `CONVERSATION_HISTORY_MESSAGE_LOOKBACK`) 2. This is for text only prompt and responses. 3. Supports if user closes + re-opens app again. 3. As part of this needed to separate out how `prompt` is placed in the `generate()` function since system prompt is always first, followed by conversation history (if present) and then current prompt. Multi-turn format (with example from [Llama 3 Model Card](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3)): ``` <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|> What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Bonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|> What can I do there?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Paris, the City of Light, offers a romantic getaway with must-see attractions like the Eiffel Tower and Louvre Museum, romantic experiences like river cruises and charming neighborhoods, and delicious food and drink options, with helpful tips for making the most of your trip.<|eot_id|><|start_header_id|>user<|end_header_id|> Give me a detailed list of the attractions I should visit, and time it takes in each one, to plan my trip accordingly.<|eot_id|><|start_header_id|>assistant<|end_header_id|> ``` Reviewed By: Riandy Differential Revision: D61134262
1 parent 84100d1 commit 26cbd40

File tree

5 files changed

+113
-9
lines changed

5 files changed

+113
-9
lines changed

examples/demo-apps/android/LlamaDemo/app/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ dependencies {
5757
implementation("androidx.constraintlayout:constraintlayout:2.2.0-alpha12")
5858
implementation("com.facebook.fbjni:fbjni:0.5.1")
5959
implementation("com.google.code.gson:gson:2.8.6")
60-
implementation(files("libs/executorch-llama.aar"))
60+
implementation(files("libs/modified-executorch-llama.aar"))
6161
implementation("com.google.android.material:material:1.12.0")
6262
implementation("androidx.activity:activity:1.9.0")
6363
testImplementation("junit:junit:4.13.2")

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa
7070
private SettingsFields mCurrentSettingsFields;
7171
private Handler mMemoryUpdateHandler;
7272
private Runnable memoryUpdater;
73+
private int promptID = 0;
74+
75+
private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2;
7376

7477
@Override
7578
public void onResult(String result) {
@@ -180,6 +183,11 @@ private void populateExistingMessages(String existingMsgJSON) {
180183
mMessageAdapter.notifyDataSetChanged();
181184
}
182185

186+
private int setPromptID() {
187+
188+
return mMessageAdapter.getMaxPromptID() + 1;
189+
}
190+
183191
@Override
184192
protected void onCreate(Bundle savedInstanceState) {
185193
super.onCreate(savedInstanceState);
@@ -201,6 +209,7 @@ protected void onCreate(Bundle savedInstanceState) {
201209
String existingMsgJSON = mDemoSharedPreferences.getSavedMessages();
202210
if (!existingMsgJSON.isEmpty()) {
203211
populateExistingMessages(existingMsgJSON);
212+
promptID = setPromptID();
204213
}
205214
mSettingsButton = requireViewById(R.id.settings);
206215
mSettingsButton.setOnClickListener(
@@ -537,6 +546,48 @@ private void addSelectedImagesToChatThread(List<Uri> selectedImageUri) {
537546
mMessageAdapter.notifyDataSetChanged();
538547
}
539548

549+
private String getConversationHistory() {
550+
String conversationHistory = "";
551+
552+
ArrayList<Message> conversations =
553+
mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK);
554+
if (conversations.isEmpty()) {
555+
return conversationHistory;
556+
}
557+
558+
int prevPromptID = conversations.get(0).getPromptID();
559+
String conversationFormat =
560+
PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType());
561+
String format = conversationFormat;
562+
for (int i = 0; i < conversations.size(); i++) {
563+
Message conversation = conversations.get(i);
564+
int currentPromptID = conversation.getPromptID();
565+
if (currentPromptID != prevPromptID) {
566+
conversationHistory = conversationHistory + format;
567+
format = conversationFormat;
568+
prevPromptID = currentPromptID;
569+
}
570+
if (conversation.getIsSent()) {
571+
format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText());
572+
} else {
573+
format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText());
574+
}
575+
}
576+
conversationHistory = conversationHistory + format;
577+
578+
return conversationHistory;
579+
}
580+
581+
private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) {
582+
if (conversationHistory.isEmpty()) {
583+
return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
584+
}
585+
586+
return mCurrentSettingsFields.getFormattedSystemPrompt()
587+
+ conversationHistory
588+
+ mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt);
589+
}
590+
540591
private void onModelRunStarted() {
541592
mSendButton.setClickable(false);
542593
mSendButton.setImageResource(R.drawable.baseline_stop_24);
@@ -572,20 +623,21 @@ private void onModelRunStopped() {
572623
+ " bytes size = "
573624
+ image.getBytes().length);
574625
});
626+
String conversationHistory = getConversationHistory();
575627
String rawPrompt = mEditTextMessage.getText().toString();
576-
String prompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
577628
// We store raw prompt into message adapter, because we don't want to show the extra
578629
// tokens from system prompt
579-
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, 0));
630+
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID));
580631
mMessageAdapter.notifyDataSetChanged();
581632
mEditTextMessage.setText("");
582-
mResultMessage = new Message("", false, MessageType.TEXT, 0);
633+
mResultMessage = new Message("", false, MessageType.TEXT, promptID);
583634
mMessageAdapter.add(mResultMessage);
584635
// Scroll to bottom of the list
585636
mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1);
586637
// After images are added to prompt and chat thread, we clear the imageURI list
587638
// Note: This has to be done after imageURIs are no longer needed by LlamaModule
588639
mSelectedImageUri = null;
640+
promptID++;
589641
Runnable runnable =
590642
new Runnable() {
591643
@Override
@@ -597,9 +649,10 @@ public void run() {
597649
onModelRunStarted();
598650
}
599651
});
600-
ETLogging.getInstance().log("Running inference.. prompt=" + prompt);
652+
String finalPrompt = getTotalFormattedPrompt(conversationHistory, rawPrompt);
653+
ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt);
601654
long generateStartTime = System.currentTimeMillis();
602-
mModule.generate(prompt, MainActivity.this);
655+
mModule.generate(finalPrompt, MainActivity.this);
603656
long generateDuration = System.currentTimeMillis() - generateStartTime;
604657
mResultMessage.setTotalGenerationTime(generateDuration);
605658
runOnUiThread(

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import android.widget.ImageView;
1717
import android.widget.TextView;
1818
import java.util.ArrayList;
19+
import java.util.Collections;
1920

2021
public class MessageAdapter extends ArrayAdapter<Message> {
2122

@@ -90,4 +91,41 @@ public void clear() {
9091
public ArrayList<Message> getSavedMessages() {
9192
return savedMessages;
9293
}
94+
95+
public ArrayList<Message> getRecentSavedTextMessages(int numOfLatestPromptMessages) {
96+
ArrayList<Message> recentMessages = new ArrayList<Message>();
97+
int lastIndex = savedMessages.size() - 1;
98+
Message messageToAdd = savedMessages.get(lastIndex);
99+
int oldPromptID = messageToAdd.getPromptID();
100+
101+
for (int i = 0; i < savedMessages.size(); i++) {
102+
messageToAdd = savedMessages.get(lastIndex - i);
103+
if (messageToAdd.getMessageType() != MessageType.SYSTEM) {
104+
if (messageToAdd.getPromptID() != oldPromptID) {
105+
numOfLatestPromptMessages--;
106+
oldPromptID = messageToAdd.getPromptID();
107+
}
108+
if (numOfLatestPromptMessages > 0) {
109+
if (messageToAdd.getMessageType() == MessageType.TEXT) {
110+
recentMessages.add(messageToAdd);
111+
}
112+
} else {
113+
break;
114+
}
115+
}
116+
}
117+
118+
// To place the order in [input1, output1, input2, output2...]
119+
Collections.reverse(recentMessages);
120+
return recentMessages;
121+
}
122+
123+
public int getMaxPromptID() {
124+
int maxPromptID = -1;
125+
for (Message msg : savedMessages) {
126+
127+
maxPromptID = Math.max(msg.getPromptID(), maxPromptID);
128+
}
129+
return maxPromptID;
130+
}
93131
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ public class PromptFormat {
1212

1313
public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}";
1414
public static final String USER_PLACEHOLDER = "{{ user_prompt }}";
15+
public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}";
1516

1617
public static String getSystemPromptTemplate(ModelType modelType) {
1718
switch (modelType) {
@@ -32,8 +33,20 @@ public static String getUserPromptTemplate(ModelType modelType) {
3233
case LLAMA_3_1:
3334
return "<|start_header_id|>user<|end_header_id|>\n"
3435
+ USER_PLACEHOLDER
35-
+ "<|eot_id|>\n"
36+
+ "<|eot_id|>"
3637
+ "<|start_header_id|>assistant<|end_header_id|>";
38+
39+
case LLAVA_1_5:
40+
default:
41+
return USER_PLACEHOLDER;
42+
}
43+
}
44+
45+
public static String getConversationFormat(ModelType modelType) {
46+
switch (modelType) {
47+
case LLAMA_3:
48+
case LLAMA_3_1:
49+
return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>";
3750
case LLAVA_1_5:
3851
default:
3952
return USER_PLACEHOLDER;

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ public String getFormattedSystemAndUserPrompt(String prompt) {
3838
return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt);
3939
}
4040

41-
private String getFormattedSystemPrompt() {
41+
public String getFormattedSystemPrompt() {
4242
return PromptFormat.getSystemPromptTemplate(modelType)
4343
.replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt);
4444
}
4545

46-
private String getFormattedUserPrompt(String prompt) {
46+
public String getFormattedUserPrompt(String prompt) {
4747
return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt);
4848
}
4949

0 commit comments

Comments
 (0)