Skip to content

Commit 05cd087

Browse files
cmodi-metafacebook-github-bot
authored andcommitted
Enable Llama3 Multi-turn conversation (#4721)
Summary: Pull Request resolved: #4721 To provide more conversational type output, we now include previous prompt/responses as part of our input into the `generate()` by the [Llama3 prompt formatting ](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/). 1. Currently we look back at the past 2 conversations (i.e. `CONVERSATION_HISTORY_MESSAGE_LOOKBACK`) 2. This is for text only prompt and responses. 3. Supports if user closes + re-opens app again. 3. As part of this needed to separate out how `prompt` is placed in the `generate()` function since system prompt is always first, followed by conversation history (if present) and then current prompt. Multi-turn format (with example from [Llama 3 Model Card](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3)): ``` <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|> What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Bonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|> What can I do there?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Paris, the City of Light, offers a romantic getaway with must-see attractions like the Eiffel Tower and Louvre Museum, romantic experiences like river cruises and charming neighborhoods, and delicious food and drink options, with helpful tips for making the most of your trip.<|eot_id|><|start_header_id|>user<|end_header_id|> Give me a detailed list of the attractions I should visit, and time it takes in each one, to plan my trip accordingly.<|eot_id|><|start_header_id|>assistant<|end_header_id|> ``` Reviewed By: Riandy Differential Revision: D61134262
1 parent ec8dac9 commit 05cd087

File tree

5 files changed

+113
-9
lines changed

5 files changed

+113
-9
lines changed

examples/demo-apps/android/LlamaDemo/app/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ dependencies {
5757
implementation("androidx.constraintlayout:constraintlayout:2.2.0-alpha12")
5858
implementation("com.facebook.fbjni:fbjni:0.5.1")
5959
implementation("com.google.code.gson:gson:2.8.6")
60-
implementation(files("libs/executorch-llama.aar"))
60+
implementation(files("libs/modified-executorch-llama.aar"))
6161
implementation("com.google.android.material:material:1.12.0")
6262
implementation("androidx.activity:activity:1.9.0")
6363
testImplementation("junit:junit:4.13.2")

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa
7070
private SettingsFields mCurrentSettingsFields;
7171
private Handler mMemoryUpdateHandler;
7272
private Runnable memoryUpdater;
73+
private int promptID = 0;
74+
75+
private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2;
7376

7477
@Override
7578
public void onResult(String result) {
@@ -183,6 +186,11 @@ private void populateExistingMessages(String existingMsgJSON) {
183186
mMessageAdapter.notifyDataSetChanged();
184187
}
185188

189+
private int setPromptID() {
190+
191+
return mMessageAdapter.getMaxPromptID() + 1;
192+
}
193+
186194
@Override
187195
protected void onCreate(Bundle savedInstanceState) {
188196
super.onCreate(savedInstanceState);
@@ -204,6 +212,7 @@ protected void onCreate(Bundle savedInstanceState) {
204212
String existingMsgJSON = mDemoSharedPreferences.getSavedMessages();
205213
if (!existingMsgJSON.isEmpty()) {
206214
populateExistingMessages(existingMsgJSON);
215+
promptID = setPromptID();
207216
}
208217
mSettingsButton = requireViewById(R.id.settings);
209218
mSettingsButton.setOnClickListener(
@@ -540,6 +549,48 @@ private void addSelectedImagesToChatThread(List<Uri> selectedImageUri) {
540549
mMessageAdapter.notifyDataSetChanged();
541550
}
542551

552+
private String getConversationHistory() {
553+
String conversationHistory = "";
554+
555+
ArrayList<Message> conversations =
556+
mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK);
557+
if (conversations.isEmpty()) {
558+
return conversationHistory;
559+
}
560+
561+
int prevPromptID = conversations.get(0).getPromptID();
562+
String conversationFormat =
563+
PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType());
564+
String format = conversationFormat;
565+
for (int i = 0; i < conversations.size(); i++) {
566+
Message conversation = conversations.get(i);
567+
int currentPromptID = conversation.getPromptID();
568+
if (currentPromptID != prevPromptID) {
569+
conversationHistory = conversationHistory + format;
570+
format = conversationFormat;
571+
prevPromptID = currentPromptID;
572+
}
573+
if (conversation.getIsSent()) {
574+
format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText());
575+
} else {
576+
format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText());
577+
}
578+
}
579+
conversationHistory = conversationHistory + format;
580+
581+
return conversationHistory;
582+
}
583+
584+
private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) {
585+
if (conversationHistory.isEmpty()) {
586+
return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
587+
}
588+
589+
return mCurrentSettingsFields.getFormattedSystemPrompt()
590+
+ conversationHistory
591+
+ mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt);
592+
}
593+
543594
private void onModelRunStarted() {
544595
mSendButton.setClickable(false);
545596
mSendButton.setImageResource(R.drawable.baseline_stop_24);
@@ -575,20 +626,21 @@ private void onModelRunStopped() {
575626
+ " bytes size = "
576627
+ image.getBytes().length);
577628
});
629+
String conversationHistory = getConversationHistory();
578630
String rawPrompt = mEditTextMessage.getText().toString();
579-
String prompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
580631
// We store raw prompt into message adapter, because we don't want to show the extra
581632
// tokens from system prompt
582-
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, 0));
633+
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID));
583634
mMessageAdapter.notifyDataSetChanged();
584635
mEditTextMessage.setText("");
585-
mResultMessage = new Message("", false, MessageType.TEXT, 0);
636+
mResultMessage = new Message("", false, MessageType.TEXT, promptID);
586637
mMessageAdapter.add(mResultMessage);
587638
// Scroll to bottom of the list
588639
mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1);
589640
// After images are added to prompt and chat thread, we clear the imageURI list
590641
// Note: This has to be done after imageURIs are no longer needed by LlamaModule
591642
mSelectedImageUri = null;
643+
promptID++;
592644
Runnable runnable =
593645
new Runnable() {
594646
@Override
@@ -600,9 +652,10 @@ public void run() {
600652
onModelRunStarted();
601653
}
602654
});
603-
ETLogging.getInstance().log("Running inference.. prompt=" + prompt);
655+
String finalPrompt = getTotalFormattedPrompt(conversationHistory, rawPrompt);
656+
ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt);
604657
long generateStartTime = System.currentTimeMillis();
605-
mModule.generate(prompt, false, MainActivity.this);
658+
mModule.generate(finalPrompt, 700,false, MainActivity.this);
606659
long generateDuration = System.currentTimeMillis() - generateStartTime;
607660
mResultMessage.setTotalGenerationTime(generateDuration);
608661
runOnUiThread(

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import android.widget.ImageView;
1717
import android.widget.TextView;
1818
import java.util.ArrayList;
19+
import java.util.Collections;
1920

2021
public class MessageAdapter extends ArrayAdapter<Message> {
2122

@@ -90,4 +91,41 @@ public void clear() {
9091
public ArrayList<Message> getSavedMessages() {
9192
return savedMessages;
9293
}
94+
95+
public ArrayList<Message> getRecentSavedTextMessages(int numOfLatestPromptMessages) {
96+
ArrayList<Message> recentMessages = new ArrayList<Message>();
97+
int lastIndex = savedMessages.size() - 1;
98+
Message messageToAdd = savedMessages.get(lastIndex);
99+
int oldPromptID = messageToAdd.getPromptID();
100+
101+
for (int i = 0; i < savedMessages.size(); i++) {
102+
messageToAdd = savedMessages.get(lastIndex - i);
103+
if (messageToAdd.getMessageType() != MessageType.SYSTEM) {
104+
if (messageToAdd.getPromptID() != oldPromptID) {
105+
numOfLatestPromptMessages--;
106+
oldPromptID = messageToAdd.getPromptID();
107+
}
108+
if (numOfLatestPromptMessages > 0) {
109+
if (messageToAdd.getMessageType() == MessageType.TEXT) {
110+
recentMessages.add(messageToAdd);
111+
}
112+
} else {
113+
break;
114+
}
115+
}
116+
}
117+
118+
// To place the order in [input1, output1, input2, output2...]
119+
Collections.reverse(recentMessages);
120+
return recentMessages;
121+
}
122+
123+
public int getMaxPromptID() {
124+
int maxPromptID = -1;
125+
for (Message msg : savedMessages) {
126+
127+
maxPromptID = Math.max(msg.getPromptID(), maxPromptID);
128+
}
129+
return maxPromptID;
130+
}
93131
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ public class PromptFormat {
1212

1313
public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}";
1414
public static final String USER_PLACEHOLDER = "{{ user_prompt }}";
15+
public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}";
1516

1617
public static String getSystemPromptTemplate(ModelType modelType) {
1718
switch (modelType) {
@@ -32,8 +33,20 @@ public static String getUserPromptTemplate(ModelType modelType) {
3233
case LLAMA_3_1:
3334
return "<|start_header_id|>user<|end_header_id|>\n"
3435
+ USER_PLACEHOLDER
35-
+ "<|eot_id|>\n"
36+
+ "<|eot_id|>"
3637
+ "<|start_header_id|>assistant<|end_header_id|>";
38+
39+
case LLAVA_1_5:
40+
default:
41+
return USER_PLACEHOLDER;
42+
}
43+
}
44+
45+
public static String getConversationFormat(ModelType modelType) {
46+
switch (modelType) {
47+
case LLAMA_3:
48+
case LLAMA_3_1:
49+
return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>";
3750
case LLAVA_1_5:
3851
default:
3952
return USER_PLACEHOLDER;

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ public String getFormattedSystemAndUserPrompt(String prompt) {
3838
return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt);
3939
}
4040

41-
private String getFormattedSystemPrompt() {
41+
public String getFormattedSystemPrompt() {
4242
return PromptFormat.getSystemPromptTemplate(modelType)
4343
.replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt);
4444
}
4545

46-
private String getFormattedUserPrompt(String prompt) {
46+
public String getFormattedUserPrompt(String prompt) {
4747
return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt);
4848
}
4949

0 commit comments

Comments
 (0)