Skip to content

Commit 9a53ffe

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Add api to explicitly load a method (#2408)
Summary: We lazy load methods during the first execute (forward). However, sometimes we want to explicitly load the method before the first inference, so that we get consistent performance between all runs. Adding a method so that user can load method, if they wish. Otherwise, they can still depend on execute to lazy load for them. Pull Request resolved: #2408 Test Plan: CI Explicitly call this: ``` diff --git a/xplat/executorch/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java b/xplat/executorch/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java --- a/xplat/executorch/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java +++ b/xplat/executorch/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java @@ -101,6 +101,7 @@ try { mModule = Module.load(MainActivity.assetFilePath(getApplicationContext(), "dl3_xnnpack_fp32.pte")); + mModule.loadMethod("forward"); } catch (IOException e) { Log.e("ImageSegmentation", "Error reading assets", e); @@ -136,6 +137,7 @@ mModule = Module.load( MainActivity.assetFilePath(getApplicationContext(), "dl3_xnnpack_fp32.pte")); + mModule.loadMethod("forward"); } catch (IOException e) { Log.e("ImageSegmentation", "Error reading assets", e); finish(); ``` ``` buck install fbandroid/mode/opt xplat/executorch/examples/demo-apps/android/ExecuTorchDemo/app/src/main:ExecuTorchDemo ``` and run app Differential Revision: D54868635 Pulled By: kirklandsign fbshipit-source-id: a680862020b41a73e1aa2c68a84cd68c22cb18c9
1 parent ad9f186 commit 9a53ffe

File tree

4 files changed

+28
-0
lines changed

4 files changed

+28
-0
lines changed

extension/android/jni/jni_layer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
275275
return execute_method(methodName->toStdString(), jinputs);
276276
}
277277

278+
jint load_method(facebook::jni::alias_ref<jstring> methodName) {
279+
return static_cast<jint>(module_->load_method(methodName->toStdString()));
280+
}
281+
278282
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> execute_method(
279283
std::string method,
280284
facebook::jni::alias_ref<
@@ -343,6 +347,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
343347
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
344348
makeNativeMethod("forward", ExecuTorchJni::forward),
345349
makeNativeMethod("execute", ExecuTorchJni::execute),
350+
makeNativeMethod("loadMethod", ExecuTorchJni::load_method),
346351
});
347352
}
348353
};

extension/android/src/main/java/org/pytorch/executorch/INativePeer.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,11 @@ interface INativePeer {
1818

1919
/** Run an arbitrary method on the module */
2020
EValue[] execute(String methodName, EValue... inputs);
21+
22+
/**
23+
* Load a method on this module.
24+
*
25+
* @return the Error code if there was an error loading the method
26+
*/
27+
int loadMethod(String methodName);
2128
}

extension/android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,19 @@ public EValue[] execute(String methodName, EValue... inputs) {
6767
return mNativePeer.execute(methodName, inputs);
6868
}
6969

70+
/**
71+
* Load a method on this module. This might help with the first time inference performance,
72+
* because otherwise the method is loaded lazily when it's execute. Note: this function is
73+
* synchronous, and will block until the method is loaded. Therefore, it is recommended to call
74+
* this on a background thread. However, users need to make sure that they don't execute before
75+
* this function returns.
76+
*
77+
* @return the Error code if there was an error loading the method
78+
*/
79+
public int loadMethod(String methodName) {
80+
return mNativePeer.loadMethod(methodName);
81+
}
82+
7083
/**
7184
* Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the
7285
* native object will be destroyed when this object is garbage-collected. However, the timing of

extension/android/src/main/java/org/pytorch/executorch/NativePeer.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,7 @@ public void resetNative() {
3838

3939
@DoNotStrip
4040
public native EValue[] execute(String methodName, EValue... inputs);
41+
42+
@DoNotStrip
43+
public native int loadMethod(String methodName);
4144
}

0 commit comments

Comments
 (0)