Skip to content

Commit 7b795d7

Browse files
authored
Make seq_len param available in JNI layer generate()
Differential Revision: D61343892 Pull Request resolved: #4745
1 parent 96e7f0a commit 7b795d7

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,11 @@ class ExecuTorchLlamaJni
127127

128128
jint generate(
129129
facebook::jni::alias_ref<jstring> prompt,
130+
jint seq_len,
130131
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
131132
runner_->generate(
132133
prompt->toStdString(),
133-
128,
134+
seq_len,
134135
[callback](std::string result) { callback->onResult(result); },
135136
[callback](const Stats& result) { callback->onStats(result); });
136137
return 0;

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public class LlamaModule {
2222
}
2323

2424
private final HybridData mHybridData;
25+
private static final int DEFAULT_SEQ_LEN = 128;
2526

2627
@DoNotStrip
2728
private static native HybridData initHybrid(
@@ -42,8 +43,19 @@ public void resetNative() {
4243
* @param prompt Input prompt
4344
* @param llamaCallback callback object to receive results.
4445
*/
46+
public int generate(String prompt, LlamaCallback llamaCallback) {
47+
return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback);
48+
}
49+
50+
/**
51+
* Start generating tokens from the module.
52+
*
53+
* @param prompt Input prompt
54+
* @param seqLen sequence length
55+
* @param llamaCallback callback object to receive results.
56+
*/
4557
@DoNotStrip
46-
public native int generate(String prompt, LlamaCallback llamaCallback);
58+
public native int generate(String prompt, int seqLen, LlamaCallback llamaCallback);
4759

4860
/** Stop current generate() before it finishes. */
4961
@DoNotStrip

0 commit comments

Comments
 (0)