Skip to content

Commit ed3ee79

Browse files
committed
integration
1 parent 438efc8 commit ed3ee79

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

extension/android/jni/jni_layer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,12 +354,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
354354
}
355355

356356
jint forward_ones() {
357+
auto&& load_result = module_->load_method("forward");
357358
auto&& buf = prepare_input_tensors(*(module_->methods_["forward"].method));
358-
auto&& result = module_->forward();
359-
return (jint) result.error();
359+
auto&& result = module_->methods_["forward"].method->execute();
360+
return (jint) result;
360361
}
361362

362-
363363
static void registerNatives() {
364364
registerHybrid({
365365
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),

extension/android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ public static Module load(final String modelPath) {
7979
* @return return value from the 'forward' method.
8080
*/
8181
public EValue[] forward(EValue... inputs) {
82+
if (inputs.length == 0) {
83+
// forward default args (ones)
84+
mNativePeer.forwardOnes();
85+
// discard the return value
86+
return null;
87+
}
8288
return mNativePeer.forward(inputs);
8389
}
8490

0 commit comments

Comments
 (0)