Skip to content

Commit 2fa10fd

Browse files
committed
Add a new method metadata class, and allow users to get methods
1 parent b69a166 commit 2fa10fd

File tree

5 files changed

+86
-16
lines changed

5 files changed

+86
-16
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ public void testXnnpackBackendRequired() throws IOException, URISyntaxException
9999

100100
Module module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"));
101101
String[] expectedBackends = new String[] {"XnnpackBackend"};
102-
assertArrayEquals(expectedBackends, module.getUsedBackends("forward"));
102+
assertArrayEquals(expectedBackends, module.getMethodMetadata("forward").getBackends());
103103
}
104104

105105
@Test

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ public void testModuleLoadMethodNonExistantMethod() throws IOException{
109109
assertEquals(loadMethod, INVALID_ARGUMENT);
110110
}
111111

112-
@Test
112+
@Test(expected = RuntimeException.class)
113113
public void testNonPteFile() throws IOException{
114114
Module module = Module.load(getTestFilePath(NON_PTE_FILE_NAME));
115115

extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,33 @@
88

99
package org.pytorch.executorch;
1010

11-
/**
12-
* Helper class to access the metadata for a method from a Module
13-
*/
11+
/** Helper class to access the metadata for a method from a Module */
1412
public class MethodMetadata {
15-
private String name;
13+
private String mName;
14+
15+
private String[] mBackends;
16+
17+
MethodMetadata setName(String name) {
18+
mName = name;
19+
return this;
20+
}
21+
22+
/**
23+
* @return Method name
24+
*/
25+
public String getName() {
26+
return mName;
27+
}
28+
29+
MethodMetadata setBackends(String[] backends) {
30+
mBackends = backends;
31+
return this;
32+
}
1633

17-
public String getName() {
18-
return name;
19-
}
34+
/**
35+
* @return Backends used for this method
36+
*/
37+
public String[] getBackends() {
38+
return mBackends;
39+
}
2040
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
package org.pytorch.executorch;
1010

1111
import android.util.Log;
12-
1312
import com.facebook.jni.HybridData;
1413
import com.facebook.jni.annotations.DoNotStrip;
1514
import com.facebook.soloader.nativeloader.NativeLoader;
1615
import com.facebook.soloader.nativeloader.SystemDelegate;
1716
import java.io.File;
17+
import java.util.HashMap;
18+
import java.util.Map;
1819
import java.util.concurrent.locks.Lock;
1920
import java.util.concurrent.locks.ReentrantLock;
2021
import org.pytorch.executorch.annotations.Experimental;
@@ -49,19 +50,26 @@ public class Module {
4950

5051
private final HybridData mHybridData;
5152

52-
private final MethodMetadata[] mMethodMetadata;
53+
private final Map<String, MethodMetadata> mMethodMetadata;
5354

5455
@DoNotStrip
5556
private static native HybridData initHybrid(String moduleAbsolutePath, int loadMode);
5657

5758
private Module(String moduleAbsolutePath, int loadMode) {
5859
mHybridData = initHybrid(moduleAbsolutePath, loadMode);
5960

60-
populateMethodMeta();
61+
mMethodMetadata = populateMethodMeta();
6162
}
6263

63-
void populateMethodMeta() {
64-
mMethodMetadata = new MethodMetadata[1];
64+
Map<String, MethodMetadata> populateMethodMeta() {
65+
String[] methods = getMethods();
66+
Map<String, MethodMetadata> metadata = new HashMap<String, MethodMetadata>();
67+
for (int i = 0; i < methods.length; i++) {
68+
String name = methods[i];
69+
metadata.put(name, new MethodMetadata().setName(name).setBackends(getUsedBackends(name)));
70+
}
71+
72+
return metadata;
6573
}
6674

6775
/** Lock protecting the non-thread safe methods in mHybridData. */
@@ -153,13 +161,34 @@ public int loadMethod(String methodName) {
153161
private native int loadMethodNative(String methodName);
154162

155163
/**
156-
* Returns the names of the methods in a certain method.
164+
* Returns the names of the backends in a certain method.
157165
*
158166
* @param methodName method name to query
159167
* @return an array of backend name
160168
*/
161169
@DoNotStrip
162-
public native String[] getUsedBackends(String methodName);
170+
private native String[] getUsedBackends(String methodName);
171+
172+
/**
173+
* Returns the names of methods.
174+
*
175+
* @return name of methods in this Module
176+
*/
177+
@DoNotStrip
178+
public native String[] getMethods();
179+
180+
/**
181+
* Get the corresponding @MethodMetadata for a method
182+
*
183+
* @param name method name
184+
* @return @MethodMetadata for this method
185+
*/
186+
public MethodMetadata getMethodMetadata(String name) {
187+
if (!mMethodMetadata.containsKey(name)) {
188+
throw new RuntimeException("method " + name + "does not exist for this module");
189+
}
190+
return mMethodMetadata.get(name);
191+
}
163192

164193
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
165194
@DoNotStrip

extension/android/jni/jni_layer.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
429429
return false;
430430
}
431431

432+
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getMethods() {
433+
const auto& names_result = module_->method_names();
434+
if (!names_result.ok()) {
435+
facebook::jni::throwNewJavaException(
436+
facebook::jni::gJavaLangIllegalArgumentException,
437+
"Cannot get load module");
438+
}
439+
const auto& methods = names_result.get();
440+
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
441+
facebook::jni::JArrayClass<jstring>::newArray(methods.size());
442+
int i = 0;
443+
for (auto s : methods) {
444+
facebook::jni::local_ref<facebook::jni::JString> method_name =
445+
facebook::jni::make_jstring(s.c_str());
446+
(*ret)[i] = method_name;
447+
i++;
448+
}
449+
return ret;
450+
}
451+
432452
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
433453
facebook::jni::alias_ref<jstring> methodName) {
434454
auto methodMeta = module_->method_meta(methodName->toStdString()).get();
@@ -456,6 +476,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
456476
makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method),
457477
makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer),
458478
makeNativeMethod("etdump", ExecuTorchJni::etdump),
479+
makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
459480
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
460481
});
461482
}

0 commit comments

Comments
 (0)