Skip to content

Commit 2662099

Browse files
Riandyfacebook-github-bot
authored andcommitted
Make seq_len param available in JNI layer generate() (#4745)
Summary: Pull Request resolved: #4745 - Previously, on JNI side, we fixed the seq_len to 128. - In this diff, we expose seq_len as a paramter in generate(), so devs can customize this as needed Reviewed By: kirklandsign Differential Revision: D61343892
1 parent 5c9a00a commit 2662099

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,11 @@ class ExecuTorchLlamaJni
127127

128128
jint generate(
129129
facebook::jni::alias_ref<jstring> prompt,
130+
jint seq_len,
130131
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
131132
runner_->generate(
132133
prompt->toStdString(),
133-
128,
134+
seq_len,
134135
[callback](std::string result) { callback->onResult(result); },
135136
[callback](const Stats& result) { callback->onStats(result); });
136137
return 0;

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public class LlamaModule {
2222
}
2323

2424
private final HybridData mHybridData;
25+
private static final int DEFAULT_SEQ_LEN = 128;
2526

2627
@DoNotStrip
2728
private static native HybridData initHybrid(
@@ -42,8 +43,19 @@ public void resetNative() {
4243
* @param prompt Input prompt
4344
* @param llamaCallback callback object to receive results.
4445
*/
46+
public int generate(String prompt, LlamaCallback llamaCallback) {
47+
return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback);
48+
}
49+
50+
/**
51+
* Start generating tokens from the module.
52+
*
53+
* @param prompt Input prompt
54+
* @param seqLen sequence length
55+
* @param llamaCallback callback object to receive results.
56+
*/
4557
@DoNotStrip
46-
public native int generate(String prompt, LlamaCallback llamaCallback);
58+
public native int generate(String prompt, int seqLen, LlamaCallback llamaCallback);
4759

4860
/** Stop current generate() before it finishes. */
4961
@DoNotStrip

0 commit comments

Comments
 (0)