Skip to content

Android app use stats from runner #2801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,26 @@ public class MainActivity extends Activity implements Runnable, LlamaCallback {
private LlamaModule mModule = null;
private Message mResultMessage = null;

private int mNumTokens = 0;
private long mRunStartTime = 0;
private String mModelFilePath = "";
private String mTokenizerFilePath = "";

@Override
public void onResult(String result) {
System.out.println("onResult: " + result);
mResultMessage.appendText(result);
mNumTokens++;
run();
}

@Override
public void onStats(float tps) {
runOnUiThread(
() -> {
if (mResultMessage != null) {
mResultMessage.setTokensPerSecond(tps);
mMessageAdapter.notifyDataSetChanged();
}
});
}

private static String[] listLocalFile(String path, String suffix) {
File directory = new File(path);
if (directory.exists() && directory.isDirectory()) {
Expand Down Expand Up @@ -79,14 +86,14 @@ private void setLocalModel(String modelPath, String tokenizerPath) {
});
}

long runDuration = System.currentTimeMillis() - runStartTime;
long loadDuration = System.currentTimeMillis() - runStartTime;
String modelInfo =
"Model path: "
+ modelPath
+ "\nTokenizer path: "
+ tokenizerPath
+ "\nModel loaded time: "
+ runDuration
+ loadDuration
+ " ms";
Message modelLoadedMessage = new Message(modelInfo, false);
runOnUiThread(
Expand Down Expand Up @@ -175,16 +182,10 @@ private void onModelRunStarted() {
view -> {
mModule.stop();
});

mRunStartTime = System.currentTimeMillis();
}

private void onModelRunStopped() {
setTitle(memoryInfo());
long runDuration = System.currentTimeMillis() - mRunStartTime;
if (mResultMessage != null) {
mResultMessage.setTokensPerSecond(1.0f * mNumTokens / (runDuration / 1000.0f));
}
mSendButton.setText("Generate");
mSendButton.setOnClickListener(
view -> {
Expand Down Expand Up @@ -219,8 +220,6 @@ public void run() {
};
new Thread(runnable).start();
});
mNumTokens = 0;
mRunStartTime = 0;
mMessageAdapter.notifyDataSetChanged();
}

Expand Down
19 changes: 16 additions & 3 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ class ExecuTorchLlamaCallbackJni
facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
method(self(), s);
}

void onStats(const Runner::Stats& result) const {
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
static const auto method = cls->getMethod<void(jfloat)>("onStats");
double eval_time =
(double)(result.inference_end_ms - result.prompt_eval_end_ms);

float tps = result.num_generated_tokens / eval_time *
result.SCALING_FACTOR_UNITS_PER_SECOND;

method(self(), tps);
}
};

class ExecuTorchLlamaJni
Expand Down Expand Up @@ -117,9 +129,10 @@ class ExecuTorchLlamaJni
facebook::jni::alias_ref<jstring> prompt,
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
runner_->generate(
prompt->toStdString(), 128, [callback](std::string result) {
callback->onResult(result);
});
prompt->toStdString(),
128,
[callback](std::string result) { callback->onResult(result); },
[callback](const Runner::Stats& result) { callback->onStats(result); });
return 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,20 @@
import com.facebook.jni.annotations.DoNotStrip;

public interface LlamaCallback {
/** Called when a new result is available from JNI. User should override this method. */
/**
* Called when a new result is available from JNI. Users will keep getting onResult() invocations
* until generate() finishes.
*
* @param result Last generated token
*/
@DoNotStrip
public void onResult(String result);

/**
* Called when the statistics for the generate() is available.
*
* @param tps Tokens/second for generated tokens.
*/
@DoNotStrip
public void onStats(float tps);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class LlamaModule {
private static native HybridData initHybrid(
String modulePath, String tokenizerPath, float temperature);

/** Constructs a LLAMA Module for a model with given path, tokenizer, and temperature. */
public LlamaModule(String modulePath, String tokenizerPath, float temperature) {
mHybridData = initHybrid(modulePath, tokenizerPath, temperature);
}
Expand All @@ -35,12 +36,20 @@ public void resetNative() {
mHybridData.resetNative();
}

/**
* Start generating tokens from the module.
*
* @param prompt Input prompt
* @param llamaCallback callback object to receive results.
*/
@DoNotStrip
public native int generate(String prompt, LlamaCallback llamaCallback);

/** Stop current generate() before it finishes. */
@DoNotStrip
public native void stop();

/** Force loading the module. Otherwise the model is loaded during first generate(). */
@DoNotStrip
public native int load();
}