Skip to content

Commit 317d749

Browse files
committed
Unify order of echo parameter to be last in all layers
1 parent 24081a7 commit 317d749

File tree

3 files changed

+22
-22
lines changed

3 files changed

+22
-22
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -624,8 +624,8 @@ public void run() {
624624
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
625625
prompt,
626626
ModelUtils.VISION_MODEL_SEQ_LEN,
627-
false,
628-
MainActivity.this);
627+
MainActivity.this,
628+
false);
629629
} else {
630630
// no image selected, we pass in empty int array
631631
mModule.generate(
@@ -635,12 +635,12 @@ public void run() {
635635
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
636636
prompt,
637637
ModelUtils.VISION_MODEL_SEQ_LEN,
638-
false,
639-
MainActivity.this);
638+
MainActivity.this,
639+
false);
640640
}
641641
} else {
642642
mModule.generate(
643-
prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, false, MainActivity.this);
643+
prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, MainActivity.this, false);
644644
}
645645

646646
long generateDuration = System.currentTimeMillis() - generateStartTime;

extension/android/jni/jni_layer_llama.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ class ExecuTorchLlamaJni
150150
jint channels,
151151
facebook::jni::alias_ref<jstring> prompt,
152152
jint seq_len,
153-
jboolean echo,
154-
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
153+
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
154+
jboolean echo) {
155155
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
156156
auto image_size = image->size();
157157
std::vector<Image> images;
@@ -249,8 +249,8 @@ class ExecuTorchLlamaJni
249249
facebook::jni::alias_ref<jstring> prompt,
250250
jint seq_len,
251251
jlong start_pos,
252-
jboolean echo,
253-
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
252+
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
253+
jboolean echo) {
254254
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
255255
return static_cast<jint>(Error::NotSupported);
256256
}

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public void resetNative() {
6060
* @param llamaCallback callback object to receive results.
6161
*/
6262
public int generate(String prompt, LlamaCallback llamaCallback) {
63-
return generate(prompt, DEFAULT_SEQ_LEN, DEFAULT_ECHO, llamaCallback);
63+
return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO);
6464
}
6565

6666
/**
@@ -71,30 +71,30 @@ public int generate(String prompt, LlamaCallback llamaCallback) {
7171
* @param llamaCallback callback object to receive results.
7272
*/
7373
public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) {
74-
return generate(null, 0, 0, 0, prompt, seqLen, DEFAULT_ECHO, llamaCallback);
74+
return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, DEFAULT_ECHO);
7575
}
7676

7777
/**
7878
* Start generating tokens from the module.
7979
*
8080
* @param prompt Input prompt
81+
* @param llamaCallback callback object to receive results
8182
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
82-
* @param llamaCallback callback object to receive results.
8383
*/
84-
public int generate(String prompt, boolean echo, LlamaCallback llamaCallback) {
85-
return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, echo, llamaCallback);
84+
public int generate(String prompt, LlamaCallback llamaCallback, boolean echo) {
85+
return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, llamaCallback, echo);
8686
}
8787

8888
/**
8989
* Start generating tokens from the module.
9090
*
9191
* @param prompt Input prompt
9292
* @param seqLen sequence length
93+
* @param llamaCallback callback object to receive results
9394
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
94-
* @param llamaCallback callback object to receive results.
9595
*/
96-
public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llamaCallback) {
97-
return generate(null, 0, 0, 0, prompt, seqLen, echo, llamaCallback);
96+
public int generate(String prompt, int seqLen, LlamaCallback llamaCallback, boolean echo) {
97+
return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, echo);
9898
}
9999

100100
/**
@@ -106,8 +106,8 @@ public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llama
106106
* @param channels Input image number of channels
107107
* @param prompt Input prompt
108108
* @param seqLen sequence length
109-
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
110109
* @param llamaCallback callback object to receive results.
110+
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
111111
*/
112112
@DoNotStrip
113113
public native int generate(
@@ -117,8 +117,8 @@ public native int generate(
117117
int channels,
118118
String prompt,
119119
int seqLen,
120-
boolean echo,
121-
LlamaCallback llamaCallback);
120+
LlamaCallback llamaCallback,
121+
boolean echo);
122122

123123
/**
124124
* Prefill an LLaVA Module with the given images input.
@@ -171,12 +171,12 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
171171
* @param prompt The text prompt to LLaVA.
172172
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
173173
* @param startPos The starting position in KV cache of the input in the LLM.
174-
* @param echo indicate whether to echo the input prompt or not
175174
* @param llamaCallback callback object to receive results.
175+
* @param echo indicate whether to echo the input prompt or not.
176176
* @return The error code.
177177
*/
178178
public native int generateFromPos(
179-
String prompt, int seqLen, long startPos, boolean echo, LlamaCallback callback);
179+
String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo);
180180

181181
/** Stop current generate() before it finishes. */
182182
@DoNotStrip

0 commit comments

Comments
 (0)