Skip to content

Commit 0d1644f

Browse files
huydhnfacebook-github-bot
authored andcommitted
Define generic Android benchmark metric structure (#5332)
Summary: To be able to display the benchmark results, we need the following information: 1. About the model * Name, i.e. `mv2` * The backend it uses, i.e. `xnnpack` * The quantization (dtype) applied, i.e. `q8` 2. About the metric * Name, i.e. `token_per_sec`. Note that this needs to be flexible to cover future metrics * Value * An optional target (so that we can highlight regression if it happens) 3. More metadata * The device name, i.e. `samsung` * The device model and its Android version * More can be included here I codified these fields in a new `BenchmarkMetric` class, so that the benchmark results can be expressed as a list of different metrics in the result JSON. NB: Atm, the information about the model is extracted from its name, i.e. `NAME_BACKEND_QUANTIZATION.pte`, but it's better to get it from the file itself instead. Achieving this needs a bit more research. ### Testing https://github.com/pytorch/executorch/actions/runs/10843580072 * The JSON for `llama2`: ``` [ { "actual": 247, "arch": "SM-S901U1 / 12", "benchmarkModel": { "backend": "", "name": "llama2", "quantization": "" }, "device": "samsung", "metric": "model_load_time(ms)", "target": 0 }, { "actual": 367, "arch": "SM-S901U1 / 12", "benchmarkModel": { "backend": "", "name": "llama2", "quantization": "" }, "device": "samsung", "metric": "generate_time(ms)", "target": 0 }, { "actual": 342.69662, "arch": "SM-S901U1 / 12", "benchmarkModel": { "backend": "", "name": "llama2", "quantization": "" }, "device": "samsung", "metric": "token_per_sec", "target": 0 } ] ``` * The JSON for `mv2_xnnpack_q8`. I keep the average latency here as the final number to show later on the dashboard. ``` [ { "actual": 91.1, "arch": "SM-S908U1 / 12", "benchmarkModel": { "backend": "xnnpack", "name": "mv2", "quantization": "q8" }, "device": "samsung", "metric": "avg_inference_latency(ms)", "target": 0 } ] ``` Pull Request resolved: #5332 Reviewed By: guangy10, kirklandsign Differential Revision: D62624549 Pulled By: huydhn fbshipit-source-id: 5c1a605c1012396ff904c148e9a99967c83321f6
1 parent 6d1a573 commit 0d1644f

File tree

4 files changed

+254
-39
lines changed

4 files changed

+254
-39
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
package com.example.executorchllamademo;
1010

1111
import android.app.Activity;
12+
import android.app.ActivityManager;
1213
import android.content.Intent;
14+
import android.os.Build;
1315
import android.os.Bundle;
1416
import android.util.Log;
1517
import android.widget.TextView;
@@ -18,7 +20,11 @@
1820
import java.io.File;
1921
import java.io.FileWriter;
2022
import java.io.IOException;
23+
import java.util.ArrayList;
2124
import java.util.Arrays;
25+
import java.util.List;
26+
import java.util.regex.Matcher;
27+
import java.util.regex.Pattern;
2228

2329
public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback {
2430
ModelRunner mModelRunner;
@@ -50,19 +56,21 @@ protected void onCreate(Bundle savedInstanceState) {
5056
}
5157

5258
mStatsDump = new StatsDump();
59+
mStatsDump.modelName = model.getName().replace(".pte", "");
5360
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
54-
mStatsDump.loadStart = System.currentTimeMillis();
61+
mStatsDump.loadStart = System.nanoTime();
5562
}
5663

5764
@Override
5865
public void onModelLoaded(int status) {
59-
mStatsDump.loadEnd = System.currentTimeMillis();
66+
mStatsDump.loadEnd = System.nanoTime();
67+
mStatsDump.loadStatus = status;
6068
if (status != 0) {
6169
Log.e("LlmBenchmarkRunner", "Loaded failed: " + status);
6270
onGenerationStopped();
6371
return;
6472
}
65-
mStatsDump.generateStart = System.currentTimeMillis();
73+
mStatsDump.generateStart = System.nanoTime();
6674
mModelRunner.generate(mPrompt);
6775
}
6876

@@ -81,36 +89,122 @@ public void onStats(String stats) {
8189

8290
@Override
8391
public void onGenerationStopped() {
84-
mStatsDump.generateEnd = System.currentTimeMillis();
92+
mStatsDump.generateEnd = System.nanoTime();
8593
runOnUiThread(
8694
() -> {
8795
mTextView.append(mStatsDump.toString());
8896
});
8997

90-
// TODO (huydhn): Remove txt files here once the JSON format is ready
91-
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
92-
writer.write(mStatsDump.toString());
93-
} catch (IOException e) {
94-
e.printStackTrace();
95-
}
98+
final BenchmarkMetric.BenchmarkModel benchmarkModel =
99+
BenchmarkMetric.extractBackendAndQuantization(mStatsDump.modelName);
100+
final List<BenchmarkMetric> results = new ArrayList<>();
101+
// The list of metrics we have atm includes:
102+
// Load status
103+
results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsDump.loadStatus, 0));
104+
// Model load time
105+
results.add(
106+
new BenchmarkMetric(
107+
benchmarkModel,
108+
"model_load_time(ns)",
109+
mStatsDump.loadEnd - mStatsDump.loadStart,
110+
0.0f));
111+
// LLM generate time
112+
results.add(
113+
new BenchmarkMetric(
114+
benchmarkModel,
115+
"generate_time(ns)",
116+
mStatsDump.generateEnd - mStatsDump.generateStart,
117+
0.0f));
118+
// Token per second
119+
results.add(
120+
new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f));
96121

97-
// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
98-
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
99122
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
100123
Gson gson = new Gson();
101-
writer.write(gson.toJson(mStatsDump));
124+
writer.write(gson.toJson(results));
102125
} catch (IOException e) {
103126
e.printStackTrace();
104127
}
105128
}
129+
130+
private double extractTPS(final String tokens) {
131+
final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens);
132+
if (m.find()) {
133+
return Double.parseDouble(m.group());
134+
} else {
135+
return 0.0f;
136+
}
137+
}
138+
}
139+
140+
class BenchmarkMetric {
141+
public static class BenchmarkModel {
142+
// The model name, i.e. stories110M
143+
String name;
144+
String backend;
145+
String quantization;
146+
147+
public BenchmarkModel(final String name, final String backend, final String quantization) {
148+
this.name = name;
149+
this.backend = backend;
150+
this.quantization = quantization;
151+
}
152+
}
153+
154+
BenchmarkModel benchmarkModel;
155+
156+
// The metric name, i.e. TPS
157+
String metric;
158+
159+
// The actual value and the option target value
160+
double actualValue;
161+
double targetValue;
162+
163+
public static class DeviceInfo {
164+
// Let's see which information we want to include here
165+
final String device = Build.BRAND;
166+
// The phone model and Android release version
167+
final String arch = Build.MODEL;
168+
final String os = "Android " + Build.VERSION.RELEASE;
169+
final long totalMem = new ActivityManager.MemoryInfo().totalMem;
170+
final long availMem = new ActivityManager.MemoryInfo().availMem;
171+
}
172+
173+
DeviceInfo deviceInfo = new DeviceInfo();
174+
175+
public BenchmarkMetric(
176+
final BenchmarkModel benchmarkModel,
177+
final String metric,
178+
final double actualValue,
179+
final double targetValue) {
180+
this.benchmarkModel = benchmarkModel;
181+
this.metric = metric;
182+
this.actualValue = actualValue;
183+
this.targetValue = targetValue;
184+
}
185+
186+
// TODO (huydhn): Figure out a way to extract the backend and quantization information from
187+
// the .pte model itself instead of parsing its name
188+
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
189+
final Matcher m =
190+
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model);
191+
if (m.matches()) {
192+
return new BenchmarkMetric.BenchmarkModel(
193+
m.group("name"), m.group("backend"), m.group("quantization"));
194+
} else {
195+
return new BenchmarkMetric.BenchmarkModel(model, "", "");
196+
}
197+
}
106198
}
107199

108200
class StatsDump {
201+
int loadStatus;
109202
long loadStart;
110203
long loadEnd;
111204
long generateStart;
112205
long generateEnd;
113206
String tokens;
207+
String modelName;
114208

115209
@NonNull
116210
@Override

extension/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,34 +47,49 @@ protected void onCreate(Bundle savedInstanceState) {
4747
// TODO: Format the string with a parsable format
4848
Stats stats = new Stats();
4949

50+
// Record the time it takes to load the model and the forward method
51+
stats.loadStart = System.nanoTime();
5052
Module module = Module.load(model.getPath());
53+
stats.errorCode = module.loadMethod("forward");
54+
stats.loadEnd = System.nanoTime();
55+
5156
for (int i = 0; i < numIter; i++) {
52-
long start = System.currentTimeMillis();
57+
long start = System.nanoTime();
5358
module.forward();
54-
long forwardMs = System.currentTimeMillis() - start;
59+
long forwardMs = System.nanoTime() - start;
5560
stats.latency.add(forwardMs);
5661
}
57-
stats.errorCode = module.loadMethod("forward");
5862

59-
// TODO (huydhn): Remove txt files here once the JSON format is ready
60-
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
61-
writer.write(stats.toString());
62-
} catch (IOException e) {
63-
e.printStackTrace();
64-
}
63+
final BenchmarkMetric.BenchmarkModel benchmarkModel =
64+
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
65+
final List<BenchmarkMetric> results = new ArrayList<>();
66+
// The list of metrics we have atm includes:
67+
// Avg inference latency after N iterations
68+
results.add(
69+
new BenchmarkMetric(
70+
benchmarkModel,
71+
"avg_inference_latency(ns)",
72+
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
73+
0.0f));
74+
// Model load time
75+
results.add(
76+
new BenchmarkMetric(
77+
benchmarkModel, "model_load_time(ns)", stats.loadEnd - stats.loadStart, 0.0f));
78+
// Load status
79+
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
6580

66-
// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
67-
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
6881
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
6982
Gson gson = new Gson();
70-
writer.write(gson.toJson(stats));
83+
writer.write(gson.toJson(results));
7184
} catch (IOException e) {
7285
e.printStackTrace();
7386
}
7487
}
7588
}
7689

7790
class Stats {
91+
long loadStart;
92+
long loadEnd;
7893
List<Long> latency = new ArrayList<>();
7994
int errorCode = 0;
8095

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.minibench;
10+
11+
import android.app.ActivityManager;
12+
import android.os.Build;
13+
import java.util.regex.Matcher;
14+
import java.util.regex.Pattern;
15+
16+
class BenchmarkMetric {
17+
public static class BenchmarkModel {
18+
// The model name, i.e. stories110M
19+
String name;
20+
String backend;
21+
String quantization;
22+
23+
public BenchmarkModel(final String name, final String backend, final String quantization) {
24+
this.name = name;
25+
this.backend = backend;
26+
this.quantization = quantization;
27+
}
28+
}
29+
30+
BenchmarkModel benchmarkModel;
31+
32+
// The metric name, i.e. TPS
33+
String metric;
34+
35+
// The actual value and the option target value
36+
double actualValue;
37+
double targetValue;
38+
39+
public static class DeviceInfo {
40+
// Let's see which information we want to include here
41+
final String device = Build.BRAND;
42+
// The phone model and Android release version
43+
final String arch = Build.MODEL;
44+
final String os = "Android " + Build.VERSION.RELEASE;
45+
final long totalMem = new ActivityManager.MemoryInfo().totalMem;
46+
final long availMem = new ActivityManager.MemoryInfo().availMem;
47+
}
48+
49+
DeviceInfo deviceInfo = new DeviceInfo();
50+
51+
public BenchmarkMetric(
52+
final BenchmarkModel benchmarkModel,
53+
final String metric,
54+
final double actualValue,
55+
final double targetValue) {
56+
this.benchmarkModel = benchmarkModel;
57+
this.metric = metric;
58+
this.actualValue = actualValue;
59+
this.targetValue = targetValue;
60+
}
61+
62+
// TODO (huydhn): Figure out a way to extract the backend and quantization information from
63+
// the .pte model itself instead of parsing its name
64+
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
65+
final Matcher m =
66+
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model);
67+
if (m.matches()) {
68+
return new BenchmarkMetric.BenchmarkModel(
69+
m.group("name"), m.group("backend"), m.group("quantization"));
70+
} else {
71+
return new BenchmarkMetric.BenchmarkModel(model, "", "");
72+
}
73+
}
74+
}

0 commit comments

Comments
 (0)