Skip to content

Commit 7361896

Browse files
committed
Now it should fix
1 parent 5bda99f commit 7361896

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

extension/android/executorch_android/android_test_setup.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ prepare_tinyllama() {
3636
prepare_vision() {
3737
pushd "${BASEDIR}/../../../"
3838
python3 -m examples.xnnpack.aot_compiler --model_name "mv2" --delegate
39-
cp mv2*.pte "${BASEDIR}/src/androidTest/resources/"
39+
python3 -m examples.xnnpack.aot_compiler --model_name "mv3" --delegate
40+
python3 -m examples.xnnpack.aot_compiler --model_name "resnet50" --quantize --delegate
41+
cp mv2*.pte mv3*.pte resnet50*.pte "${BASEDIR}/src/androidTest/resources/"
4042
popd
4143
}
4244

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import static org.junit.Assert.assertNotEquals;
1515
import static org.junit.Assert.fail;
1616

17+
import android.graphics.Bitmap;
18+
import android.graphics.BitmapFactory;
1719
import android.os.Environment;
1820
import androidx.test.rule.GrantPermissionRule;
1921
import android.Manifest;
@@ -45,18 +47,60 @@ private static String getTestFilePath(String fileName) {
4547
@Rule
4648
public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE);
4749

48-
@Test
49-
public void testMv2Fp32() throws IOException, URISyntaxException{
50-
String filePath = "/mv2_xnnpack_fp32.pte";
50+
static int argmax(float[] array) {
51+
if (array.length == 0) {
52+
throw new IllegalArgumentException("Array cannot be empty");
53+
}
54+
int maxIndex = 0;
55+
float maxValue = array[0];
56+
for (int i = 1; i < array.length; i++) {
57+
if (array[i] > maxValue) {
58+
maxValue = array[i];
59+
maxIndex = i;
60+
}
61+
}
62+
return maxIndex;
63+
}
64+
65+
public void testClassification(String filePath) throws IOException, URISyntaxException {
5166
File pteFile = new File(getTestFilePath(filePath));
5267
InputStream inputStream = getClass().getResourceAsStream(filePath);
5368
FileUtils.copyInputStreamToFile(inputStream, pteFile);
5469
inputStream.close();
5570

71+
InputStream imgInputStream = getClass().getResourceAsStream("/banana.jpeg");
72+
Bitmap bitmap = BitmapFactory.decodeStream(imgInputStream);
73+
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);
74+
imgInputStream.close();
75+
76+
Tensor inputTensor =
77+
TensorImageUtils.bitmapToFloat32Tensor(
78+
bitmap,
79+
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
80+
TensorImageUtils.TORCHVISION_NORM_STD_RGB);
81+
5682
Module module = Module.load(getTestFilePath(filePath));
5783

58-
EValue[] results = module.forward();
84+
EValue[] results = module.forward(EValue.from(inputTensor));
5985
assertTrue(results[0].isTensor());
86+
float[] scores = results[0].toTensor().getDataAsFloatArray();
87+
88+
int bananaClass = 954; // From ImageNet 1K
89+
assertEquals(bananaClass, argmax(scores));
90+
}
91+
92+
@Test
93+
public void testMv2Fp32() throws IOException, URISyntaxException {
94+
testClassification("/mv2_xnnpack_fp32.pte");
6095
}
6196

97+
@Test
98+
public void testMv3Fp32() throws IOException, URISyntaxException {
99+
testClassification("/mv3_xnnpack_fp32.pte");
100+
}
101+
102+
@Test
103+
public void testResnet50() throws IOException, URISyntaxException {
104+
testClassification("/resnet50_xnnpack_q8.pte");
105+
}
62106
}

0 commit comments

Comments
 (0)