Skip to content

Address ARM's big.LITTLE arch by checking cpu info. #1254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class WhisperContext private constructor(private var ptr: Long) {

suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) {
require(ptr != 0L)
WhisperLib.fullTranscribe(ptr, data)
val numThreads = WhisperCpuConfig.preferredThreadCount
Log.d(LOG_TAG, "Selecting $numThreads threads")
WhisperLib.fullTranscribe(ptr, numThreads, data)
val textCount = WhisperLib.getTextSegmentCount(ptr)
return@withContext buildString {
for (i in 0 until textCount) {
Expand Down Expand Up @@ -126,7 +128,7 @@ private class WhisperLib {
external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long
external fun initContext(modelPath: String): Long
external fun freeContext(contextPtr: Long)
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
external fun fullTranscribe(contextPtr: Long, numThreads: Int, audioData: FloatArray)
external fun getTextSegmentCount(contextPtr: Long): Int
external fun getTextSegment(contextPtr: Long, index: Int): String
external fun getSystemInfo(): String
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package com.whispercppdemo.whisper

import android.util.Log
import java.io.BufferedReader
import java.io.FileReader

object WhisperCpuConfig {
val preferredThreadCount: Int
// Always use at least 2 threads:
get() = CpuInfo.getHighPerfCpuCount().coerceAtLeast(2)
}

private class CpuInfo(private val lines: List<String>) {
private fun getHighPerfCpuCount(): Int = try {
getHighPerfCpuCountByFrequencies()
} catch (e: Exception) {
Log.d(LOG_TAG, "Couldn't read CPU frequencies", e)
getHighPerfCpuCountByVariant()
}

private fun getHighPerfCpuCountByFrequencies(): Int =
getCpuValues(property = "processor") { getMaxCpuFrequency(it.toInt()) }
.also { Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): ${it.binnedValues()}") }
.countDroppingMin()

private fun getHighPerfCpuCountByVariant(): Int =
getCpuValues(property = "CPU variant") { it.substringAfter("0x").toInt(radix = 16) }
.also { Log.d(LOG_TAG, "Binned cpu variants (variant, count): ${it.binnedValues()}") }
.countKeepingMin()

private fun List<Int>.binnedValues() = groupingBy { it }.eachCount()

private fun getCpuValues(property: String, mapper: (String) -> Int) = lines
.asSequence()
.filter { it.startsWith(property) }
.map { mapper(it.substringAfter(':').trim()) }
.sorted()
.toList()


private fun List<Int>.countDroppingMin(): Int {
val min = min()
return count { it > min }
}

private fun List<Int>.countKeepingMin(): Int {
val min = min()
return count { it == min }
}

companion object {
private const val LOG_TAG = "WhisperCpuConfig"

fun getHighPerfCpuCount(): Int = try {
readCpuInfo().getHighPerfCpuCount()
} catch (e: Exception) {
Log.d(LOG_TAG, "Couldn't read CPU info", e)
// Our best guess -- just return the # of CPUs minus 4.
(Runtime.getRuntime().availableProcessors() - 4).coerceAtLeast(0)
}

private fun readCpuInfo() = CpuInfo(
BufferedReader(FileReader("/proc/cpuinfo"))
.useLines { it.toList() }
)

private fun getMaxCpuFrequency(cpuIndex: Int): Int {
val path = "/sys/devices/system/cpu/cpu${cpuIndex}/cpufreq/cpuinfo_max_freq"
val maxFreq = BufferedReader(FileReader(path)).use { it.readLine() }
return maxFreq.toInt()
}
}
}
8 changes: 2 additions & 6 deletions examples/whisper.android/app/src/main/jni/whisper/jni.c
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,12 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_freeContext(

JNIEXPORT void JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
JNIEnv *env, jobject thiz, jlong context_ptr, jfloatArray audio_data) {
JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, jfloatArray audio_data) {
UNUSED(thiz);
struct whisper_context *context = (struct whisper_context *) context_ptr;
jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);

// Leave 2 processors free (i.e. the high-efficiency cores).
int max_threads = max(1, min(8, get_nprocs() - 2));
LOGI("Selecting %d threads", max_threads);

// The below adapted from the Objective-C iOS sample
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
params.print_realtime = true;
Expand All @@ -181,7 +177,7 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
params.print_special = false;
params.translate = false;
params.language = "en";
params.n_threads = max_threads;
params.n_threads = num_threads;
params.offset_ms = 0;
params.no_context = true;
params.single_segment = false;
Expand Down