Skip to content

Commit 5e91666

Browse files
committed
Fix tps
1 parent c7735bc commit 5e91666

File tree

4 files changed

+8
-17
lines changed

4 files changed

+8
-17
lines changed

examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public void onResult(String result) {
6969
public void onStats(String result) {
7070
try {
7171
JSONObject jsonObject = new JSONObject(result);
72-
int numGeneratedTokens = jsonObject.getInt("num_generated_tokens");
72+
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
7373
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
7474
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
7575
float tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public void onStats(String stats) {
106106
float tps = 0;
107107
try {
108108
JSONObject jsonObject = new JSONObject(stats);
109-
int numGeneratedTokens = jsonObject.getInt("num_generated_tokens");
109+
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
110110
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
111111
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
112112
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public void onStats(String stats) {
7575
float tps = 0;
7676
try {
7777
JSONObject jsonObject = new JSONObject(stats);
78-
int numGeneratedTokens = jsonObject.getInt("num_generated_tokens");
78+
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
7979
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
8080
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
8181
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ public void onStats(String stats) {
8585
float tps = 0;
8686
try {
8787
JSONObject jsonObject = new JSONObject(stats);
88-
int numGeneratedTokens = jsonObject.getInt("num_generated_tokens");
88+
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
8989
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
9090
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
9191
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
92-
mStatsInfo.tokens = String.valueOf(tps);
92+
mStatsInfo.tps = tps;
9393
} catch (JSONException e) {
9494
Log.e("LLM", "Error parsing JSON: " + e.getMessage());
9595
}
@@ -121,7 +121,7 @@ public void onGenerationStopped() {
121121
0.0f));
122122
// Token per second
123123
results.add(
124-
new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsInfo.tokens), 0.0f));
124+
new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f));
125125

126126
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
127127
Gson gson = new Gson();
@@ -130,15 +130,6 @@ public void onGenerationStopped() {
130130
e.printStackTrace();
131131
}
132132
}
133-
134-
private double extractTPS(final String tokens) {
135-
final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens);
136-
if (m.find()) {
137-
return Double.parseDouble(m.group());
138-
} else {
139-
return 0.0f;
140-
}
141-
}
142133
}
143134

144135
class StatsInfo {
@@ -147,7 +138,7 @@ class StatsInfo {
147138
long loadEnd;
148139
long generateStart;
149140
long generateEnd;
150-
String tokens;
141+
float tps;
151142
String modelName;
152143

153144
@Override
@@ -161,6 +152,6 @@ public String toString() {
161152
+ "\ngenerateEnd: "
162153
+ generateEnd
163154
+ "\n"
164-
+ tokens;
155+
+ tps;
165156
}
166157
}

0 commit comments

Comments
 (0)