Skip to content

Commit 2660287

Browse files
authored
added instrumentation test for LlamaModule (#6759)
Added instrumentation test for LlamaModule. Modified setup.sh to include building stories110M model and moves it into src/androidTest/resources Added test cases for LlamaModule by generating a sequence length of 32, and verifying the length. Also verifies that stop() works by checking output length is less than input sequence length [ghstack-poisoned]
1 parent b23c9e6 commit 2660287

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

extension/android_test/setup.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@ build_native_library() {
2121
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
2222
-DANDROID_ABI="${ANDROID_ABI}" \
2323
-DEXECUTORCH_BUILD_XNNPACK=ON \
24+
-DEXECUTORCH_XNNPACK_SHARED_WORKSPACE=ON \
2425
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
2526
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
2627
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
2728
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
29+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
30+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
2831
-B"${CMAKE_OUT}"
2932

3033
cmake --build "${CMAKE_OUT}" -j16 --target install
@@ -33,6 +36,7 @@ build_native_library() {
3336
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}"/build/cmake/android.toolchain.cmake \
3437
-DANDROID_ABI="${ANDROID_ABI}" \
3538
-DCMAKE_INSTALL_PREFIX=c"${CMAKE_OUT}" \
39+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
3640
-DEXECUTORCH_BUILD_LLAMA_JNI=ON \
3741
-B"${CMAKE_OUT}"/extension/android
3842

@@ -48,6 +52,8 @@ build_jar
4852
build_native_library "arm64-v8a"
4953
build_native_library "x86_64"
5054
build_aar
55+
source ".ci/scripts/test_llama.sh" stories110M cmake fp16 portable ${BUILD_AAR_DIR}
5156
popd
5257
mkdir -p "$BASEDIR"/src/libs
5358
cp "$BUILD_AAR_DIR/executorch.aar" "$BASEDIR"/src/libs/executorch.aar
59+
unzip -o "$BUILD_AAR_DIR"/model.zip -d "$BASEDIR"/src/androidTest/resources
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package com.example.executorch;
10+
11+
import static org.junit.Assert.assertEquals;
12+
import static org.junit.Assert.assertTrue;
13+
import static org.junit.Assert.assertFalse;
14+
import static org.junit.Assert.assertNotEquals;
15+
import static org.junit.Assert.fail;
16+
17+
import android.os.Environment;
18+
import androidx.test.rule.GrantPermissionRule;
19+
import android.Manifest;
20+
import android.content.Context;
21+
import org.junit.Test;
22+
import org.junit.Before;
23+
import org.junit.Rule;
24+
import org.junit.runner.RunWith;
25+
import java.io.InputStream;
26+
import java.net.URI;
27+
import java.net.URISyntaxException;
28+
import java.util.List;
29+
import java.util.ArrayList;
30+
import java.io.IOException;
31+
import java.io.File;
32+
import java.io.FileOutputStream;
33+
import org.junit.runners.JUnit4;
34+
import org.apache.commons.io.FileUtils;
35+
import androidx.test.ext.junit.runners.AndroidJUnit4;
36+
import androidx.test.InstrumentationRegistry;
37+
import org.pytorch.executorch.LlamaModule;
38+
import org.pytorch.executorch.LlamaCallback;
39+
import org.pytorch.executorch.Module;
40+
import org.pytorch.executorch.EValue;
41+
import org.pytorch.executorch.Tensor;
42+
43+
/** Unit tests for {@link LlamaModule}. */
44+
@RunWith(AndroidJUnit4.class)
45+
public class LlamaModuleInstrumentationTest implements LlamaCallback {
46+
private static String TEST_FILE_NAME = "/tinyllama_portable_fp16_h.pte";
47+
private static String TOKENIZER_FILE_NAME = "/tokenizer.bin";
48+
private static String TEST_PROMPT = "Hello";
49+
private static int OK = 0x00;
50+
private static int SEQ_LEN = 32;
51+
52+
private final List<String> results = new ArrayList<>();
53+
private final List<Float> tokensPerSecond = new ArrayList<>();
54+
private LlamaModule mModule;
55+
56+
private static String getTestFilePath(String fileName) {
57+
return InstrumentationRegistry.getInstrumentation().getTargetContext().getExternalCacheDir() + fileName;
58+
}
59+
60+
@Before
61+
public void setUp() throws IOException {
62+
// copy zipped test resources to local device
63+
File addPteFile = new File(getTestFilePath(TEST_FILE_NAME));
64+
InputStream inputStream = getClass().getResourceAsStream(TEST_FILE_NAME);
65+
FileUtils.copyInputStreamToFile(inputStream, addPteFile);
66+
inputStream.close();
67+
68+
File tokenizerFile = new File(getTestFilePath(TOKENIZER_FILE_NAME));
69+
inputStream = getClass().getResourceAsStream(TOKENIZER_FILE_NAME);
70+
FileUtils.copyInputStreamToFile(inputStream, tokenizerFile);
71+
inputStream.close();
72+
73+
mModule = new LlamaModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f);
74+
}
75+
76+
@Rule
77+
public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE);
78+
79+
@Test
80+
public void testGenerate() throws IOException, URISyntaxException{
81+
int loadResult = mModule.load();
82+
// Check that the model can be load successfully
83+
assertEquals(OK, loadResult);
84+
85+
mModule.generate(TEST_PROMPT, SEQ_LEN, LlamaModuleInstrumentationTest.this);
86+
assertEquals(results.size(), SEQ_LEN);
87+
assertTrue(tokensPerSecond.get(tokensPerSecond.size() - 1) > 0);
88+
}
89+
90+
@Test
91+
public void testGenerateAndStop() throws IOException, URISyntaxException{
92+
int seqLen = 32;
93+
mModule.generate(TEST_PROMPT, SEQ_LEN, new LlamaCallback() {
94+
@Override
95+
public void onResult(String result) {
96+
LlamaModuleInstrumentationTest.this.onResult(result);
97+
mModule.stop();
98+
}
99+
100+
@Override
101+
public void onStats(float tps) {
102+
LlamaModuleInstrumentationTest.this.onStats(tps);
103+
}
104+
});
105+
106+
int stoppedResultSize = results.size();
107+
assertTrue(stoppedResultSize < SEQ_LEN);
108+
}
109+
110+
@Override
111+
public void onResult(String result) {
112+
results.add(result);
113+
}
114+
115+
@Override
116+
public void onStats(float tps) {
117+
tokensPerSecond.add(tps);
118+
}
119+
}

0 commit comments

Comments
 (0)