Skip to content

Android method metadata #11023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 21, 2025
1 change: 1 addition & 0 deletions extension/android/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ non_fbcode_target(_kind = fb_android_library,
srcs = [
"executorch_android/src/main/java/org/pytorch/executorch/DType.java",
"executorch_android/src/main/java/org/pytorch/executorch/EValue.java",
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ModuleE2ETest {

val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"))
val expectedBackends = arrayOf("XnnpackBackend")
Assert.assertArrayEquals(expectedBackends, module.getUsedBackends("forward"))
Assert.assertArrayEquals(expectedBackends, module.getMethodMetadata("forward").getBackends())
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ class ModuleInstrumentationTest {
Assert.assertTrue(results[0].isTensor)
}

@Test
@Throws(IOException::class, URISyntaxException::class)
fun testMethodMetadata() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))

Assert.assertArrayEquals(arrayOf("forward"), module.getMethods())
Assert.assertTrue(module.getMethodMetadata("forward").backends.isEmpty())
}

@Test
@Throws(IOException::class)
fun testModuleLoadMethodAndForward() {
Expand Down Expand Up @@ -91,7 +100,7 @@ class ModuleInstrumentationTest {
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
}

@Test
@Test(expected = RuntimeException::class)
@Throws(IOException::class)
fun testNonPteFile() {
val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch;

/** Helper class to access the metadata for a method from a Module */
public class MethodMetadata {
private String mName;

private String[] mBackends;

MethodMetadata setName(String name) {
mName = name;
return this;
}

/**
* @return Method name
*/
public String getName() {
return mName;
}

MethodMetadata setBackends(String[] backends) {
mBackends = backends;
return this;
}

/**
* @return Backends used for this method
*/
public String[] getBackends() {
return mBackends;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.pytorch.executorch.annotations.Experimental;
Expand Down Expand Up @@ -48,12 +50,27 @@ public class Module {

private final HybridData mHybridData;

private final Map<String, MethodMetadata> mMethodMetadata;

@DoNotStrip
private static native HybridData initHybrid(
String moduleAbsolutePath, int loadMode, int initHybrid);

private Module(String moduleAbsolutePath, int loadMode, int numThreads) {
mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);

mMethodMetadata = populateMethodMeta();
}

Map<String, MethodMetadata> populateMethodMeta() {
String[] methods = getMethods();
Map<String, MethodMetadata> metadata = new HashMap<String, MethodMetadata>();
for (int i = 0; i < methods.length; i++) {
String name = methods[i];
metadata.put(name, new MethodMetadata().setName(name).setBackends(getUsedBackends(name)));
}

return metadata;
}

/** Lock protecting the non-thread safe methods in mHybridData. */
Expand Down Expand Up @@ -158,13 +175,34 @@ public int loadMethod(String methodName) {
private native int loadMethodNative(String methodName);

/**
* Returns the names of the methods in a certain method.
* Returns the names of the backends in a certain method.
*
* @param methodName method name to query
* @return an array of backend name
*/
@DoNotStrip
public native String[] getUsedBackends(String methodName);
private native String[] getUsedBackends(String methodName);

/**
* Returns the names of methods.
*
* @return name of methods in this Module
*/
@DoNotStrip
public native String[] getMethods();

/**
* Get the corresponding @MethodMetadata for a method
*
* @param name method name
* @return @MethodMetadata for this method
*/
public MethodMetadata getMethodMetadata(String name) {
if (!mMethodMetadata.containsKey(name)) {
throw new RuntimeException("method " + name + "does not exist for this module");
}
return mMethodMetadata.get(name);
}

/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
public String[] readLogBuffer() {
Expand Down
21 changes: 21 additions & 0 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
return false;
}

facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getMethods() {
const auto& names_result = module_->method_names();
if (!names_result.ok()) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Cannot get load module");
}
const auto& methods = names_result.get();
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
facebook::jni::JArrayClass<jstring>::newArray(methods.size());
int i = 0;
for (auto s : methods) {
facebook::jni::local_ref<facebook::jni::JString> method_name =
facebook::jni::make_jstring(s.c_str());
(*ret)[i] = method_name;
i++;
}
return ret;
}

facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
facebook::jni::alias_ref<jstring> methodName) {
auto methodMeta = module_->method_meta(methodName->toStdString()).get();
Expand Down Expand Up @@ -458,6 +478,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method),
makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer),
makeNativeMethod("etdump", ExecuTorchJni::etdump),
makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
});
}
Expand Down
Loading