Skip to content

Commit 5bda99f

Browse files
committed
Now it should fix
1 parent 88d6397 commit 5bda99f

File tree

6 files changed

+37
-76
lines changed

6 files changed

+37
-76
lines changed

.github/workflows/_android.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,6 @@ jobs:
139139
heap-size: 12288M
140140
force-avd-creation: false
141141
disable-animations: true
142-
emulator-options: -memory 65536 -no-snapshot-save -no-window -gpu swiftshader_indirect -noaudio -no-boot-anim -camera-back none
142+
emulator-options: -no-snapshot-save -no-window -gpu swiftshader_indirect -noaudio -no-boot-anim -camera-back none
143143
# This is to make sure that the job doesn't fail flakily
144144
emulator-boot-timeout: 900

extension/android/executorch_android/android_test_setup.sh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ prepare_tinyllama() {
2525
# Create params.json file
2626
touch params.json
2727
echo '{"dim": 288, "multiple_of": 32, "n_heads": 6, "n_layers": 6, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
28-
python -m examples.models.llama.export_llama -c stories15M.pt -p params.json -d fp32 -n stories15m_h.pte -kv
28+
python -m examples.models.llama.export_llama -c stories15M.pt -p params.json -d fp16 -n stories15m_h.pte -kv
2929
python -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin
3030

3131
cp stories15m_h.pte "${BASEDIR}/src/androidTest/resources/stories.pte"
@@ -36,9 +36,7 @@ prepare_tinyllama() {
3636
prepare_vision() {
3737
pushd "${BASEDIR}/../../../"
3838
python3 -m examples.xnnpack.aot_compiler --model_name "mv2" --delegate
39-
python3 -m examples.xnnpack.aot_compiler --model_name "mv3" --delegate
40-
python3 -m examples.xnnpack.aot_compiler --model_name "resnet50" --quantize --delegate
41-
cp mv2*.pte mv3*.pte resnet50*.pte "${BASEDIR}/src/androidTest/resources/"
39+
cp mv2*.pte "${BASEDIR}/src/androidTest/resources/"
4240
popd
4341
}
4442

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,28 @@ public void testGenerate() throws IOException, URISyntaxException{
7979
// Check that the model can be load successfully
8080
assertEquals(OK, loadResult);
8181

82+
mModule.generate(TEST_PROMPT, SEQ_LEN, LlmModuleInstrumentationTest.this);
83+
assertEquals(results.size(), SEQ_LEN);
84+
assertTrue(tokensPerSecond.get(tokensPerSecond.size() - 1) > 0);
85+
}
86+
87+
@Test
88+
public void testGenerateAndStop() throws IOException, URISyntaxException{
89+
mModule.generate(TEST_PROMPT, SEQ_LEN, new LlmCallback() {
90+
@Override
91+
public void onResult(String result) {
92+
LlmModuleInstrumentationTest.this.onResult(result);
93+
mModule.stop();
94+
}
95+
96+
@Override
97+
public void onStats(float tps) {
98+
LlmModuleInstrumentationTest.this.onStats(tps);
99+
}
100+
});
101+
102+
int stoppedResultSize = results.size();
103+
assertTrue(stoppedResultSize < SEQ_LEN);
82104
}
83105

84106
@Override

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java

Lines changed: 4 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import static org.junit.Assert.assertNotEquals;
1515
import static org.junit.Assert.fail;
1616

17-
import android.graphics.Bitmap;
18-
import android.graphics.BitmapFactory;
1917
import android.os.Environment;
2018
import androidx.test.rule.GrantPermissionRule;
2119
import android.Manifest;
@@ -47,72 +45,18 @@ private static String getTestFilePath(String fileName) {
4745
@Rule
4846
public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE);
4947

50-
static int argmax(float[] array) {
51-
if (array.length == 0) {
52-
throw new IllegalArgumentException("Array cannot be empty");
53-
}
54-
int maxIndex = 0;
55-
float maxValue = array[0];
56-
for (int i = 1; i < array.length; i++) {
57-
if (array[i] > maxValue) {
58-
maxValue = array[i];
59-
maxIndex = i;
60-
}
61-
}
62-
return maxIndex;
63-
}
64-
65-
public void testClassification(String filePath) throws IOException, URISyntaxException {
66-
File pteFile = new File(getTestFilePath(filePath));
67-
InputStream inputStream = getClass().getResourceAsStream(filePath);
68-
FileUtils.copyInputStreamToFile(inputStream, pteFile);
69-
inputStream.close();
70-
71-
InputStream imgInputStream = getClass().getResourceAsStream("/banana.jpeg");
72-
Bitmap bitmap = BitmapFactory.decodeStream(imgInputStream);
73-
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);
74-
imgInputStream.close();
75-
76-
Tensor inputTensor =
77-
TensorImageUtils.bitmapToFloat32Tensor(
78-
bitmap,
79-
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
80-
TensorImageUtils.TORCHVISION_NORM_STD_RGB);
81-
82-
Module module = Module.load(getTestFilePath(filePath));
83-
84-
EValue[] results = module.forward(EValue.from(inputTensor));
85-
assertTrue(results[0].isTensor());
86-
float[] scores = results[0].toTensor().getDataAsFloatArray();
87-
88-
int bananaClass = 954; // From ImageNet 1K
89-
assertEquals(bananaClass, argmax(scores));
90-
}
91-
9248
@Test
93-
public void testStories() throws IOException, URISyntaxException{
94-
String filePath = "/stories.pte";
49+
public void testMv2Fp32() throws IOException, URISyntaxException{
50+
String filePath = "/mv2_xnnpack_fp32.pte";
9551
File pteFile = new File(getTestFilePath(filePath));
9652
InputStream inputStream = getClass().getResourceAsStream(filePath);
9753
FileUtils.copyInputStreamToFile(inputStream, pteFile);
9854
inputStream.close();
9955

10056
Module module = Module.load(getTestFilePath(filePath));
101-
module.loadMethod("forward");
102-
module.forward();
103-
}
10457

105-
public void testMv2Fp32() throws IOException, URISyntaxException {
106-
testClassification("/mv2_xnnpack_fp32.pte");
107-
}
108-
109-
@Test
110-
public void testMv3Fp32() throws IOException, URISyntaxException {
111-
testClassification("/mv3_xnnpack_fp32.pte");
58+
EValue[] results = module.forward();
59+
assertTrue(results[0].isTensor());
11260
}
11361

114-
@Test
115-
public void testResnet50() throws IOException, URISyntaxException {
116-
testClassification("/resnet50_xnnpack_q8.pte");
117-
}
11862
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ private static Tensor nativeNewTensor(
675675
} else if (DType.INT8.jniCode == dtype) {
676676
tensor = new Tensor_int8(data, shape);
677677
} else {
678-
throw new IllegalArgumentException("Unknown Tensor dtype: " + dtype);
678+
throw new IllegalArgumentException("Unknown Tensor dtype");
679679
}
680680
tensor.mHybridData = hybridData;
681681
return tensor;

scripts/run_android_emulator.sh

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,29 @@
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
7+
78
set -ex
8-
free -h
9+
910
# This script is originally adopted from https://github.com/pytorch/pytorch/blob/main/android/run_tests.sh
1011
ADB_PATH=$ANDROID_HOME/platform-tools/adb
12+
1113
echo "Waiting for emulator boot to complete"
1214
# shellcheck disable=SC2016
1315
$ADB_PATH wait-for-device shell 'while [[ -z $(getprop sys.boot_completed) ]]; do sleep 5; done;'
16+
1417
# The device will be created by ReactiveCircus/android-emulator-runner GHA
1518
echo "List all running emulators"
1619
$ADB_PATH devices
17-
adb shell "free -h"
20+
1821
adb uninstall org.pytorch.executorch.test || true
1922
adb install -t android-test-debug-androidTest.apk
2023

21-
for i in {1..40}; do
22-
adb shell 'free -h'
23-
sleep 1
24-
done &
25-
2624
adb logcat -c
27-
adb shell am instrument -w -r \
25+
adb shell am instrument -w -r -e \
26+
class org.pytorch.executorch.ModuleInstrumentationTest,org.pytorch.executorch.ModuleE2ETest \
2827
org.pytorch.executorch.test/androidx.test.runner.AndroidJUnitRunner >result.txt 2>&1
2928
adb logcat -d > logcat.txt
30-
adb shell dumpsys meminfo
3129
cat logcat.txt
32-
cat result.txt
3330
grep -q FAILURES result.txt && cat result.txt
3431
grep -q FAILURES result.txt && exit -1
3532
exit 0

0 commit comments

Comments
 (0)