Skip to content

move junit tests to android_test #6761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion extension/android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,5 @@ task makeJar(type: Jar) {
dependencies {
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
implementation 'com.facebook.soloader:nativeloader:0.10.5'
testImplementation 'junit:junit:4.13.2'
}
}
26 changes: 26 additions & 0 deletions extension/android_test/add_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
from executorch.exir import to_edge
from torch.export import export


# Start with a PyTorch model that adds two input tensors (matrices)
class Add(torch.nn.Module):
def __init__(self):
super(Add, self).__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor):
return x + y


# 1. torch.export: Defines the program with the ATen operator set.
aten_dialect = export(Add(), (torch.ones(1), torch.ones(1)))

# 2. to_edge: Make optimizations for Edge devices
edge_program = to_edge(aten_dialect)

# 3. to_executorch: Convert the graph to an ExecuTorch program
executorch_program = edge_program.to_executorch()

# 4. Save the compiled .pte program
with open("add.pte", "wb") as file:
file.write(executorch_program.buffer)
8 changes: 8 additions & 0 deletions extension/android_test/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ build_native_library() {
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
-DANDROID_ABI="${ANDROID_ABI}" \
-DEXECUTORCH_BUILD_XNNPACK=ON \
-DEXECUTORCH_XNNPACK_SHARED_WORKSPACE=ON \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-B"${CMAKE_OUT}"

cmake --build "${CMAKE_OUT}" -j16 --target install
Expand All @@ -33,6 +36,7 @@ build_native_library() {
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}"/build/cmake/android.toolchain.cmake \
-DANDROID_ABI="${ANDROID_ABI}" \
-DCMAKE_INSTALL_PREFIX=c"${CMAKE_OUT}" \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DEXECUTORCH_BUILD_LLAMA_JNI=ON \
-B"${CMAKE_OUT}"/extension/android

Expand All @@ -48,6 +52,10 @@ build_jar
build_native_library "arm64-v8a"
build_native_library "x86_64"
build_aar
source ".ci/scripts/test_llama.sh" stories110M cmake fp16 portable ${BUILD_AAR_DIR}
popd
mkdir -p "$BASEDIR"/src/libs
cp "$BUILD_AAR_DIR/executorch.aar" "$BASEDIR"/src/libs/executorch.aar
python add_model.py
mv "add.pte" "$BASEDIR"/src/androidTest/resources/add.pte
unzip -o "$BUILD_AAR_DIR"/model.zip -d "$BASEDIR"/src/androidTest/resources
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.example.executorch;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.fail;

import android.os.Environment;
import androidx.test.rule.GrantPermissionRule;
import android.Manifest;
import android.content.Context;
import org.junit.Test;
import org.junit.Before;
import org.junit.Rule;
import org.junit.runner.RunWith;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import java.util.ArrayList;
import java.io.IOException;
import java.io.File;
import java.io.FileOutputStream;
import org.junit.runners.JUnit4;
import org.apache.commons.io.FileUtils;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.InstrumentationRegistry;
import org.pytorch.executorch.LlamaModule;
import org.pytorch.executorch.LlamaCallback;
import org.pytorch.executorch.Module;
import org.pytorch.executorch.EValue;
import org.pytorch.executorch.Tensor;

/** Unit tests for {@link LlamaModule}. */
@RunWith(AndroidJUnit4.class)
public class LlamaModuleInstrumentationTest implements LlamaCallback {
private static String TEST_FILE_NAME = "/tinyllama_portable_fp16_h.pte";
private static String TOKENIZER_FILE_NAME = "/tokenizer.bin";
private static String TEST_PROMPT = "Hello";
private static int OK = 0x00;
private static int SEQ_LEN = 32;

private final List<String> results = new ArrayList<>();
private final List<Float> tokensPerSecond = new ArrayList<>();
private LlamaModule mModule;

private static String getTestFilePath(String fileName) {
return InstrumentationRegistry.getInstrumentation().getTargetContext().getExternalCacheDir() + fileName;
}

@Before
public void setUp() throws IOException {
// copy zipped test resources to local device
File addPteFile = new File(getTestFilePath(TEST_FILE_NAME));
InputStream inputStream = getClass().getResourceAsStream(TEST_FILE_NAME);
FileUtils.copyInputStreamToFile(inputStream, addPteFile);
inputStream.close();

File tokenizerFile = new File(getTestFilePath(TOKENIZER_FILE_NAME));
inputStream = getClass().getResourceAsStream(TOKENIZER_FILE_NAME);
FileUtils.copyInputStreamToFile(inputStream, tokenizerFile);
inputStream.close();

mModule = new LlamaModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f);
}

@Rule
public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE);

@Test
public void testGenerate() throws IOException, URISyntaxException{
int loadResult = mModule.load();
// Check that the model can be load successfully
assertEquals(OK, loadResult);

mModule.generate(TEST_PROMPT, SEQ_LEN, LlamaModuleInstrumentationTest.this);
assertEquals(results.size(), SEQ_LEN);
assertTrue(tokensPerSecond.get(tokensPerSecond.size() - 1) > 0);
}

@Test
public void testGenerateAndStop() throws IOException, URISyntaxException{
int seqLen = 32;
mModule.generate(TEST_PROMPT, SEQ_LEN, new LlamaCallback() {
@Override
public void onResult(String result) {
LlamaModuleInstrumentationTest.this.onResult(result);
mModule.stop();
}

@Override
public void onStats(float tps) {
LlamaModuleInstrumentationTest.this.onStats(tps);
}
});

int stoppedResultSize = results.size();
assertTrue(stoppedResultSize < SEQ_LEN);
}

@Override
public void onResult(String result) {
results.add(result);
}

@Override
public void onStats(float tps) {
tokensPerSecond.add(tps);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public void testOptionalTensorListValue() {
Optional.of(Tensor.fromBlob(data[1], shape[1])));
assertTrue(evalue.isOptionalTensorList());

assertTrue(evalue.toOptionalTensorList()[0].isEmpty());
assertTrue(!evalue.toOptionalTensorList()[0].isPresent());

assertTrue(evalue.toOptionalTensorList()[1].isPresent());
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().shape, shape[0]));
Expand Down
Loading