Skip to content

Commit 133c5a7

Browse files
committed
Allow setting thread count from Java
Summary: Allow passing in an explicit thread count through the Java bindings. Differential Revision: D74676049
1 parent d0360b7 commit 133c5a7

File tree

4 files changed

+33
-10
lines changed

4 files changed

+33
-10
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ xcuserdata/
4040
.swiftpm/
4141
*.xcworkspace/
4242
*.xcframework/
43+
44+
# Android
45+
*.aar

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,23 @@ public static Module load(final String modelPath, int loadMode) {
5252
if (!NativeLoader.isInitialized()) {
5353
NativeLoader.init(new SystemDelegate());
5454
}
55-
return new Module(new NativePeer(modelPath, loadMode));
55+
return new Module(new NativePeer(modelPath, loadMode, 0));
56+
}
57+
58+
/**
59+
* Loads a serialized ExecuTorch module from the specified path on the disk.
60+
*
61+
* @param modelPath path to file that contains the serialized ExecuTorch module.
62+
* @param loadMode load mode for the module. See constants in {@link Module}.
63+
* @param numThreads the number of threads to use for inference. A value of 0 defaults to a
64+
* hardware-specific default.
65+
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
66+
*/
67+
public static Module load(final String modelPath, int loadMode, int numThreads) {
68+
if (!NativeLoader.isInitialized()) {
69+
NativeLoader.init(new SystemDelegate());
70+
}
71+
return new Module(new NativePeer(modelPath, loadMode, numThreads));
5672
}
5773

5874
/**

extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ class NativePeer {
2828
private final HybridData mHybridData;
2929

3030
@DoNotStrip
31-
private static native HybridData initHybrid(String moduleAbsolutePath, int loadMode);
31+
private static native HybridData initHybrid(
32+
String moduleAbsolutePath, int loadMode, int numThreads);
3233

33-
NativePeer(String moduleAbsolutePath, int loadMode) {
34-
mHybridData = initHybrid(moduleAbsolutePath, loadMode);
34+
NativePeer(String moduleAbsolutePath, int loadMode, int numThreads) {
35+
mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
3536
}
3637

3738
/** Clean up the native resources associated with this instance */

extension/android/jni/jni_layer.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,15 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
222222
static facebook::jni::local_ref<jhybriddata> initHybrid(
223223
facebook::jni::alias_ref<jclass>,
224224
facebook::jni::alias_ref<jstring> modelPath,
225-
jint loadMode) {
226-
return makeCxxInstance(modelPath, loadMode);
225+
jint loadMode,
226+
jint numThreads) {
227+
return makeCxxInstance(modelPath, loadMode, numThreads);
227228
}
228229

229-
ExecuTorchJni(facebook::jni::alias_ref<jstring> modelPath, jint loadMode) {
230+
ExecuTorchJni(
231+
facebook::jni::alias_ref<jstring> modelPath,
232+
jint loadMode,
233+
jint numThreads) {
230234
Module::LoadMode load_mode = Module::LoadMode::Mmap;
231235
if (loadMode == 0) {
232236
load_mode = Module::LoadMode::File;
@@ -248,11 +252,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
248252
// Based on testing, this is almost universally faster than using all
249253
// cores, as efficiency cores can be quite slow. In extreme cases, using
250254
// all cores can be 10x slower than using cores/2.
251-
//
252-
// TODO Allow overriding this default from Java.
253255
auto threadpool = executorch::extension::threadpool::get_threadpool();
254256
if (threadpool) {
255-
int thread_count = cpuinfo_get_processors_count() / 2;
257+
int thread_count =
258+
numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2;
256259
if (thread_count > 0) {
257260
threadpool->_unsafe_reset_threadpool(thread_count);
258261
}

0 commit comments

Comments
 (0)