Skip to content

Android library update for benchmarking support #5000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build/build_android_llm_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ build_android_native_library() {
-DEXECUTORCH_XNNPACK_SHARED_WORKSPACE=ON \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
Expand Down
11 changes: 9 additions & 2 deletions extension/android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,15 @@ find_package(executorch CONFIG REQUIRED)
target_link_options_shared_lib(executorch)

set(link_libraries)
list(APPEND link_libraries extension_data_loader extension_module extension_threadpool executorch
fbjni
list(
APPEND
link_libraries
executorch
extension_data_loader
extension_module
extension_runner_util
extension_threadpool
fbjni
)

if(TARGET optimized_native_cpu_ops_lib)
Expand Down
16 changes: 12 additions & 4 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "jni_layer_constants.h"

#include <executorch/extension/module/module.h>
#include <executorch/extension/runner_util/inputs.h>
#include <executorch/extension/runner_util/managed_tensor.h>
#include <executorch/runtime/core/portable_type/tensor_impl.h>
#include <executorch/runtime/platform/log.h>
Expand Down Expand Up @@ -56,7 +57,7 @@ void et_pal_emit_log_message(

using namespace torch::executor;

namespace executorch_jni {
namespace executorch::extension {
class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
public:
constexpr static const char* kJavaDescriptor =
Expand Down Expand Up @@ -352,19 +353,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
return jresult;
}

jint forward_ones() {
auto&& load_result = module_->load_method("forward");
auto&& buf = prepare_input_tensors(*(module_->methods_["forward"].method));
auto&& result = module_->methods_["forward"].method->execute();
return (jint)result;
}

static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
makeNativeMethod("forward", ExecuTorchJni::forward),
makeNativeMethod("execute", ExecuTorchJni::execute),
makeNativeMethod("loadMethod", ExecuTorchJni::load_method),
makeNativeMethod("forwardOnes", ExecuTorchJni::forward_ones),
});
}
};

} // namespace executorch_jni
} // namespace executorch::extension

JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(
vm, [] { executorch_jni::ExecuTorchJni::registerNatives(); });
vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); });
}
4 changes: 2 additions & 2 deletions extension/android/jni/jni_layer_constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#include <executorch/runtime/core/portable_type/scalar_type.h>

namespace executorch_jni {
namespace executorch::extension {

constexpr static int kTensorDTypeUInt8 = 0;
constexpr static int kTensorDTypeInt8 = 1;
Expand Down Expand Up @@ -93,4 +93,4 @@ const std::unordered_map<int, ScalarType> java_dtype_to_scalar_type = {
{kTensorDTypeBits16, ScalarType::Bits16},
};

} // namespace executorch_jni
} // namespace executorch::extension
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ public static Module load(final String modelPath) {
* @return return value from the 'forward' method.
*/
public EValue[] forward(EValue... inputs) {
if (inputs.length == 0) {
// forward default args (ones)
mNativePeer.forwardOnes();
// discard the return value
return null;
}
return mNativePeer.forward(inputs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.facebook.soloader.nativeloader.NativeLoader;
import java.util.Map;

/** Interface for the native peer object for entry points to the Module */
class NativePeer {
static {
// Loads libexecutorch.so from jniLibs
Expand All @@ -29,16 +30,33 @@ private static native HybridData initHybrid(
mHybridData = initHybrid(moduleAbsolutePath, extraFiles, loadMode);
}

/** Clean up the native resources associated with this instance */
public void resetNative() {
mHybridData.resetNative();
}

/** Run a "forward" call with the given inputs */
@DoNotStrip
public native EValue[] forward(EValue... inputs);

/**
* Run a "forward" call with the sample inputs (ones) to test a module
*
* @return the outputs of the forward call
* @apiNote This is experimental and test-only API
*/
@DoNotStrip
public native int forwardOnes();

/** Run an arbitrary method on the module */
@DoNotStrip
public native EValue[] execute(String methodName, EValue... inputs);

/**
* Load a method on this module.
*
* @return the Error code if there was an error loading the method
*/
@DoNotStrip
public native int loadMethod(String methodName);
}
2 changes: 2 additions & 0 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ class Module final {
std::unique_ptr<::executorch::runtime::MemoryAllocator> temp_allocator_;
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer_;
std::unordered_map<std::string, MethodHolder> methods_;

friend class ExecuTorchJni;
};

} // namespace extension
Expand Down
Loading