Skip to content

Android backend used by method #10934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 17, 2025
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.pytorch.executorch;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse;
Expand Down Expand Up @@ -89,6 +90,18 @@ public void testClassification(String filePath) throws IOException, URISyntaxExc
assertEquals(bananaClass, argmax(scores));
}

@Test
public void testXnnpackBackendRequired() throws IOException, URISyntaxException {
File pteFile = new File(getTestFilePath("/mv3_xnnpack_fp32.pte"));
InputStream inputStream = getClass().getResourceAsStream("/mv3_xnnpack_fp32.pte");
FileUtils.copyInputStreamToFile(inputStream, pteFile);
inputStream.close();

Module module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"));
String[] expectedBackends = new String[] {"XnnpackBackend"};
assertArrayEquals(expectedBackends, module.getUsedBackends("forward"));
}

@Test
public void testMv2Fp32() throws IOException, URISyntaxException {
testClassification("/mv2_xnnpack_fp32.pte");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ public int loadMethod(String methodName) {
}
}

/**
* Returns the names of the methods in a certain method.
*
* @param methodName method name to query
* @return an array of backend name
*/
public String[] getUsedBackends(String methodName) {
return mNativePeer.getUsedBackends(methodName);
}

/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
public String[] readLogBuffer() {
return mNativePeer.readLogBuffer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ public void resetNative() {
@DoNotStrip
public native int loadMethod(String methodName);

/** Return the list of backends used by a method */
@DoNotStrip
public native String[] getUsedBackends(String methodName);

/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
@DoNotStrip
public native String[] readLogBuffer();
Expand Down
22 changes: 22 additions & 0 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "jni_layer_constants.h"
Expand Down Expand Up @@ -395,13 +396,34 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
#endif
}

facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
facebook::jni::alias_ref<jstring> methodName) {
auto methodMeta = module_->method_meta(methodName->toStdString()).get();
std::unordered_set<std::string> backends;
for (auto i = 0; i < methodMeta.num_backends(); i++) {
backends.insert(methodMeta.get_backend_name(i).get());
}

facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
facebook::jni::JArrayClass<jstring>::newArray(backends.size());
int i = 0;
for (auto s : backends) {
facebook::jni::local_ref<facebook::jni::JString> backend_name =
facebook::jni::make_jstring(s.c_str());
(*ret)[i] = backend_name;
i++;
}
return ret;
}

static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
makeNativeMethod("forward", ExecuTorchJni::forward),
makeNativeMethod("execute", ExecuTorchJni::execute),
makeNativeMethod("loadMethod", ExecuTorchJni::load_method),
makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer),
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
});
}
};
Expand Down
Loading