Skip to content

Commit e134941

Browse files
committed
Android app side update
1 parent dff3368 commit e134941

File tree

4 files changed

+53
-15
lines changed

4 files changed

+53
-15
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,26 @@ public class MainActivity extends Activity implements Runnable, LlamaCallback {
3030
private LlamaModule mModule = null;
3131
private Message mResultMessage = null;
3232

33-
private int mNumTokens = 0;
34-
private long mRunStartTime = 0;
3533
private String mModelFilePath = "";
3634
private String mTokenizerFilePath = "";
3735

3836
@Override
3937
public void onResult(String result) {
40-
System.out.println("onResult: " + result);
4138
mResultMessage.appendText(result);
42-
mNumTokens++;
4339
run();
4440
}
4541

42+
@Override
43+
public void onStats(float tps) {
44+
System.out.println("LLAMAERROR ERRORRRRRR");
45+
runOnUiThread(() -> {
46+
if (mResultMessage != null) {
47+
mResultMessage.setTokensPerSecond(tps);
48+
mMessageAdapter.notifyDataSetChanged();
49+
}
50+
});
51+
}
52+
4653
private static String[] listLocalFile(String path, String suffix) {
4754
File directory = new File(path);
4855
if (directory.exists() && directory.isDirectory()) {
@@ -79,14 +86,14 @@ private void setLocalModel(String modelPath, String tokenizerPath) {
7986
});
8087
}
8188

82-
long runDuration = System.currentTimeMillis() - runStartTime;
89+
long loadDuration = System.currentTimeMillis() - runStartTime;
8390
String modelInfo =
8491
"Model path: "
8592
+ modelPath
8693
+ "\nTokenizer path: "
8794
+ tokenizerPath
8895
+ "\nModel loaded time: "
89-
+ runDuration
96+
+ loadDuration
9097
+ " ms";
9198
Message modelLoadedMessage = new Message(modelInfo, false);
9299
runOnUiThread(
@@ -175,16 +182,10 @@ private void onModelRunStarted() {
175182
view -> {
176183
mModule.stop();
177184
});
178-
179-
mRunStartTime = System.currentTimeMillis();
180185
}
181186

182187
private void onModelRunStopped() {
183188
setTitle(memoryInfo());
184-
long runDuration = System.currentTimeMillis() - mRunStartTime;
185-
if (mResultMessage != null) {
186-
mResultMessage.setTokensPerSecond(1.0f * mNumTokens / (runDuration / 1000.0f));
187-
}
188189
mSendButton.setText("Generate");
189190
mSendButton.setOnClickListener(
190191
view -> {
@@ -219,8 +220,6 @@ public void run() {
219220
};
220221
new Thread(runnable).start();
221222
});
222-
mNumTokens = 0;
223-
mRunStartTime = 0;
224223
mMessageAdapter.notifyDataSetChanged();
225224
}
226225

extension/android/jni/jni_layer_llama.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@ class ExecuTorchLlamaCallbackJni
7272
facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
7373
method(self(), s);
7474
}
75+
76+
void onStats(const Runner::TimeStampsAndStats& result) const {
77+
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
78+
static const auto method =
79+
cls->getMethod<void(jfloat)>("onStats");
80+
double eval_time = (double)(result.inference_end_ms -
81+
result.prompt_eval_end_ms);
82+
float tps = result.num_generated_tokens / eval_time *
83+
result.SCALING_FACTOR_UNITS_PER_SECOND;
84+
method(self(), tps);
85+
}
7586
};
7687

7788
class ExecuTorchLlamaJni
@@ -119,6 +130,8 @@ class ExecuTorchLlamaJni
119130
runner_->generate(
120131
prompt->toStdString(), 128, [callback](std::string result) {
121132
callback->onResult(result);
133+
}, [callback](const Runner::TimeStampsAndStats& result) {
134+
callback->onStats(result);
122135
});
123136
return 0;
124137
}

extension/android/src/main/java/org/pytorch/executorch/LlamaCallback.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,17 @@
1111
import com.facebook.jni.annotations.DoNotStrip;
1212

1313
public interface LlamaCallback {
14-
/** Called when a new result is available from JNI. User should override this method. */
14+
/** Called when a new result is available from JNI.
15+
* Users will keep getting onResult() invocations until generate() finishes.
16+
* @param result Last generated token
17+
*/
1518
@DoNotStrip
1619
public void onResult(String result);
20+
21+
/** Called when the statistics for the generate() is available.
22+
* @param tps Tokens/second for generated tokens.
23+
*/
24+
@DoNotStrip
25+
public void onStats(float tps);
26+
1727
}

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ public class LlamaModule {
2727
private static native HybridData initHybrid(
2828
String modulePath, String tokenizerPath, float temperature);
2929

30+
/**
31+
* Constructs a LLAMA Module for a model with given path, tokenizer,
32+
* and temperature.
33+
*/
3034
public LlamaModule(String modulePath, String tokenizerPath, float temperature) {
3135
mHybridData = initHybrid(modulePath, tokenizerPath, temperature);
3236
}
@@ -35,12 +39,24 @@ public void resetNative() {
3539
mHybridData.resetNative();
3640
}
3741

42+
/**
43+
* Start generating tokens from the module.
44+
* @param prompt Input prompt
45+
* @param llamaCallback callback object to receive results.
46+
*/
3847
@DoNotStrip
3948
public native int generate(String prompt, LlamaCallback llamaCallback);
4049

50+
/**
51+
* Stop current generate() before it finishes.
52+
*/
4153
@DoNotStrip
4254
public native void stop();
4355

56+
/**
57+
* Force loading the module. Otherwise the model is loaded during first
58+
* generate().
59+
*/
4460
@DoNotStrip
4561
public native int load();
4662
}

0 commit comments

Comments
 (0)