Skip to content

Commit 32d83b0

Browse files
authored
[Android Java] Get rid of forwardOnes
Differential Revision: D62327373 Pull Request resolved: #5153
1 parent 8ff79ef commit 32d83b0

File tree

3 files changed

+26
-24
lines changed

3 files changed

+26
-24
lines changed

extension/android/jni/jni_layer.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,29 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
294294
facebook::jni::alias_ref<
295295
facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
296296
jinputs) {
297+
// If no inputs is given, it will run with sample inputs (ones)
298+
if (jinputs->size() == 0) {
299+
if (module_->load_method(method) != Error::Ok) {
300+
return {};
301+
}
302+
auto&& underlying_method = module_->methods_[method].method;
303+
auto&& buf = prepare_input_tensors(*underlying_method);
304+
auto result = underlying_method->execute();
305+
if (result != Error::Ok) {
306+
return {};
307+
}
308+
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
309+
facebook::jni::JArrayClass<JEValue>::newArray(
310+
underlying_method->outputs_size());
311+
312+
for (int i = 0; i < underlying_method->outputs_size(); i++) {
313+
auto jevalue =
314+
JEValue::newJEValueFromEValue(underlying_method->get_output(i));
315+
jresult->setElement(i, *jevalue);
316+
}
317+
return jresult;
318+
}
319+
297320
std::vector<EValue> evalues;
298321
std::vector<TensorPtr> tensors;
299322

@@ -352,20 +375,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
352375
return jresult;
353376
}
354377

355-
jint forward_ones() {
356-
auto&& load_result = module_->load_method("forward");
357-
auto&& buf = prepare_input_tensors(*(module_->methods_["forward"].method));
358-
auto&& result = module_->methods_["forward"].method->execute();
359-
return (jint)result;
360-
}
361-
362378
static void registerNatives() {
363379
registerHybrid({
364380
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
365381
makeNativeMethod("forward", ExecuTorchJni::forward),
366382
makeNativeMethod("execute", ExecuTorchJni::execute),
367383
makeNativeMethod("loadMethod", ExecuTorchJni::load_method),
368-
makeNativeMethod("forwardOnes", ExecuTorchJni::forward_ones),
369384
});
370385
}
371386
};

extension/android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,12 @@ public static Module load(final String modelPath) {
7979
/**
8080
* Runs the 'forward' method of this module with the specified arguments.
8181
*
82-
* @param inputs arguments for the ExecuTorch module's 'forward' method.
82+
* @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward'
83+
* requires inputs but no inputs are given, the function will not error out, but run 'forward'
84+
* with sample inputs.
8385
* @return return value from the 'forward' method.
8486
*/
8587
public EValue[] forward(EValue... inputs) {
86-
if (inputs.length == 0) {
87-
// forward default args (ones)
88-
mNativePeer.forwardOnes();
89-
// discard the return value
90-
return null;
91-
}
9288
return mNativePeer.forward(inputs);
9389
}
9490

extension/android/src/main/java/org/pytorch/executorch/NativePeer.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,6 @@ public void resetNative() {
4343
@DoNotStrip
4444
public native EValue[] forward(EValue... inputs);
4545

46-
/**
47-
* Run a "forward" call with the sample inputs (ones) to test a module
48-
*
49-
* @return the outputs of the forward call
50-
* @apiNote This is experimental and test-only API
51-
*/
52-
@DoNotStrip
53-
public native int forwardOnes();
54-
5546
/** Run an arbitrary method on the module */
5647
@DoNotStrip
5748
public native EValue[] execute(String methodName, EValue... inputs);

0 commit comments

Comments
 (0)