Skip to content

Commit abfc518

Browse files
committed
Sync bench code
1 parent 943bba2 commit abfc518

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ import kotlinx.coroutines.flow.catch
1010
import kotlinx.coroutines.launch
1111

1212
class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
13+
companion object {
14+
@JvmStatic
15+
private val NanosPerSecond = 1_000_000_000.0
16+
}
17+
1318
private val tag: String? = this::class.simpleName
1419

1520
var messages by mutableStateOf(listOf("Initializing..."))
@@ -39,9 +44,9 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
3944
messages += ""
4045

4146
viewModelScope.launch {
42-
llm.send(if (text.last() == '\n') text else text + "\n")
47+
llm.send(text)
4348
.catch {
44-
Log.e(tag, "send() flow failed", it)
49+
Log.e(tag, "send() failed", it)
4550
messages += it.message!!
4651
}
4752
.collect { messages = messages.dropLast(1) + (messages.last() + it) }
@@ -51,8 +56,23 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
5156
fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) {
5257
viewModelScope.launch {
5358
try {
54-
llm.bench(pp, tg, pl, nr)
59+
val start = System.nanoTime()
60+
val warmupResult = llm.bench(pp, tg, pl, nr)
61+
val end = System.nanoTime()
62+
63+
messages += warmupResult
64+
65+
val warmup = (end - start).toDouble() / NanosPerSecond
66+
messages += "Warm up time: $warmup seconds, please wait..."
67+
68+
if (warmup > 5.0) {
69+
messages += "Warm up took too long, aborting benchmark"
70+
return@launch
71+
}
72+
73+
messages += llm.bench(512, 128, 1, 3)
5574
} catch (exc: IllegalStateException) {
75+
Log.e(tag, "bench() failed", exc)
5676
messages += exc.message!!
5777
}
5878
}
@@ -64,6 +84,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
6484
llm.load(pathToModel)
6585
messages += "Loaded $pathToModel"
6686
} catch (exc: IllegalStateException) {
87+
Log.e(tag, "load() failed", exc)
6788
messages += exc.message!!
6889
}
6990
}

0 commit comments

Comments
 (0)