Skip to content

Commit 4e38f4a

Browse files
authored
Allow setting thread count from Java
Differential Revision: D74676049 Pull Request resolved: #10858
1 parent 074b392 commit 4e38f4a

File tree

4 files changed

+30
-10
lines changed

4 files changed

+30
-10
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ xcuserdata/
4040
.swiftpm/
4141
*.xcworkspace/
4242
*.xcframework/
43+
44+
# Android
45+
*.aar

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,27 @@ public class Module {
5050
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
5151
*/
5252
public static Module load(final String modelPath, int loadMode) {
53+
return load(modelPath, loadMode, 0);
54+
}
55+
56+
/**
57+
* Loads a serialized ExecuTorch module from the specified path on the disk.
58+
*
59+
* @param modelPath path to file that contains the serialized ExecuTorch module.
60+
* @param loadMode load mode for the module. See constants in {@link Module}.
61+
* @param numThreads the number of threads to use for inference. A value of 0 defaults to a
62+
* hardware-specific default.
63+
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
64+
*/
65+
public static Module load(final String modelPath, int loadMode, int numThreads) {
5366
if (!NativeLoader.isInitialized()) {
5467
NativeLoader.init(new SystemDelegate());
5568
}
5669
File modelFile = new File(modelPath);
5770
if (!modelFile.canRead() || !modelFile.isFile()) {
5871
throw new RuntimeException("Cannot load model path " + modelPath);
5972
}
60-
return new Module(new NativePeer(modelPath, loadMode));
73+
return new Module(new NativePeer(modelPath, loadMode, numThreads));
6174
}
6275

6376
/**

extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ class NativePeer {
2828
private final HybridData mHybridData;
2929

3030
@DoNotStrip
31-
private static native HybridData initHybrid(String moduleAbsolutePath, int loadMode);
31+
private static native HybridData initHybrid(
32+
String moduleAbsolutePath, int loadMode, int numThreads);
3233

33-
NativePeer(String moduleAbsolutePath, int loadMode) {
34-
mHybridData = initHybrid(moduleAbsolutePath, loadMode);
34+
NativePeer(String moduleAbsolutePath, int loadMode, int numThreads) {
35+
mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
3536
}
3637

3738
/** Clean up the native resources associated with this instance */

extension/android/jni/jni_layer.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,15 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
227227
static facebook::jni::local_ref<jhybriddata> initHybrid(
228228
facebook::jni::alias_ref<jclass>,
229229
facebook::jni::alias_ref<jstring> modelPath,
230-
jint loadMode) {
231-
return makeCxxInstance(modelPath, loadMode);
230+
jint loadMode,
231+
jint numThreads) {
232+
return makeCxxInstance(modelPath, loadMode, numThreads);
232233
}
233234

234-
ExecuTorchJni(facebook::jni::alias_ref<jstring> modelPath, jint loadMode) {
235+
ExecuTorchJni(
236+
facebook::jni::alias_ref<jstring> modelPath,
237+
jint loadMode,
238+
jint numThreads) {
235239
Module::LoadMode load_mode = Module::LoadMode::Mmap;
236240
if (loadMode == 0) {
237241
load_mode = Module::LoadMode::File;
@@ -258,11 +262,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
258262
// Based on testing, this is almost universally faster than using all
259263
// cores, as efficiency cores can be quite slow. In extreme cases, using
260264
// all cores can be 10x slower than using cores/2.
261-
//
262-
// TODO Allow overriding this default from Java.
263265
auto threadpool = executorch::extension::threadpool::get_threadpool();
264266
if (threadpool) {
265-
int thread_count = cpuinfo_get_processors_count() / 2;
267+
int thread_count =
268+
numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2;
266269
if (thread_count > 0) {
267270
threadpool->_unsafe_reset_threadpool(thread_count);
268271
}

0 commit comments

Comments
 (0)