Skip to content

[Android] Use same stats as llm::Stats #10247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ dependencies {
implementation(files("libs/executorch.aar"))
implementation("com.google.android.material:material:1.12.0")
implementation("androidx.activity:activity:1.9.0")
implementation("org.json:json:20250107")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.json.JSONException;
import org.json.JSONObject;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.pytorch.executorch.extension.llm.LlmCallback;
Expand Down Expand Up @@ -64,8 +66,16 @@ public void onResult(String result) {
}

@Override
public void onStats(float tps) {
tokensPerSecond.add(tps);
public void onStats(String result) {
try {
JSONObject jsonObject = new JSONObject(result);
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
float tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
tokensPerSecond.add(tps);
} catch (JSONException e) {
}
}

private void report(final String metric, final Float value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import org.json.JSONException;
import org.json.JSONObject;
import org.pytorch.executorch.extension.llm.LlmCallback;
import org.pytorch.executorch.extension.llm.LlmModule;

Expand Down Expand Up @@ -97,10 +99,20 @@ public void onResult(String result) {
}

@Override
public void onStats(float tps) {
public void onStats(String stats) {
runOnUiThread(
() -> {
if (mResultMessage != null) {
float tps = 0;
try {
JSONObject jsonObject = new JSONObject(stats);
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
} catch (JSONException e) {
Log.e("LLM", "Error parsing JSON: " + e.getMessage());
}
mResultMessage.setTokensPerSecond(tps);
mMessageAdapter.notifyDataSetChanged();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import android.os.Looper;
import android.os.Message;
import androidx.annotation.NonNull;
import org.json.JSONException;
import org.json.JSONObject;
import org.pytorch.executorch.extension.llm.LlmCallback;
import org.pytorch.executorch.extension.llm.LlmModule;

Expand Down Expand Up @@ -69,7 +71,16 @@ public void onResult(String result) {
}

@Override
public void onStats(float tps) {
public void onStats(String stats) {
float tps = 0;
try {
JSONObject jsonObject = new JSONObject(stats);
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
} catch (JSONException e) {
}
mCallback.onStats("tokens/second: " + tps);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public interface ModelRunnerCallback {

void onTokenGenerated(String token);

void onStats(String token);
void onStats(String stats);

void onGenerationStopped();
}
1 change: 1 addition & 0 deletions extension/android/executorch_android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies {
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
androidTestImplementation 'androidx.test:rules:1.2.0'
androidTestImplementation 'commons-io:commons-io:2.4'
androidTestImplementation 'org.json:json:20250107'
}

import com.vanniktech.maven.publish.SonatypeHost
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.apache.commons.io.FileUtils;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.InstrumentationRegistry;
import org.json.JSONException;
import org.json.JSONObject;
import org.pytorch.executorch.extension.llm.LlmCallback;
import org.pytorch.executorch.extension.llm.LlmModule;

Expand Down Expand Up @@ -94,8 +96,17 @@ public void onResult(String result) {
}

@Override
public void onStats(float tps) {
LlmModuleInstrumentationTest.this.onStats(tps);
public void onStats(String stats) {
float tps = 0;
try {
JSONObject jsonObject = new JSONObject(stats);
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
LlmModuleInstrumentationTest.this.onStats(tps);
} catch (JSONException e) {
}
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,22 @@ public interface LlmCallback {
/**
* Called when the statistics for the generate() is available.
*
* Note: This is a deprecated API and will be removed in the future. Please use onStats(String stats)
*
* @param tps Tokens/second for generated tokens.
*/
@Deprecated
@DoNotStrip
default public void onStats(float tps) {}

/**
* Called when the statistics for the generate() is available.
*
* The result will be a JSON string. See extension/llm/stats.h for the field
* definitions.
*
* @param stats JSON string containing the statistics for the generate()
*/
@DoNotStrip
public void onStats(float tps);
default public void onStats(String stats) {}
}
12 changes: 9 additions & 3 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,20 @@ class ExecuTorchLlmCallbackJni

void onStats(const llm::Stats& result) const {
static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic();
static const auto method = cls->getMethod<void(jfloat)>("onStats");
static const auto tps_method = cls->getMethod<void(jfloat)>("onStats");
double eval_time =
(double)(result.inference_end_ms - result.prompt_eval_end_ms);

float tps = result.num_generated_tokens / eval_time *
result.SCALING_FACTOR_UNITS_PER_SECOND;

method(self(), tps);
tps_method(self(), tps);

static const auto on_stats_method =
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onStats");
on_stats_method(
self(),
facebook::jni::make_jstring(
executorch::extension::llm::stats_to_json_string(result)));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies {
implementation("com.facebook.soloader:soloader:0.10.5")
implementation("com.facebook.fbjni:fbjni:0.5.1")
implementation("com.google.code.gson:gson:2.8.6")
implementation("org.json:json:20250107")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.2.1")
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.json.JSONException;
import org.json.JSONObject;

public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback {
ModelRunner mModelRunner;
Expand Down Expand Up @@ -80,7 +82,17 @@ public void onTokenGenerated(String token) {}

@Override
public void onStats(String stats) {
mStatsInfo.tokens = stats;
float tps = 0;
try {
JSONObject jsonObject = new JSONObject(stats);
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
mStatsInfo.tps = tps;
} catch (JSONException e) {
Log.e("LLM", "Error parsing JSON: " + e.getMessage());
}
}

@Override
Expand Down Expand Up @@ -109,7 +121,7 @@ public void onGenerationStopped() {
0.0f));
// Token per second
results.add(
new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsInfo.tokens), 0.0f));
new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f));

try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
Expand All @@ -118,15 +130,6 @@ public void onGenerationStopped() {
e.printStackTrace();
}
}

private double extractTPS(final String tokens) {
final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens);
if (m.find()) {
return Double.parseDouble(m.group());
} else {
return 0.0f;
}
}
}

class StatsInfo {
Expand All @@ -135,7 +138,7 @@ class StatsInfo {
long loadEnd;
long generateStart;
long generateEnd;
String tokens;
float tps;
String modelName;

@Override
Expand All @@ -149,6 +152,6 @@ public String toString() {
+ "\ngenerateEnd: "
+ generateEnd
+ "\n"
+ tokens;
+ tps;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ public void onResult(String result) {
}

@Override
public void onStats(float tps) {
mCallback.onStats("tokens/second: " + tps);
public void onStats(String result) {
mCallback.onStats(result);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public interface ModelRunnerCallback {

void onTokenGenerated(String token);

void onStats(String token);
void onStats(String result);

void onGenerationStopped();
}
Loading