Skip to content

Commit ee64dc6

Browse files
phaitingHaiting Pu
andauthored
Convert all Instrumentation and Module E2E test to kotlin (#10995)
### Summary This change converts all Instrumentation and Module e2e test to kotlin. ### Test plan ./gradlew :executorch_android:connectedAndroidTest Resolved: #10454 --------- Co-authored-by: Haiting Pu <[email protected]>
1 parent 4e38f4a commit ee64dc6

File tree

8 files changed

+543
-573
lines changed

8 files changed

+543
-573
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.java

Lines changed: 0 additions & 126 deletions
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
package org.pytorch.executorch
9+
10+
import android.Manifest
11+
import androidx.test.InstrumentationRegistry
12+
import androidx.test.ext.junit.runners.AndroidJUnit4
13+
import androidx.test.rule.GrantPermissionRule
14+
import org.apache.commons.io.FileUtils
15+
import org.json.JSONException
16+
import org.json.JSONObject
17+
import org.junit.Assert
18+
import org.junit.Before
19+
import org.junit.Rule
20+
import org.junit.Test
21+
import org.junit.runner.RunWith
22+
import org.pytorch.executorch.extension.llm.LlmCallback
23+
import org.pytorch.executorch.extension.llm.LlmModule
24+
import java.io.File
25+
import java.io.IOException
26+
import java.net.URISyntaxException
27+
28+
/** Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */
29+
@RunWith(AndroidJUnit4::class)
30+
class LlmModuleInstrumentationTest : LlmCallback {
31+
private val results: MutableList<String> = ArrayList()
32+
private val tokensPerSecond: MutableList<Float> = ArrayList()
33+
private var llmModule: LlmModule? = null
34+
35+
@Before
36+
@Throws(IOException::class)
37+
fun setUp() {
38+
// copy zipped test resources to local device
39+
val addPteFile = File(getTestFilePath(TEST_FILE_NAME))
40+
var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME)
41+
FileUtils.copyInputStreamToFile(inputStream, addPteFile)
42+
inputStream.close()
43+
44+
val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME))
45+
inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME)
46+
FileUtils.copyInputStreamToFile(inputStream, tokenizerFile)
47+
inputStream.close()
48+
49+
llmModule =
50+
LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f)
51+
}
52+
53+
@get:Rule
54+
var runtimePermissionRule: GrantPermissionRule =
55+
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)
56+
57+
@Test
58+
@Throws(IOException::class, URISyntaxException::class)
59+
fun testGenerate() {
60+
val loadResult = llmModule!!.load()
61+
// Check that the model can be load successfully
62+
Assert.assertEquals(OK.toLong(), loadResult.toLong())
63+
64+
llmModule!!.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
65+
Assert.assertEquals(results.size.toLong(), SEQ_LEN.toLong())
66+
Assert.assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0)
67+
}
68+
69+
@Test
70+
@Throws(IOException::class, URISyntaxException::class)
71+
fun testGenerateAndStop() {
72+
llmModule!!.generate(TEST_PROMPT, SEQ_LEN, object : LlmCallback {
73+
override fun onResult(result: String) {
74+
this@LlmModuleInstrumentationTest.onResult(result)
75+
llmModule!!.stop()
76+
}
77+
78+
override fun onStats(stats: String) {
79+
this@LlmModuleInstrumentationTest.onStats(stats)
80+
}
81+
})
82+
83+
val stoppedResultSize = results.size
84+
Assert.assertTrue(stoppedResultSize < SEQ_LEN)
85+
}
86+
87+
override fun onResult(result: String) {
88+
results.add(result)
89+
}
90+
91+
override fun onStats(stats: String) {
92+
var tps = 0f
93+
try {
94+
val jsonObject = JSONObject(stats)
95+
val numGeneratedTokens = jsonObject.getInt("generated_tokens")
96+
val inferenceEndMs = jsonObject.getInt("inference_end_ms")
97+
val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms")
98+
tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
99+
tokensPerSecond.add(tps)
100+
} catch (_: JSONException) {
101+
}
102+
}
103+
104+
companion object {
105+
private const val TEST_FILE_NAME = "/stories.pte"
106+
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
107+
private const val TEST_PROMPT = "Hello"
108+
private const val OK = 0x00
109+
private const val SEQ_LEN = 32
110+
111+
private fun getTestFilePath(fileName: String): String {
112+
return InstrumentationRegistry.getInstrumentation().targetContext.externalCacheDir.toString() + fileName
113+
}
114+
}
115+
}

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java

Lines changed: 0 additions & 119 deletions
This file was deleted.

0 commit comments

Comments
 (0)