Skip to content

Commit d82e852

Browse files
authored
[Android] Use same stats as llm::Stats
Differential Revision: D73207250 Pull Request resolved: #10247
1 parent cbca483 commit d82e852

File tree

13 files changed

+97
-27
lines changed

13 files changed

+97
-27
lines changed

examples/demo-apps/android/LlamaDemo/app/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ dependencies {
6060
implementation(files("libs/executorch.aar"))
6161
implementation("com.google.android.material:material:1.12.0")
6262
implementation("androidx.activity:activity:1.9.0")
63+
implementation("org.json:json:20250107")
6364
testImplementation("junit:junit:4.13.2")
6465
androidTestImplementation("androidx.test.ext:junit:1.1.5")
6566
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")

examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import java.util.ArrayList;
1919
import java.util.Arrays;
2020
import java.util.List;
21+
import org.json.JSONException;
22+
import org.json.JSONObject;
2123
import org.junit.Test;
2224
import org.junit.runner.RunWith;
2325
import org.pytorch.executorch.extension.llm.LlmCallback;
@@ -64,8 +66,16 @@ public void onResult(String result) {
6466
}
6567

6668
@Override
67-
public void onStats(float tps) {
68-
tokensPerSecond.add(tps);
69+
public void onStats(String result) {
70+
try {
71+
JSONObject jsonObject = new JSONObject(result);
72+
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
73+
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
74+
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
75+
float tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
76+
tokensPerSecond.add(tps);
77+
} catch (JSONException e) {
78+
}
6979
}
7080

7181
private void report(final String metric, final Float value) {

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
import java.util.List;
5050
import java.util.concurrent.Executor;
5151
import java.util.concurrent.Executors;
52+
import org.json.JSONException;
53+
import org.json.JSONObject;
5254
import org.pytorch.executorch.extension.llm.LlmCallback;
5355
import org.pytorch.executorch.extension.llm.LlmModule;
5456

@@ -97,10 +99,20 @@ public void onResult(String result) {
9799
}
98100

99101
@Override
100-
public void onStats(float tps) {
102+
public void onStats(String stats) {
101103
runOnUiThread(
102104
() -> {
103105
if (mResultMessage != null) {
106+
float tps = 0;
107+
try {
108+
JSONObject jsonObject = new JSONObject(stats);
109+
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
110+
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
111+
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
112+
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
113+
} catch (JSONException e) {
114+
Log.e("LLM", "Error parsing JSON: " + e.getMessage());
115+
}
104116
mResultMessage.setTokensPerSecond(tps);
105117
mMessageAdapter.notifyDataSetChanged();
106118
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import android.os.Looper;
1414
import android.os.Message;
1515
import androidx.annotation.NonNull;
16+
import org.json.JSONException;
17+
import org.json.JSONObject;
1618
import org.pytorch.executorch.extension.llm.LlmCallback;
1719
import org.pytorch.executorch.extension.llm.LlmModule;
1820

@@ -69,7 +71,16 @@ public void onResult(String result) {
6971
}
7072

7173
@Override
72-
public void onStats(float tps) {
74+
public void onStats(String stats) {
75+
float tps = 0;
76+
try {
77+
JSONObject jsonObject = new JSONObject(stats);
78+
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
79+
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
80+
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
81+
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
82+
} catch (JSONException e) {
83+
}
7384
mCallback.onStats("tokens/second: " + tps);
7485
}
7586
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public interface ModelRunnerCallback {
1818

1919
void onTokenGenerated(String token);
2020

21-
void onStats(String token);
21+
void onStats(String stats);
2222

2323
void onGenerationStopped();
2424
}

extension/android/executorch_android/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dependencies {
4747
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
4848
androidTestImplementation 'androidx.test:rules:1.2.0'
4949
androidTestImplementation 'commons-io:commons-io:2.4'
50+
androidTestImplementation 'org.json:json:20250107'
5051
}
5152

5253
import com.vanniktech.maven.publish.SonatypeHost

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import org.apache.commons.io.FileUtils;
3535
import androidx.test.ext.junit.runners.AndroidJUnit4;
3636
import androidx.test.InstrumentationRegistry;
37+
import org.json.JSONException;
38+
import org.json.JSONObject;
3739
import org.pytorch.executorch.extension.llm.LlmCallback;
3840
import org.pytorch.executorch.extension.llm.LlmModule;
3941

@@ -94,8 +96,17 @@ public void onResult(String result) {
9496
}
9597

9698
@Override
97-
public void onStats(float tps) {
98-
LlmModuleInstrumentationTest.this.onStats(tps);
99+
public void onStats(String stats) {
100+
float tps = 0;
101+
try {
102+
JSONObject jsonObject = new JSONObject(stats);
103+
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
104+
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
105+
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
106+
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
107+
LlmModuleInstrumentationTest.this.onStats(tps);
108+
} catch (JSONException e) {
109+
}
99110
}
100111
});
101112

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,22 @@ public interface LlmCallback {
3131
/**
3232
* Called when the statistics for the generate() is available.
3333
*
34+
* Note: This is a deprecated API and will be removed in the future. Please use onStats(String stats)
35+
*
3436
* @param tps Tokens/second for generated tokens.
3537
*/
38+
@Deprecated
39+
@DoNotStrip
40+
default public void onStats(float tps) {}
41+
42+
/**
43+
* Called when the statistics for the generate() is available.
44+
*
45+
* The result will be a JSON string. See extension/llm/stats.h for the field
46+
* definitions.
47+
*
48+
* @param stats JSON string containing the statistics for the generate()
49+
*/
3650
@DoNotStrip
37-
public void onStats(float tps);
51+
default public void onStats(String stats) {}
3852
}

extension/android/jni/jni_layer_llama.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,20 @@ class ExecuTorchLlmCallbackJni
100100

101101
void onStats(const llm::Stats& result) const {
102102
static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic();
103-
static const auto method = cls->getMethod<void(jfloat)>("onStats");
103+
static const auto tps_method = cls->getMethod<void(jfloat)>("onStats");
104104
double eval_time =
105105
(double)(result.inference_end_ms - result.prompt_eval_end_ms);
106106

107107
float tps = result.num_generated_tokens / eval_time *
108108
result.SCALING_FACTOR_UNITS_PER_SECOND;
109-
110-
method(self(), tps);
109+
tps_method(self(), tps);
110+
111+
static const auto on_stats_method =
112+
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onStats");
113+
on_stats_method(
114+
self(),
115+
facebook::jni::make_jstring(
116+
executorch::extension::llm::stats_to_json_string(result)));
111117
}
112118
};
113119

extension/benchmark/android/benchmark/app/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies {
3939
implementation("com.facebook.soloader:soloader:0.10.5")
4040
implementation("com.facebook.fbjni:fbjni:0.5.1")
4141
implementation("com.google.code.gson:gson:2.8.6")
42+
implementation("org.json:json:20250107")
4243
testImplementation("junit:junit:4.13.2")
4344
androidTestImplementation("androidx.test.ext:junit:1.2.1")
4445
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.List;
2424
import java.util.regex.Matcher;
2525
import java.util.regex.Pattern;
26+
import org.json.JSONException;
27+
import org.json.JSONObject;
2628

2729
public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback {
2830
ModelRunner mModelRunner;
@@ -80,7 +82,17 @@ public void onTokenGenerated(String token) {}
8082

8183
@Override
8284
public void onStats(String stats) {
83-
mStatsInfo.tokens = stats;
85+
float tps = 0;
86+
try {
87+
JSONObject jsonObject = new JSONObject(stats);
88+
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
89+
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
90+
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
91+
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
92+
mStatsInfo.tps = tps;
93+
} catch (JSONException e) {
94+
Log.e("LLM", "Error parsing JSON: " + e.getMessage());
95+
}
8496
}
8597

8698
@Override
@@ -109,7 +121,7 @@ public void onGenerationStopped() {
109121
0.0f));
110122
// Token per second
111123
results.add(
112-
new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsInfo.tokens), 0.0f));
124+
new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f));
113125

114126
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
115127
Gson gson = new Gson();
@@ -118,15 +130,6 @@ public void onGenerationStopped() {
118130
e.printStackTrace();
119131
}
120132
}
121-
122-
private double extractTPS(final String tokens) {
123-
final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens);
124-
if (m.find()) {
125-
return Double.parseDouble(m.group());
126-
} else {
127-
return 0.0f;
128-
}
129-
}
130133
}
131134

132135
class StatsInfo {
@@ -135,7 +138,7 @@ class StatsInfo {
135138
long loadEnd;
136139
long generateStart;
137140
long generateEnd;
138-
String tokens;
141+
float tps;
139142
String modelName;
140143

141144
@Override
@@ -149,6 +152,6 @@ public String toString() {
149152
+ "\ngenerateEnd: "
150153
+ generateEnd
151154
+ "\n"
152-
+ tokens;
155+
+ tps;
153156
}
154157
}

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ public void onResult(String result) {
6868
}
6969

7070
@Override
71-
public void onStats(float tps) {
72-
mCallback.onStats("tokens/second: " + tps);
71+
public void onStats(String result) {
72+
mCallback.onStats(result);
7373
}
7474
}
7575

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public interface ModelRunnerCallback {
1818

1919
void onTokenGenerated(String token);
2020

21-
void onStats(String token);
21+
void onStats(String result);
2222

2323
void onGenerationStopped();
2424
}

0 commit comments

Comments
 (0)