File tree Expand file tree Collapse file tree 2 files changed +15
-2
lines changed
src/main/java/org/pytorch/executorch Expand file tree Collapse file tree 2 files changed +15
-2
lines changed Original file line number Diff line number Diff line change @@ -127,10 +127,11 @@ class ExecuTorchLlamaJni
127
127
128
128
jint generate (
129
129
facebook::jni::alias_ref<jstring> prompt,
130
+ jint seq_len,
130
131
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
131
132
runner_->generate (
132
133
prompt->toStdString (),
133
- 128 ,
134
+ seq_len ,
134
135
[callback](std::string result) { callback->onResult (result); },
135
136
[callback](const Stats& result) { callback->onStats (result); });
136
137
return 0 ;
Original file line number Diff line number Diff line change @@ -22,6 +22,7 @@ public class LlamaModule {
22
22
}
23
23
24
24
private final HybridData mHybridData ;
25
+ private static final int DEFAULT_SEQ_LEN = 128 ;
25
26
26
27
@ DoNotStrip
27
28
private static native HybridData initHybrid (
@@ -42,8 +43,19 @@ public void resetNative() {
42
43
* @param prompt Input prompt
43
44
* @param llamaCallback callback object to receive results.
44
45
*/
46
+ public int generate (String prompt , LlamaCallback llamaCallback ) {
47
+ return generate (prompt , DEFAULT_SEQ_LEN , llamaCallback );
48
+ }
49
+
50
+ /**
51
+ * Start generating tokens from the module.
52
+ *
53
+ * @param prompt Input prompt
54
+ * @param seqLen sequence length
55
+ * @param llamaCallback callback object to receive results.
56
+ */
45
57
@ DoNotStrip
46
- public native int generate (String prompt , LlamaCallback llamaCallback );
58
+ public native int generate (String prompt , int seqLen , LlamaCallback llamaCallback );
47
59
48
60
/** Stop current generate() before it finishes. */
49
61
@ DoNotStrip
You can’t perform that action at this time.
0 commit comments