Skip to content

Allow setting thread count from Java #10858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ xcuserdata/
.swiftpm/
*.xcworkspace/
*.xcframework/

# Android
*.aar
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,27 @@ public class Module {
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
*/
public static Module load(final String modelPath, int loadMode) {
return load(modelPath, loadMode, 0);
}

/**
* Loads a serialized ExecuTorch module from the specified path on the disk.
*
* @param modelPath path to file that contains the serialized ExecuTorch module.
* @param loadMode load mode for the module. See constants in {@link Module}.
* @param numThreads the number of threads to use for inference. A value of 0 defaults to a
* hardware-specific default.
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
*/
public static Module load(final String modelPath, int loadMode, int numThreads) {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
File modelFile = new File(modelPath);
if (!modelFile.canRead() || !modelFile.isFile()) {
throw new RuntimeException("Cannot load model path " + modelPath);
}
return new Module(new NativePeer(modelPath, loadMode));
return new Module(new NativePeer(modelPath, loadMode, numThreads));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ class NativePeer {
private final HybridData mHybridData;

@DoNotStrip
private static native HybridData initHybrid(String moduleAbsolutePath, int loadMode);
private static native HybridData initHybrid(
String moduleAbsolutePath, int loadMode, int numThreads);

NativePeer(String moduleAbsolutePath, int loadMode) {
mHybridData = initHybrid(moduleAbsolutePath, loadMode);
NativePeer(String moduleAbsolutePath, int loadMode, int numThreads) {
mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
}

/** Clean up the native resources associated with this instance */
Expand Down
15 changes: 9 additions & 6 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,15 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> modelPath,
jint loadMode) {
return makeCxxInstance(modelPath, loadMode);
jint loadMode,
jint numThreads) {
return makeCxxInstance(modelPath, loadMode, numThreads);
}

ExecuTorchJni(facebook::jni::alias_ref<jstring> modelPath, jint loadMode) {
ExecuTorchJni(
facebook::jni::alias_ref<jstring> modelPath,
jint loadMode,
jint numThreads) {
Module::LoadMode load_mode = Module::LoadMode::Mmap;
if (loadMode == 0) {
load_mode = Module::LoadMode::File;
Expand All @@ -259,11 +263,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
// Based on testing, this is almost universally faster than using all
// cores, as efficiency cores can be quite slow. In extreme cases, using
// all cores can be 10x slower than using cores/2.
//
// TODO Allow overriding this default from Java.
auto threadpool = executorch::extension::threadpool::get_threadpool();
if (threadpool) {
int thread_count = cpuinfo_get_processors_count() / 2;
int thread_count =
numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2;
if (thread_count > 0) {
threadpool->_unsafe_reset_threadpool(thread_count);
}
Expand Down
Loading