Skip to content

Commit e245590

Browse files
authored
App side change
Differential Revision: D62458651 Pull Request resolved: #5205
1 parent ced40f4 commit e245590

File tree

5 files changed

+244
-0
lines changed

5 files changed

+244
-0
lines changed

extension/android/benchmark/app/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies {
3838
implementation(files("libs/executorch.aar"))
3939
implementation("com.facebook.soloader:soloader:0.10.5")
4040
implementation("com.facebook.fbjni:fbjni:0.5.1")
41+
implementation("com.google.code.gson:gson:2.8.6")
4142
testImplementation("junit:junit:4.13.2")
4243
androidTestImplementation("androidx.test.ext:junit:1.2.1")
4344
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")

extension/android/benchmark/app/src/main/AndroidManifest.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616
</intent-filter>
1717
</activity>
1818

19+
<activity
20+
android:name=".LlmBenchmarkActivity"
21+
android:exported="true">
22+
<intent-filter>
23+
<action android:name="org.pytorch.minibench.BENCHMARK" />
24+
</intent-filter>
25+
</activity>
26+
1927
</application>
2028

2129
</manifest>
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.minibench;
10+
11+
import android.app.Activity;
12+
import android.content.Intent;
13+
import android.os.Bundle;
14+
import android.util.Log;
15+
import com.google.gson.Gson;
16+
import java.io.File;
17+
import java.io.FileWriter;
18+
import java.io.IOException;
19+
import java.util.Arrays;
20+
21+
public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback {
22+
ModelRunner mModelRunner;
23+
24+
String mPrompt;
25+
StatsInfo mStatsInfo;
26+
27+
@Override
28+
protected void onCreate(Bundle savedInstanceState) {
29+
super.onCreate(savedInstanceState);
30+
31+
Intent intent = getIntent();
32+
33+
File modelDir = new File(intent.getStringExtra("model_dir"));
34+
File model =
35+
Arrays.stream(modelDir.listFiles())
36+
.filter(file -> file.getName().endsWith(".pte"))
37+
.findFirst()
38+
.get();
39+
String tokenizerPath = intent.getStringExtra("tokenizer_path");
40+
41+
float temperature = intent.getFloatExtra("temperature", 0.8f);
42+
mPrompt = intent.getStringExtra("prompt");
43+
if (mPrompt == null) {
44+
mPrompt = "The ultimate answer";
45+
}
46+
47+
mStatsInfo = new StatsInfo();
48+
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
49+
mStatsInfo.loadStart = System.currentTimeMillis();
50+
}
51+
52+
@Override
53+
public void onModelLoaded(int status) {
54+
mStatsInfo.loadEnd = System.currentTimeMillis();
55+
if (status != 0) {
56+
Log.e("LlmBenchmarkRunner", "Loaded failed: " + status);
57+
onGenerationStopped();
58+
return;
59+
}
60+
mStatsInfo.generateStart = System.currentTimeMillis();
61+
mModelRunner.generate(mPrompt);
62+
}
63+
64+
@Override
65+
public void onTokenGenerated(String token) {}
66+
67+
@Override
68+
public void onStats(String stats) {
69+
mStatsInfo.tokens = stats;
70+
}
71+
72+
@Override
73+
public void onGenerationStopped() {
74+
mStatsInfo.generateEnd = System.currentTimeMillis();
75+
76+
// TODO (huydhn): Remove txt files here once the JSON format is ready
77+
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
78+
writer.write(mStatsInfo.toString());
79+
} catch (IOException e) {
80+
e.printStackTrace();
81+
}
82+
83+
// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
84+
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
85+
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
86+
Gson gson = new Gson();
87+
writer.write(gson.toJson(mStatsInfo));
88+
} catch (IOException e) {
89+
e.printStackTrace();
90+
}
91+
}
92+
}
93+
94+
class StatsInfo {
95+
long loadStart;
96+
long loadEnd;
97+
long generateStart;
98+
long generateEnd;
99+
String tokens;
100+
101+
@Override
102+
public String toString() {
103+
return "loadStart: "
104+
+ loadStart
105+
+ "\nloadEnd: "
106+
+ loadEnd
107+
+ "\ngenerateStart: "
108+
+ generateStart
109+
+ "\ngenerateEnd: "
110+
+ generateEnd
111+
+ "\n"
112+
+ tokens;
113+
}
114+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.minibench;
10+
11+
import android.os.Handler;
12+
import android.os.HandlerThread;
13+
import android.os.Looper;
14+
import android.os.Message;
15+
import org.pytorch.executorch.LlamaCallback;
16+
import org.pytorch.executorch.LlamaModule;
17+
18+
/** A helper class to handle all model running logic within this class. */
19+
public class ModelRunner implements LlamaCallback {
20+
LlamaModule mModule = null;
21+
22+
String mModelFilePath = "";
23+
String mTokenizerFilePath = "";
24+
25+
ModelRunnerCallback mCallback = null;
26+
27+
HandlerThread mHandlerThread = null;
28+
Handler mHandler = null;
29+
30+
/**
31+
* ] Helper class to separate between UI logic and model runner logic. Automatically handle
32+
* generate() request on worker thread.
33+
*
34+
* @param modelFilePath
35+
* @param tokenizerFilePath
36+
* @param callback
37+
*/
38+
ModelRunner(
39+
String modelFilePath,
40+
String tokenizerFilePath,
41+
float temperature,
42+
ModelRunnerCallback callback) {
43+
mModelFilePath = modelFilePath;
44+
mTokenizerFilePath = tokenizerFilePath;
45+
mCallback = callback;
46+
47+
mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f);
48+
mHandlerThread = new HandlerThread("ModelRunner");
49+
mHandlerThread.start();
50+
mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this);
51+
52+
mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL);
53+
}
54+
55+
int generate(String prompt) {
56+
Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt);
57+
msg.sendToTarget();
58+
return 0;
59+
}
60+
61+
void stop() {
62+
mModule.stop();
63+
}
64+
65+
@Override
66+
public void onResult(String result) {
67+
mCallback.onTokenGenerated(result);
68+
}
69+
70+
@Override
71+
public void onStats(float tps) {
72+
mCallback.onStats("tokens/second: " + tps);
73+
}
74+
}
75+
76+
class ModelRunnerHandler extends Handler {
77+
public static int MESSAGE_LOAD_MODEL = 1;
78+
public static int MESSAGE_GENERATE = 2;
79+
80+
private final ModelRunner mModelRunner;
81+
82+
public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) {
83+
super(looper);
84+
mModelRunner = modelRunner;
85+
}
86+
87+
@Override
88+
public void handleMessage(android.os.Message msg) {
89+
if (msg.what == MESSAGE_LOAD_MODEL) {
90+
int status = mModelRunner.mModule.load();
91+
mModelRunner.mCallback.onModelLoaded(status);
92+
} else if (msg.what == MESSAGE_GENERATE) {
93+
mModelRunner.mModule.generate((String) msg.obj, mModelRunner);
94+
mModelRunner.mCallback.onGenerationStopped();
95+
}
96+
}
97+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.minibench;
10+
11+
/**
12+
* A helper interface within the app for MainActivity and Benchmarking to handle callback from
13+
* ModelRunner.
14+
*/
15+
public interface ModelRunnerCallback {
16+
17+
void onModelLoaded(int status);
18+
19+
void onTokenGenerated(String token);
20+
21+
void onStats(String token);
22+
23+
void onGenerationStopped();
24+
}

0 commit comments

Comments
 (0)