Skip to content

Commit 61ddee5

Browse files
authored
Android library update for benchmarking support
Test: ``` cd extension/android/benchmark mkdir app/libs cp <executorch.aar> app/libs/executorch.aar ./gradlew :app:installDebug adb shell am start -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity --es model_path /data/local/tmp/model.pte adb shell run-as org.pytorch.minibench cat files/benchmark_results.txt ```
1 parent 36c1f54 commit 61ddee5

File tree

7 files changed

+50
-8
lines changed

7 files changed

+50
-8
lines changed

build/build_android_llm_demo.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ build_android_native_library() {
3030
-DEXECUTORCH_XNNPACK_SHARED_WORKSPACE=ON \
3131
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
3232
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
33+
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
3334
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
3435
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
3536
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \

extension/android/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,15 @@ find_package(executorch CONFIG REQUIRED)
3232
target_link_options_shared_lib(executorch)
3333

3434
set(link_libraries)
35-
list(APPEND link_libraries extension_data_loader extension_module extension_threadpool executorch
36-
fbjni
35+
list(
36+
APPEND
37+
link_libraries
38+
executorch
39+
extension_data_loader
40+
extension_module
41+
extension_runner_util
42+
extension_threadpool
43+
fbjni
3744
)
3845

3946
if(TARGET optimized_native_cpu_ops_lib)

extension/android/jni/jni_layer.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "jni_layer_constants.h"
1919

2020
#include <executorch/extension/module/module.h>
21+
#include <executorch/extension/runner_util/inputs.h>
2122
#include <executorch/extension/runner_util/managed_tensor.h>
2223
#include <executorch/runtime/core/portable_type/tensor_impl.h>
2324
#include <executorch/runtime/platform/log.h>
@@ -56,7 +57,7 @@ void et_pal_emit_log_message(
5657

5758
using namespace torch::executor;
5859

59-
namespace executorch_jni {
60+
namespace executorch::extension {
6061
class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
6162
public:
6263
constexpr static const char* kJavaDescriptor =
@@ -352,19 +353,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
352353
return jresult;
353354
}
354355

356+
jint forward_ones() {
357+
auto&& load_result = module_->load_method("forward");
358+
auto&& buf = prepare_input_tensors(*(module_->methods_["forward"].method));
359+
auto&& result = module_->methods_["forward"].method->execute();
360+
return (jint)result;
361+
}
362+
355363
static void registerNatives() {
356364
registerHybrid({
357365
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
358366
makeNativeMethod("forward", ExecuTorchJni::forward),
359367
makeNativeMethod("execute", ExecuTorchJni::execute),
360368
makeNativeMethod("loadMethod", ExecuTorchJni::load_method),
369+
makeNativeMethod("forwardOnes", ExecuTorchJni::forward_ones),
361370
});
362371
}
363372
};
364-
365-
} // namespace executorch_jni
373+
} // namespace executorch::extension
366374

367375
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
368376
return facebook::jni::initialize(
369-
vm, [] { executorch_jni::ExecuTorchJni::registerNatives(); });
377+
vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); });
370378
}

extension/android/jni/jni_layer_constants.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#include <executorch/runtime/core/portable_type/scalar_type.h>
1212

13-
namespace executorch_jni {
13+
namespace executorch::extension {
1414

1515
constexpr static int kTensorDTypeUInt8 = 0;
1616
constexpr static int kTensorDTypeInt8 = 1;
@@ -93,4 +93,4 @@ const std::unordered_map<int, ScalarType> java_dtype_to_scalar_type = {
9393
{kTensorDTypeBits16, ScalarType::Bits16},
9494
};
9595

96-
} // namespace executorch_jni
96+
} // namespace executorch::extension

extension/android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ public static Module load(final String modelPath) {
7979
* @return return value from the 'forward' method.
8080
*/
8181
public EValue[] forward(EValue... inputs) {
82+
if (inputs.length == 0) {
83+
// forward default args (ones)
84+
mNativePeer.forwardOnes();
85+
// discard the return value
86+
return null;
87+
}
8288
return mNativePeer.forward(inputs);
8389
}
8490

extension/android/src/main/java/org/pytorch/executorch/NativePeer.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import com.facebook.soloader.nativeloader.NativeLoader;
1414
import java.util.Map;
1515

16+
/** Interface for the native peer object for entry points to the Module */
1617
class NativePeer {
1718
static {
1819
// Loads libexecutorch.so from jniLibs
@@ -29,16 +30,33 @@ private static native HybridData initHybrid(
2930
mHybridData = initHybrid(moduleAbsolutePath, extraFiles, loadMode);
3031
}
3132

33+
/** Clean up the native resources associated with this instance */
3234
public void resetNative() {
3335
mHybridData.resetNative();
3436
}
3537

38+
/** Run a "forward" call with the given inputs */
3639
@DoNotStrip
3740
public native EValue[] forward(EValue... inputs);
3841

42+
/**
43+
* Run a "forward" call with the sample inputs (ones) to test a module
44+
*
45+
* @return the outputs of the forward call
46+
* @apiNote This is experimental and test-only API
47+
*/
48+
@DoNotStrip
49+
public native int forwardOnes();
50+
51+
/** Run an arbitrary method on the module */
3952
@DoNotStrip
4053
public native EValue[] execute(String methodName, EValue... inputs);
4154

55+
/**
56+
* Load a method on this module.
57+
*
58+
* @return the Error code if there was an error loading the method
59+
*/
4260
@DoNotStrip
4361
public native int loadMethod(String methodName);
4462
}

extension/module/module.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,8 @@ class Module final {
358358
std::unique_ptr<::executorch::runtime::MemoryAllocator> temp_allocator_;
359359
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer_;
360360
std::unordered_map<std::string, MethodHolder> methods_;
361+
362+
friend class ExecuTorchJni;
361363
};
362364

363365
} // namespace extension

0 commit comments

Comments
 (0)