Skip to content

Commit 7510f8c

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Run benchmark on background thread (#6320)
Summary: Don't run on UI thread `onCreate()`. Use a worker thread. Pull Request resolved: #6320 Reviewed By: huydhn Differential Revision: D64563455 Pulled By: kirklandsign fbshipit-source-id: bda066663e02d14344025d24cd52993998c5782f
1 parent 8f6c16e commit 7510f8c

File tree

1 file changed

+49
-34
lines changed

1 file changed

+49
-34
lines changed

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import android.app.Activity;
1212
import android.content.Intent;
13+
import android.os.AsyncTask;
1314
import android.os.Bundle;
1415
import android.system.ErrnoException;
1516
import android.system.Os;
@@ -47,43 +48,57 @@ protected void onCreate(Bundle savedInstanceState) {
4748
// TODO: Format the string with a parsable format
4849
Stats stats = new Stats();
4950

50-
// Record the time it takes to load the model and the forward method
51-
stats.loadStart = System.nanoTime();
52-
Module module = Module.load(model.getPath());
53-
stats.errorCode = module.loadMethod("forward");
54-
stats.loadEnd = System.nanoTime();
51+
new AsyncTask<Void, Void, Void>() {
52+
@Override
53+
protected Void doInBackground(Void... voids) {
5554

56-
for (int i = 0; i < numIter; i++) {
57-
long start = System.nanoTime();
58-
module.forward();
59-
double forwardMs = (System.nanoTime() - start) * 1e-6;
60-
stats.latency.add(forwardMs);
61-
}
55+
// Record the time it takes to load the model and the forward method
56+
stats.loadStart = System.nanoTime();
57+
Module module = Module.load(model.getPath());
58+
stats.errorCode = module.loadMethod("forward");
59+
stats.loadEnd = System.nanoTime();
6260

63-
final BenchmarkMetric.BenchmarkModel benchmarkModel =
64-
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
65-
final List<BenchmarkMetric> results = new ArrayList<>();
66-
// The list of metrics we have atm includes:
67-
// Avg inference latency after N iterations
68-
results.add(
69-
new BenchmarkMetric(
70-
benchmarkModel,
71-
"avg_inference_latency(ms)",
72-
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
73-
0.0f));
74-
// Model load time
75-
results.add(
76-
new BenchmarkMetric(
77-
benchmarkModel, "model_load_time(ms)", (stats.loadEnd - stats.loadStart) * 1e-6, 0.0f));
78-
// Load status
79-
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
61+
for (int i = 0; i < numIter; i++) {
62+
long start = System.nanoTime();
63+
module.forward();
64+
double forwardMs = (System.nanoTime() - start) * 1e-6;
65+
stats.latency.add(forwardMs);
66+
}
67+
return null;
68+
}
8069

81-
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
82-
Gson gson = new Gson();
83-
writer.write(gson.toJson(results));
84-
} catch (IOException e) {
85-
e.printStackTrace();
86-
}
70+
@Override
71+
protected void onPostExecute(Void aVoid) {
72+
73+
final BenchmarkMetric.BenchmarkModel benchmarkModel =
74+
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
75+
final List<BenchmarkMetric> results = new ArrayList<>();
76+
// The list of metrics we have atm includes:
77+
// Avg inference latency after N iterations
78+
results.add(
79+
new BenchmarkMetric(
80+
benchmarkModel,
81+
"avg_inference_latency(ms)",
82+
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
83+
0.0f));
84+
// Model load time
85+
results.add(
86+
new BenchmarkMetric(
87+
benchmarkModel,
88+
"model_load_time(ms)",
89+
(stats.loadEnd - stats.loadStart) * 1e-6,
90+
0.0f));
91+
// Load status
92+
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
93+
94+
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
95+
Gson gson = new Gson();
96+
writer.write(gson.toJson(results));
97+
} catch (IOException e) {
98+
e.printStackTrace();
99+
}
100+
}
101+
}.execute();
87102
}
88103
}
89104

0 commit comments

Comments
 (0)