Skip to content

Enable Llama3 Multi-turn conversation #4721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa
private SettingsFields mCurrentSettingsFields;
private Handler mMemoryUpdateHandler;
private Runnable memoryUpdater;
private int promptID = 0;

private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2;

@Override
public void onResult(String result) {
Expand Down Expand Up @@ -195,6 +198,11 @@ private void populateExistingMessages(String existingMsgJSON) {
mMessageAdapter.notifyDataSetChanged();
}

private int setPromptID() {

return mMessageAdapter.getMaxPromptID() + 1;
}

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
Expand All @@ -216,6 +224,7 @@ protected void onCreate(Bundle savedInstanceState) {
String existingMsgJSON = mDemoSharedPreferences.getSavedMessages();
if (!existingMsgJSON.isEmpty()) {
populateExistingMessages(existingMsgJSON);
promptID = setPromptID();
}
mSettingsButton = requireViewById(R.id.settings);
mSettingsButton.setOnClickListener(
Expand Down Expand Up @@ -552,6 +561,48 @@ private void addSelectedImagesToChatThread(List<Uri> selectedImageUri) {
mMessageAdapter.notifyDataSetChanged();
}

private String getConversationHistory() {
String conversationHistory = "";

ArrayList<Message> conversations =
mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK);
if (conversations.isEmpty()) {
return conversationHistory;
}

int prevPromptID = conversations.get(0).getPromptID();
String conversationFormat =
PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType());
String format = conversationFormat;
for (int i = 0; i < conversations.size(); i++) {
Message conversation = conversations.get(i);
int currentPromptID = conversation.getPromptID();
if (currentPromptID != prevPromptID) {
conversationHistory = conversationHistory + format;
format = conversationFormat;
prevPromptID = currentPromptID;
}
if (conversation.getIsSent()) {
format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText());
} else {
format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText());
}
}
conversationHistory = conversationHistory + format;

return conversationHistory;
}

private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) {
if (conversationHistory.isEmpty()) {
return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
}

return mCurrentSettingsFields.getFormattedSystemPrompt()
+ conversationHistory
+ mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt);
}

private void onModelRunStarted() {
mSendButton.setClickable(false);
mSendButton.setImageResource(R.drawable.baseline_stop_24);
Expand Down Expand Up @@ -586,19 +637,19 @@ private void onModelRunStopped() {
+ image.getBytes().length);
});
String rawPrompt = mEditTextMessage.getText().toString();
String prompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
// We store raw prompt into message adapter, because we don't want to show the extra
// tokens from system prompt
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, 0));
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID));
mMessageAdapter.notifyDataSetChanged();
mEditTextMessage.setText("");
mResultMessage = new Message("", false, MessageType.TEXT, 0);
mResultMessage = new Message("", false, MessageType.TEXT, promptID);
mMessageAdapter.add(mResultMessage);
// Scroll to bottom of the list
mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1);
// After images are added to prompt and chat thread, we clear the imageURI list
// Note: This has to be done after imageURIs are no longer needed by LlamaModule
mSelectedImageUri = null;
promptID++;
Runnable runnable =
new Runnable() {
@Override
Expand All @@ -610,10 +661,10 @@ public void run() {
onModelRunStarted();
}
});
ETLogging.getInstance().log("Running inference.. prompt=" + prompt);
long generateStartTime = System.currentTimeMillis();
if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType())
== ModelUtils.VISION_MODEL) {
ETLogging.getInstance().log("Running inference.. prompt=" + rawPrompt);
if (!processedImageList.isEmpty()) {
// For now, Llava only support 1 image.
ETImage img = processedImageList.get(0);
Expand All @@ -622,7 +673,7 @@ public void run() {
img.getWidth(),
img.getHeight(),
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
prompt,
rawPrompt,
ModelUtils.VISION_MODEL_SEQ_LEN,
false,
MainActivity.this);
Expand All @@ -633,14 +684,20 @@ public void run() {
0,
0,
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
prompt,
rawPrompt,
ModelUtils.VISION_MODEL_SEQ_LEN,
false,
MainActivity.this);
}
} else {
String finalPrompt =
getTotalFormattedPrompt(getConversationHistory(), rawPrompt);
ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt);
mModule.generate(
prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, false, MainActivity.this);
finalPrompt,
(int) (finalPrompt.length() * 0.75) + 64,
false,
MainActivity.this);
}

long generateDuration = System.currentTimeMillis() - generateStartTime;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import android.widget.ImageView;
import android.widget.TextView;
import java.util.ArrayList;
import java.util.Collections;

public class MessageAdapter extends ArrayAdapter<Message> {

Expand Down Expand Up @@ -90,4 +91,41 @@ public void clear() {
public ArrayList<Message> getSavedMessages() {
return savedMessages;
}

public ArrayList<Message> getRecentSavedTextMessages(int numOfLatestPromptMessages) {
ArrayList<Message> recentMessages = new ArrayList<Message>();
int lastIndex = savedMessages.size() - 1;
Message messageToAdd = savedMessages.get(lastIndex);
int oldPromptID = messageToAdd.getPromptID();

for (int i = 0; i < savedMessages.size(); i++) {
messageToAdd = savedMessages.get(lastIndex - i);
if (messageToAdd.getMessageType() != MessageType.SYSTEM) {
if (messageToAdd.getPromptID() != oldPromptID) {
numOfLatestPromptMessages--;
oldPromptID = messageToAdd.getPromptID();
}
if (numOfLatestPromptMessages > 0) {
if (messageToAdd.getMessageType() == MessageType.TEXT) {
recentMessages.add(messageToAdd);
}
} else {
break;
}
}
}

// To place the order in [input1, output1, input2, output2...]
Collections.reverse(recentMessages);
return recentMessages;
}

public int getMaxPromptID() {
int maxPromptID = -1;
for (Message msg : savedMessages) {

maxPromptID = Math.max(msg.getPromptID(), maxPromptID);
}
return maxPromptID;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public class PromptFormat {

public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}";
public static final String USER_PLACEHOLDER = "{{ user_prompt }}";
public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}";

public static String getSystemPromptTemplate(ModelType modelType) {
switch (modelType) {
Expand All @@ -33,8 +34,20 @@ public static String getUserPromptTemplate(ModelType modelType) {
case LLAMA_3_1:
return "<|start_header_id|>user<|end_header_id|>\n"
+ USER_PLACEHOLDER
+ "<|eot_id|>\n"
+ "<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>";

case LLAVA_1_5:
default:
return USER_PLACEHOLDER;
}
}

public static String getConversationFormat(ModelType modelType) {
switch (modelType) {
case LLAMA_3:
case LLAMA_3_1:
return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>";
case LLAVA_1_5:
return USER_PLACEHOLDER + " ASSISTANT:";
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ public String getFormattedSystemAndUserPrompt(String prompt) {
return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt);
}

private String getFormattedSystemPrompt() {
public String getFormattedSystemPrompt() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

package private?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is an issue? The reasoning for this is that in MainActivity we can add the system prompt first before any conversational history.

return PromptFormat.getSystemPromptTemplate(modelType)
.replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt);
}

private String getFormattedUserPrompt(String prompt) {
public String getFormattedUserPrompt(String prompt) {
return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt);
}

Expand Down
Loading