Skip to content

Commit c9c1afb

Browse files
committed
backend-cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels
1 parent f010b77 commit c9c1afb

File tree

10 files changed

+261
-90
lines changed

10 files changed

+261
-90
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1993,6 +1993,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
19931993
common_log_set_timestamps(common_log_main(), true);
19941994
}
19951995
).set_env("LLAMA_LOG_TIMESTAMPS"));
1996+
add_opt(common_arg(
1997+
{"-rtrp", "--runtime-repack"},
1998+
string_format("Allow runtime requantization and repacking of Q4_0 to enable optimized GEMM and GEMV kernels (default: %d)", params.runtime_repack),
1999+
[](common_params & params) {
2000+
params.runtime_repack = true;
2001+
}
2002+
).set_examples({LLAMA_EXAMPLE_MAIN}));
19962003

19972004
return ctx_arg;
19982005
}

common/common.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,7 @@ struct llama_model_params common_model_params_to_llama(const common_params & par
996996
mparams.main_gpu = params.main_gpu;
997997
mparams.split_mode = params.split_mode;
998998
mparams.tensor_split = params.tensor_split;
999-
mparams.use_mmap = params.use_mmap;
999+
mparams.use_mmap = params.use_mmap && !params.runtime_repack;
10001000
mparams.use_mlock = params.use_mlock;
10011001
mparams.check_tensors = params.check_tensors;
10021002
if (params.kv_overrides.empty()) {
@@ -1066,6 +1066,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
10661066
cparams.offload_kqv = !params.no_kv_offload;
10671067
cparams.flash_attn = params.flash_attn;
10681068
cparams.no_perf = params.no_perf;
1069+
cparams.runtime_repack = params.runtime_repack;
10691070

10701071
if (params.reranking) {
10711072
cparams.embeddings = true;

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ struct common_params {
265265
bool warmup = true; // warmup run
266266
bool check_tensors = false; // validate tensor data
267267

268+
bool runtime_repack = false; // runtime repack weight for optimized kernels
269+
268270
std::string cache_type_k = "f16"; // KV cache data type for the K
269271
std::string cache_type_v = "f16"; // KV cache data type for the V
270272

examples/llama-bench/llama-bench.cpp

Lines changed: 112 additions & 84 deletions
Large diffs are not rendered by default.

ggml/include/ggml-backend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ extern "C" {
310310
GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
311311
GGML_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
312312
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
313+
GGML_API void ggml_backend_cpu_set_runtime_repack(ggml_backend_t backend_cpu, bool runtime_repack);
313314

314315
// Create a backend buffer from an existing pointer
315316
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);

ggml/src/ggml-aarch64.c

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3207,3 +3207,102 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
32073207
}
32083208
}
32093209
}
3210+
3211+
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor *t, int interleave_block, uint8_t **pmem, size_t *psize) {
3212+
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
3213+
GGML_ASSERT(t->ne[0] % 8 == 0);
3214+
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3215+
3216+
// Do in-place transformation. Allocate scratch buffer
3217+
size_t size = sizeof(block_q4_0x4) * t->ne[0] / QK4_0;
3218+
if (size > *psize) {
3219+
uint8_t *new_mem = realloc(*pmem, size);
3220+
if (!new_mem) {
3221+
return -1;
3222+
}
3223+
*pmem = new_mem;
3224+
*psize = size;
3225+
}
3226+
block_q4_0x4 *dst = (block_q4_0x4*) *pmem;
3227+
block_q4_0 *src = (block_q4_0*) t->data;
3228+
block_q4_0 dst_tmp[4];
3229+
int n = t->ne[0];
3230+
int nrow = t->ne[1]; // Number of rows
3231+
int nrows_interleaved = 4;
3232+
int nblocks = t->ne[0] / QK4_0;
3233+
for (int b = 0; b < (nrow * n); b += nrows_interleaved * n) {
3234+
int cnt = 0;
3235+
for (int64_t x = 0; x < nblocks; x++) {
3236+
for (int i = 0; i < nrows_interleaved; i++ ) {
3237+
dst_tmp[i] = src[x + i * nblocks];
3238+
}
3239+
dst[cnt++] = make_block_q4_0x4(dst_tmp, interleave_block, 0x88);
3240+
}
3241+
memcpy(src, dst, size);
3242+
src += cnt * 4;
3243+
}
3244+
return 0;
3245+
}
3246+
3247+
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, uint8_t **pmem, size_t *psize) {
3248+
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
3249+
GGML_ASSERT(t->ne[0] % 8 == 0);
3250+
GGML_ASSERT(interleave_block == 8);
3251+
3252+
// Do in-place transformation. Allocate scratch buffer
3253+
size_t size = sizeof(block_q4_0x8) * t->ne[0] / QK4_0;
3254+
if (size > *psize) {
3255+
uint8_t *new_mem = realloc(*pmem, size);
3256+
if (!new_mem) {
3257+
return -1;
3258+
}
3259+
*pmem = new_mem;
3260+
*psize = size;
3261+
}
3262+
block_q4_0x8 *dst = (block_q4_0x8*) *pmem;
3263+
block_q4_0 *src = (block_q4_0*) t->data;
3264+
block_q4_0 dst_tmp[8];
3265+
int n = t->ne[0];
3266+
int nrow = t->ne[1]; // Number of rows
3267+
int nrows_interleaved = 8;
3268+
int nblocks = t->ne[0] / QK4_0;
3269+
for (int b = 0; b < (nrow * n); b += nrows_interleaved * n) {
3270+
int cnt = 0;
3271+
for (int64_t x = 0; x < nblocks; x++) {
3272+
for (int i = 0; i < nrows_interleaved; i++ ) {
3273+
dst_tmp[i] = src[x + i * nblocks];
3274+
}
3275+
dst[cnt++] = make_block_q4_0x8(dst_tmp, interleave_block, 0x88);
3276+
}
3277+
memcpy(src, dst, size);
3278+
src += cnt * 4;
3279+
}
3280+
return 0;
3281+
}
3282+
3283+
// Prepare for optimized kernels if applicable
3284+
void ggml_prepare_optimal_kernel(struct ggml_tensor *cur, uint8_t **pmem, size_t *psize) {
3285+
UNUSED(cur);
3286+
UNUSED(pmem);
3287+
UNUSED(psize);
3288+
3289+
#if defined(__ARM_ARCH)
3290+
if (cur->type == GGML_TYPE_Q4_0) {
3291+
if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) {
3292+
if (repack_q4_0_to_q4_0_8_bl(cur, 8, pmem, psize) == 0) {
3293+
cur->type = GGML_TYPE_Q4_0_8_8;
3294+
}
3295+
}
3296+
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
3297+
if (repack_q4_0_to_q4_0_4_bl(cur, 8, pmem, psize) == 0) {
3298+
cur->type = GGML_TYPE_Q4_0_4_8;
3299+
}
3300+
}
3301+
else if (ggml_cpu_has_neon()) {
3302+
if (repack_q4_0_to_q4_0_4_bl(cur, 4, pmem, psize) == 0) {
3303+
cur->type = GGML_TYPE_Q4_0_4_4;
3304+
}
3305+
}
3306+
}
3307+
#endif
3308+
}

ggml/src/ggml-aarch64.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
3333
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
3434
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
3535

36+
void ggml_prepare_optimal_kernel(struct ggml_tensor *cur, uint8_t **pmem, size_t *psize);
37+
3638
#ifdef __cplusplus
3739
}
3840
#endif

ggml/src/ggml-backend.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "ggml-backend-impl.h"
1212
#include "ggml-alloc.h"
1313
#include "ggml-impl.h"
14+
#include "ggml-aarch64.h"
1415

1516
#include <assert.h>
1617
#include <limits.h>
@@ -882,6 +883,10 @@ struct ggml_backend_cpu_context {
882883
uint8_t * work_data;
883884
size_t work_size;
884885

886+
bool runtime_repack;
887+
uint8_t * scratch_memory;
888+
size_t scratch_size;
889+
885890
ggml_abort_callback abort_callback;
886891
void * abort_callback_data;
887892
};
@@ -895,6 +900,7 @@ static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
895900
static void ggml_backend_cpu_free(ggml_backend_t backend) {
896901
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
897902
delete[] cpu_ctx->work_data;
903+
free(cpu_ctx->scratch_memory); // free the scratch memory allocated by C module
898904
delete cpu_ctx;
899905
delete backend;
900906
}
@@ -952,6 +958,16 @@ static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backe
952958
static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
953959
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
954960

961+
if (cpu_ctx->runtime_repack) {
962+
for (int i = 0; i < cgraph->n_nodes; i++) {
963+
struct ggml_tensor * node = cgraph->nodes[i];
964+
if (node->op == GGML_OP_MUL_MAT && node->src[0]->type == GGML_TYPE_Q4_0) {
965+
// Prepare for optimized kernels if applicable.
966+
ggml_prepare_optimal_kernel(node->src[0], &cpu_ctx->scratch_memory, &cpu_ctx->scratch_size);
967+
}
968+
}
969+
}
970+
955971
struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
956972

957973
if (cpu_ctx->work_size < cplan.work_size) {
@@ -1008,6 +1024,9 @@ ggml_backend_t ggml_backend_cpu_init(void) {
10081024
ctx->work_size = 0;
10091025
ctx->abort_callback = NULL;
10101026
ctx->abort_callback_data = NULL;
1027+
ctx->runtime_repack = false;
1028+
ctx->scratch_memory = NULL;
1029+
ctx->scratch_size = 0;
10111030

10121031
ggml_backend_t cpu_backend = new ggml_backend {
10131032
/* .guid = */ ggml_backend_cpu_guid(),
@@ -1055,6 +1074,13 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_
10551074
ctx->abort_callback_data = abort_callback_data;
10561075
}
10571076

1077+
void ggml_backend_cpu_set_runtime_repack(ggml_backend_t backend_cpu, bool runtime_repack) {
1078+
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
1079+
1080+
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
1081+
ctx->runtime_repack = runtime_repack;
1082+
}
1083+
10581084
ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
10591085
GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
10601086
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size);

include/llama.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -341,11 +341,12 @@ extern "C" {
341341

342342
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
343343
// TODO: move at the end of the struct
344-
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
345-
bool embeddings; // if true, extract embeddings (together with logits)
346-
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
347-
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
348-
bool no_perf; // whether to measure performance timings
344+
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
345+
bool embeddings; // if true, extract embeddings (together with logits)
346+
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
347+
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
348+
bool no_perf; // whether to measure performance timings
349+
bool runtime_repack; // runtime repack weight for optimized kernels
349350

350351
// Abort callback
351352
// if it returns true, execution of llama_decode() will be aborted

src/llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,6 +2574,7 @@ struct llama_cparams {
25742574
bool offload_kqv;
25752575
bool flash_attn;
25762576
bool no_perf;
2577+
bool runtime_repack;
25772578

25782579
enum llama_pooling_type pooling_type;
25792580

@@ -17107,6 +17108,7 @@ static void llama_graph_compute(
1710717108
ggml_threadpool * threadpool) {
1710817109
if (lctx.backend_cpu != nullptr) {
1710917110
ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
17111+
ggml_backend_cpu_set_runtime_repack(lctx.backend_cpu, lctx.cparams.runtime_repack);
1711017112
ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
1711117113
}
1711217114

@@ -19034,6 +19036,7 @@ struct llama_context_params llama_context_default_params() {
1903419036
/*.offload_kqv =*/ true,
1903519037
/*.flash_attn =*/ false,
1903619038
/*.no_perf =*/ true,
19039+
/*.runtime_repack =*/ false,
1903719040
/*.abort_callback =*/ nullptr,
1903819041
/*.abort_callback_data =*/ nullptr,
1903919042
};
@@ -19292,6 +19295,7 @@ struct llama_context * llama_new_context_with_model(
1929219295
cparams.flash_attn = params.flash_attn;
1929319296
cparams.no_perf = params.no_perf;
1929419297
cparams.pooling_type = params.pooling_type;
19298+
cparams.runtime_repack = params.runtime_repack;
1929519299

1929619300
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
1929719301
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;

0 commit comments

Comments
 (0)