|
14 | 14 | import static org.junit.Assert.assertNotEquals;
|
15 | 15 | import static org.junit.Assert.fail;
|
16 | 16 |
|
| 17 | +import android.graphics.Bitmap; |
| 18 | +import android.graphics.BitmapFactory; |
17 | 19 | import android.os.Environment;
|
18 | 20 | import androidx.test.rule.GrantPermissionRule;
|
19 | 21 | import android.Manifest;
|
@@ -45,18 +47,60 @@ private static String getTestFilePath(String fileName) {
|
45 | 47 | @Rule
|
46 | 48 | public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE);
|
47 | 49 |
|
48 |
| - @Test |
49 |
| - public void testMv2Fp32() throws IOException, URISyntaxException{ |
50 |
| - String filePath = "/mv2_xnnpack_fp32.pte"; |
| 50 | + static int argmax(float[] array) { |
| 51 | + if (array.length == 0) { |
| 52 | + throw new IllegalArgumentException("Array cannot be empty"); |
| 53 | + } |
| 54 | + int maxIndex = 0; |
| 55 | + float maxValue = array[0]; |
| 56 | + for (int i = 1; i < array.length; i++) { |
| 57 | + if (array[i] > maxValue) { |
| 58 | + maxValue = array[i]; |
| 59 | + maxIndex = i; |
| 60 | + } |
| 61 | + } |
| 62 | + return maxIndex; |
| 63 | + } |
| 64 | + |
| 65 | + public void testClassification(String filePath) throws IOException, URISyntaxException { |
51 | 66 | File pteFile = new File(getTestFilePath(filePath));
|
52 | 67 | InputStream inputStream = getClass().getResourceAsStream(filePath);
|
53 | 68 | FileUtils.copyInputStreamToFile(inputStream, pteFile);
|
54 | 69 | inputStream.close();
|
55 | 70 |
|
| 71 | + InputStream imgInputStream = getClass().getResourceAsStream("/banana.jpeg"); |
| 72 | + Bitmap bitmap = BitmapFactory.decodeStream(imgInputStream); |
| 73 | + bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true); |
| 74 | + imgInputStream.close(); |
| 75 | + |
| 76 | + Tensor inputTensor = |
| 77 | + TensorImageUtils.bitmapToFloat32Tensor( |
| 78 | + bitmap, |
| 79 | + TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, |
| 80 | + TensorImageUtils.TORCHVISION_NORM_STD_RGB); |
| 81 | + |
56 | 82 | Module module = Module.load(getTestFilePath(filePath));
|
57 | 83 |
|
58 |
| - EValue[] results = module.forward(); |
| 84 | + EValue[] results = module.forward(EValue.from(inputTensor)); |
59 | 85 | assertTrue(results[0].isTensor());
|
| 86 | + float[] scores = results[0].toTensor().getDataAsFloatArray(); |
| 87 | + |
| 88 | + int bananaClass = 954; // From ImageNet 1K |
| 89 | + assertEquals(bananaClass, argmax(scores)); |
| 90 | + } |
| 91 | + |
| 92 | + @Test |
| 93 | + public void testMv2Fp32() throws IOException, URISyntaxException { |
| 94 | + testClassification("/mv2_xnnpack_fp32.pte"); |
60 | 95 | }
|
61 | 96 |
|
| 97 | + @Test |
| 98 | + public void testMv3Fp32() throws IOException, URISyntaxException { |
| 99 | + testClassification("/mv3_xnnpack_fp32.pte"); |
| 100 | + } |
| 101 | + |
| 102 | + @Test |
| 103 | + public void testResnet50() throws IOException, URISyntaxException { |
| 104 | + testClassification("/resnet50_xnnpack_q8.pte"); |
| 105 | + } |
62 | 106 | }
|
0 commit comments