Skip to content

Commit 47d3afb

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Allow setting thread count from Java
Summary: Allow passing in an explicit thread count through the Java bindings. Differential Revision: D74676049
1 parent d0360b7 commit 47d3afb

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,22 @@ public static Module load(final String modelPath, int loadMode) {
5252
if (!NativeLoader.isInitialized()) {
5353
NativeLoader.init(new SystemDelegate());
5454
}
55-
return new Module(new NativePeer(modelPath, loadMode));
55+
return new Module(new NativePeer(modelPath, loadMode, 0));
56+
}
57+
58+
/**
59+
* Loads a serialized ExecuTorch module from the specified path on the disk.
60+
*
61+
* @param modelPath path to file that contains the serialized ExecuTorch module.
62+
* @param loadMode load mode for the module. See constants in {@link Module}.
63+
* @param numThreads number of threads to use for inference.
64+
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
65+
*/
66+
public static Module load(final String modelPath, int loadMode, int numThreads) {
67+
if (!NativeLoader.isInitialized()) {
68+
NativeLoader.init(new SystemDelegate());
69+
}
70+
return new Module(new NativePeer(modelPath, loadMode, numThreads));
5671
}
5772

5873
/**

extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ class NativePeer {
2828
private final HybridData mHybridData;
2929

3030
@DoNotStrip
31-
private static native HybridData initHybrid(String moduleAbsolutePath, int loadMode);
31+
private static native HybridData initHybrid(String moduleAbsolutePath, int loadMode, int numThreads);
3232

33-
NativePeer(String moduleAbsolutePath, int loadMode) {
34-
mHybridData = initHybrid(moduleAbsolutePath, loadMode);
33+
NativePeer(String moduleAbsolutePath, int loadMode, int numThreads) {
34+
mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
3535
}
3636

3737
/** Clean up the native resources associated with this instance */

extension/android/jni/jni_layer.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
222222
static facebook::jni::local_ref<jhybriddata> initHybrid(
223223
facebook::jni::alias_ref<jclass>,
224224
facebook::jni::alias_ref<jstring> modelPath,
225-
jint loadMode) {
226-
return makeCxxInstance(modelPath, loadMode);
225+
jint loadMode,
226+
jint numThreads) {
227+
return makeCxxInstance(modelPath, loadMode, numThreads);
227228
}
228229

229-
ExecuTorchJni(facebook::jni::alias_ref<jstring> modelPath, jint loadMode) {
230+
ExecuTorchJni(facebook::jni::alias_ref<jstring> modelPath, jint loadMode, jint numThreads) {
230231
Module::LoadMode load_mode = Module::LoadMode::Mmap;
231232
if (loadMode == 0) {
232233
load_mode = Module::LoadMode::File;
@@ -248,11 +249,9 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
248249
// Based on testing, this is almost universally faster than using all
249250
// cores, as efficiency cores can be quite slow. In extreme cases, using
250251
// all cores can be 10x slower than using cores/2.
251-
//
252-
// TODO Allow overriding this default from Java.
253252
auto threadpool = executorch::extension::threadpool::get_threadpool();
254253
if (threadpool) {
255-
int thread_count = cpuinfo_get_processors_count() / 2;
254+
int thread_count = numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2;
256255
if (thread_count > 0) {
257256
threadpool->_unsafe_reset_threadpool(thread_count);
258257
}

0 commit comments

Comments
 (0)