Skip to content

Commit 8d8fe09

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Android app use stats from runner (#2801)
Summary: Instead of calculating from the app, we use number reported by runner so it's the same as binary. Right now we only report generated t/s from binary. TODO: Create a Java class for other stats so that Java layer can get it through JNI. Pull Request resolved: #2801 Reviewed By: shoumikhin Differential Revision: D55776409 Pulled By: kirklandsign fbshipit-source-id: 116a939b703408a4b67d3b694213617a42ff2b81
1 parent f646371 commit 8d8fe09

File tree

4 files changed

+52
-18
lines changed

4 files changed

+52
-18
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,26 @@ public class MainActivity extends Activity implements Runnable, LlamaCallback {
3030
private LlamaModule mModule = null;
3131
private Message mResultMessage = null;
3232

33-
private int mNumTokens = 0;
34-
private long mRunStartTime = 0;
3533
private String mModelFilePath = "";
3634
private String mTokenizerFilePath = "";
3735

3836
@Override
3937
public void onResult(String result) {
40-
System.out.println("onResult: " + result);
4138
mResultMessage.appendText(result);
42-
mNumTokens++;
4339
run();
4440
}
4541

42+
@Override
43+
public void onStats(float tps) {
44+
runOnUiThread(
45+
() -> {
46+
if (mResultMessage != null) {
47+
mResultMessage.setTokensPerSecond(tps);
48+
mMessageAdapter.notifyDataSetChanged();
49+
}
50+
});
51+
}
52+
4653
private static String[] listLocalFile(String path, String suffix) {
4754
File directory = new File(path);
4855
if (directory.exists() && directory.isDirectory()) {
@@ -79,14 +86,14 @@ private void setLocalModel(String modelPath, String tokenizerPath) {
7986
});
8087
}
8188

82-
long runDuration = System.currentTimeMillis() - runStartTime;
89+
long loadDuration = System.currentTimeMillis() - runStartTime;
8390
String modelInfo =
8491
"Model path: "
8592
+ modelPath
8693
+ "\nTokenizer path: "
8794
+ tokenizerPath
8895
+ "\nModel loaded time: "
89-
+ runDuration
96+
+ loadDuration
9097
+ " ms";
9198
Message modelLoadedMessage = new Message(modelInfo, false);
9299
runOnUiThread(
@@ -175,16 +182,10 @@ private void onModelRunStarted() {
175182
view -> {
176183
mModule.stop();
177184
});
178-
179-
mRunStartTime = System.currentTimeMillis();
180185
}
181186

182187
private void onModelRunStopped() {
183188
setTitle(memoryInfo());
184-
long runDuration = System.currentTimeMillis() - mRunStartTime;
185-
if (mResultMessage != null) {
186-
mResultMessage.setTokensPerSecond(1.0f * mNumTokens / (runDuration / 1000.0f));
187-
}
188189
mSendButton.setText("Generate");
189190
mSendButton.setOnClickListener(
190191
view -> {
@@ -219,8 +220,6 @@ public void run() {
219220
};
220221
new Thread(runnable).start();
221222
});
222-
mNumTokens = 0;
223-
mRunStartTime = 0;
224223
mMessageAdapter.notifyDataSetChanged();
225224
}
226225

extension/android/jni/jni_layer_llama.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ class ExecuTorchLlamaCallbackJni
7272
facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
7373
method(self(), s);
7474
}
75+
76+
void onStats(const Runner::Stats& result) const {
77+
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
78+
static const auto method = cls->getMethod<void(jfloat)>("onStats");
79+
double eval_time =
80+
(double)(result.inference_end_ms - result.prompt_eval_end_ms);
81+
82+
float tps = result.num_generated_tokens / eval_time *
83+
result.SCALING_FACTOR_UNITS_PER_SECOND;
84+
85+
method(self(), tps);
86+
}
7587
};
7688

7789
class ExecuTorchLlamaJni
@@ -117,9 +129,10 @@ class ExecuTorchLlamaJni
117129
facebook::jni::alias_ref<jstring> prompt,
118130
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
119131
runner_->generate(
120-
prompt->toStdString(), 128, [callback](std::string result) {
121-
callback->onResult(result);
122-
});
132+
prompt->toStdString(),
133+
128,
134+
[callback](std::string result) { callback->onResult(result); },
135+
[callback](const Runner::Stats& result) { callback->onStats(result); });
123136
return 0;
124137
}
125138

extension/android/src/main/java/org/pytorch/executorch/LlamaCallback.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,20 @@
1111
import com.facebook.jni.annotations.DoNotStrip;
1212

1313
public interface LlamaCallback {
14-
/** Called when a new result is available from JNI. User should override this method. */
14+
/**
15+
* Called when a new result is available from JNI. Users will keep getting onResult() invocations
16+
* until generate() finishes.
17+
*
18+
* @param result Last generated token
19+
*/
1520
@DoNotStrip
1621
public void onResult(String result);
22+
23+
/**
24+
* Called when the statistics for the generate() is available.
25+
*
26+
* @param tps Tokens/second for generated tokens.
27+
*/
28+
@DoNotStrip
29+
public void onStats(float tps);
1730
}

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ public class LlamaModule {
2727
private static native HybridData initHybrid(
2828
String modulePath, String tokenizerPath, float temperature);
2929

30+
/** Constructs a LLAMA Module for a model with given path, tokenizer, and temperature. */
3031
public LlamaModule(String modulePath, String tokenizerPath, float temperature) {
3132
mHybridData = initHybrid(modulePath, tokenizerPath, temperature);
3233
}
@@ -35,12 +36,20 @@ public void resetNative() {
3536
mHybridData.resetNative();
3637
}
3738

39+
/**
40+
* Start generating tokens from the module.
41+
*
42+
* @param prompt Input prompt
43+
* @param llamaCallback callback object to receive results.
44+
*/
3845
@DoNotStrip
3946
public native int generate(String prompt, LlamaCallback llamaCallback);
4047

48+
/** Stop current generate() before it finishes. */
4149
@DoNotStrip
4250
public native void stop();
4351

52+
/** Force loading the module. Otherwise the model is loaded during first generate(). */
4453
@DoNotStrip
4554
public native int load();
4655
}

0 commit comments

Comments
 (0)