Skip to content

Commit eab1be3

Browse files
committed
Allow setting thread count from Java
Summary: Allow passing in an explicit thread count through the Java bindings. Differential Revision: D74676049
1 parent 5866c19 commit eab1be3

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ xcuserdata/
4040
.swiftpm/
4141
*.xcworkspace/
4242
*.xcframework/
43+
44+
# Android
45+
*.aar

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,19 @@ public class Module {
4747
*
4848
* @param modelPath path to file that contains the serialized ExecuTorch module.
4949
* @param loadMode load mode for the module. See constants in {@link Module}.
50+
* @param numThreads the number of threads to use for inference. A value of 0 defaults to a
51+
* hardware-specific default.
5052
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
5153
*/
52-
public static Module load(final String modelPath, int loadMode) {
54+
public static Module load(final String modelPath, int loadMode, int numThreads = 0) {
5355
if (!NativeLoader.isInitialized()) {
5456
NativeLoader.init(new SystemDelegate());
5557
}
5658
File modelFile = new File(modelPath);
5759
if (!modelFile.canRead() || !modelFile.isFile()) {
5860
throw new RuntimeException("Cannot load model path " + modelPath);
5961
}
60-
return new Module(new NativePeer(modelPath, loadMode));
62+
return new Module(new NativePeer(modelPath, loadMode, numThreads));
6163
}
6264

6365
/**

extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ class NativePeer {
2828
private final HybridData mHybridData;
2929

3030
@DoNotStrip
31-
private static native HybridData initHybrid(String moduleAbsolutePath, int loadMode);
31+
private static native HybridData initHybrid(
32+
String moduleAbsolutePath, int loadMode, int numThreads);
3233

33-
NativePeer(String moduleAbsolutePath, int loadMode) {
34-
mHybridData = initHybrid(moduleAbsolutePath, loadMode);
34+
NativePeer(String moduleAbsolutePath, int loadMode, int numThreads) {
35+
mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
3536
}
3637

3738
/** Clean up the native resources associated with this instance */

extension/android/jni/jni_layer.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,15 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
228228
static facebook::jni::local_ref<jhybriddata> initHybrid(
229229
facebook::jni::alias_ref<jclass>,
230230
facebook::jni::alias_ref<jstring> modelPath,
231-
jint loadMode) {
232-
return makeCxxInstance(modelPath, loadMode);
231+
jint loadMode,
232+
jint numThreads) {
233+
return makeCxxInstance(modelPath, loadMode, numThreads);
233234
}
234235

235-
ExecuTorchJni(facebook::jni::alias_ref<jstring> modelPath, jint loadMode) {
236+
ExecuTorchJni(
237+
facebook::jni::alias_ref<jstring> modelPath,
238+
jint loadMode,
239+
jint numThreads) {
236240
Module::LoadMode load_mode = Module::LoadMode::Mmap;
237241
if (loadMode == 0) {
238242
load_mode = Module::LoadMode::File;
@@ -259,11 +263,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
259263
// Based on testing, this is almost universally faster than using all
260264
// cores, as efficiency cores can be quite slow. In extreme cases, using
261265
// all cores can be 10x slower than using cores/2.
262-
//
263-
// TODO Allow overriding this default from Java.
264266
auto threadpool = executorch::extension::threadpool::get_threadpool();
265267
if (threadpool) {
266-
int thread_count = cpuinfo_get_processors_count() / 2;
268+
int thread_count =
269+
numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2;
267270
if (thread_count > 0) {
268271
threadpool->_unsafe_reset_threadpool(thread_count);
269272
}

0 commit comments

Comments
 (0)