Skip to content

[Android] Use new Llm package API #9495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/demo-apps/android/LlamaDemo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Optional Parameters:

```java
// Upon returning to the Main Chat Activity
mModule = new LlamaModule(
mModule = new LlmModule(
ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()),
modelPath,
tokenizerPath,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.pytorch.executorch.LlamaCallback;
import org.pytorch.executorch.LlamaModule;
import org.pytorch.executorch.extension.llm.LlmCallback;
import org.pytorch.executorch.extension.llm.LlmModule;

@RunWith(AndroidJUnit4.class)
public class PerfTest implements LlamaCallback {
public class PerfTest implements LlmCallback {

private static final String RESOURCE_PATH = "/data/local/tmp/llama/";
private static final String TOKENIZER_BIN = "tokenizer.bin";
Expand All @@ -41,7 +41,7 @@ public void testTokensPerSecond() {
.filter(file -> file.getName().endsWith(".pte"))
.forEach(
model -> {
LlamaModule mModule = new LlamaModule(model.getPath(), tokenizerPath, 0.8f);
LlmModule mModule = new LlmModule(model.getPath(), tokenizerPath, 0.8f);
// Print the model name because there might be more than one of them
report("ModelName", model.getName());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import org.pytorch.executorch.LlamaCallback;
import org.pytorch.executorch.LlamaModule;
import org.pytorch.executorch.extension.llm.LlmCallback;
import org.pytorch.executorch.extension.llm.LlmModule;

public class MainActivity extends AppCompatActivity implements Runnable, LlamaCallback {
public class MainActivity extends AppCompatActivity implements Runnable, LlmCallback {
private EditText mEditTextMessage;
private ImageButton mSendButton;
private ImageButton mGalleryButton;
private ImageButton mCameraButton;
private ListView mMessagesView;
private MessageAdapter mMessageAdapter;
private LlamaModule mModule = null;
private LlmModule mModule = null;
private Message mResultMessage = null;
private ImageButton mSettingsButton;
private TextView mMemoryView;
Expand Down Expand Up @@ -124,7 +124,7 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera
}
long runStartTime = System.currentTimeMillis();
mModule =
new LlamaModule(
new LlmModule(
ModelUtils.getModelCategory(
mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()),
modelPath,
Expand Down Expand Up @@ -714,7 +714,7 @@ private void onModelRunStopped() {
// Scroll to bottom of the list
mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1);
// After images are added to prompt and chat thread, we clear the imageURI list
// Note: This has to be done after imageURIs are no longer needed by LlamaModule
// Note: This has to be done after imageURIs are no longer needed by LlmModule
mSelectedImageUri = null;
promptID++;
Runnable runnable =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
import android.os.Looper;
import android.os.Message;
import androidx.annotation.NonNull;
import org.pytorch.executorch.LlamaCallback;
import org.pytorch.executorch.LlamaModule;
import org.pytorch.executorch.extension.llm.LlmCallback;
import org.pytorch.executorch.extension.llm.LlmModule;

/** A helper class to handle all model running logic within this class. */
public class ModelRunner implements LlamaCallback {
LlamaModule mModule = null;
public class ModelRunner implements LlmCallback {
LlmModule mModule = null;

String mModelFilePath = "";
String mTokenizerFilePath = "";
Expand All @@ -45,7 +45,7 @@ public class ModelRunner implements LlamaCallback {
mTokenizerFilePath = tokenizerFilePath;
mCallback = callback;

mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f);
mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f);
mHandlerThread = new HandlerThread("ModelRunner");
mHandlerThread.start();
mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ public void onTextChanged(CharSequence s, int start, int before, int count) {}
public void afterTextChanged(Editable s) {
mSetTemperature = Double.parseDouble(s.toString());
// This is needed because temperature is changed together with model loading
// Once temperature is no longer in LlamaModule constructor, we can remove this
// Once temperature is no longer in LlmModule constructor, we can remove this
mSettingsFields.saveLoadModelAction(true);
saveSettings();
}
Expand Down
2 changes: 0 additions & 2 deletions extension/android/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ fb_android_library(
fb_android_library(
name = "executorch_llama",
srcs = [
"src/main/java/org/pytorch/executorch/LlamaCallback.java",
"src/main/java/org/pytorch/executorch/LlamaModule.java",
"src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java",
"src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java",
],
Expand Down

This file was deleted.

This file was deleted.

Loading
Loading