Skip to content

eval-callback: Example how to use eval callback for debugging #6576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
067e294
gguf-debug: Example how to use ggml callback for debugging
phymbert Apr 10, 2024
f63b722
gguf-debug: no mutex, verify type, fix stride.
phymbert Apr 10, 2024
8fe3be8
llama: cv eval: move cb eval field in common gpt_params
phymbert Apr 10, 2024
01dd5e9
ggml_debug: use common gpt_params to pass cb eval.
phymbert Apr 10, 2024
cda1d42
ggml_debug: ci: add tests
phymbert Apr 10, 2024
2d34bbe
ggml_debug: EOL in CMakeLists.txt
phymbert Apr 10, 2024
fe4b191
ggml_debug: Remove unused param n_batch, no batching here
phymbert Apr 10, 2024
08fa088
ggml_debug: fix trailing spaces
phymbert Apr 10, 2024
ca6f3ff
ggml_debug: fix trailing spaces
phymbert Apr 10, 2024
f3f0d18
common: fix cb_eval and user data not initialized
phymbert Apr 10, 2024
1a031d3
ci: build revert label
phymbert Apr 10, 2024
368272c
ggml_debug: add main test label
phymbert Apr 10, 2024
deadf29
Merge remote-tracking branch 'origin/master' into hp/ggml/debug
phymbert Apr 10, 2024
0b33928
doc: add a model: add a link to ggml-debug
phymbert Apr 10, 2024
a42ebbd
ggml-debug: add to make toolchain
phymbert Apr 10, 2024
f84473d
ggml-debug: tests add the main label
phymbert Apr 10, 2024
52a8e06
ggml-debug: ci add test curl label
phymbert Apr 10, 2024
831c97e
common: allow the warmup to be disabled in llama_init_from_gpt_params
phymbert Apr 10, 2024
3f8a93f
ci: add curl test
phymbert Apr 10, 2024
cfb820b
ggml-debug: better tensor type support
phymbert Apr 10, 2024
bb359cd
gitignore : ggml-debug
ggerganov Apr 11, 2024
8d7be2c
ggml-debug: printing also the sum of each tensor
phymbert Apr 11, 2024
28fd76f
ggml-debug: remove block size
phymbert Apr 11, 2024
ee588a5
eval-callback: renamed from ggml-debug
phymbert Apr 11, 2024
12731d2
eval-callback: fix make toolchain
phymbert Apr 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ jobs:
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900
ctest --verbose --timeout 900

- name: Test llama2c conversion
id: llama2c_test
Expand Down
2 changes: 2 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1674,6 +1674,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type;
cparams.defrag_thold = params.defrag_thold;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;

cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ struct gpt_params {
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold

ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;

ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ else()
add_subdirectory(finetune)
add_subdirectory(gritlm)
add_subdirectory(gguf-split)
add_subdirectory(ggml-debug)
add_subdirectory(infill)
add_subdirectory(llama-bench)
add_subdirectory(llava)
Expand Down
10 changes: 10 additions & 0 deletions examples/ggml-debug/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
set(TARGET ggml-debug)
add_executable(${TARGET} ggml-debug.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

# define tests
enable_testing()

add_test(NAME test-ggml-debug COMMAND ggml-debug --hf-repo ggml-org/models --hf-file tinyllamas/stories260K.gguf --model stories260K.gguf --prompt hello --seed 42)
95 changes: 95 additions & 0 deletions examples/ggml-debug/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# llama.cpp/examples/ggml-debug

A simple example which demonstrates how to use callback during the inference.
It simply prints to the console all operations and tensor data.

Usage:

```shell
ggml-debug \
--hf-repo ggml-org/models \
--hf-file phi-2/ggml-model-q4_0.gguf \
--model phi-2-q4_0.gguf \
--prompt hello \
--seed 42 \
-ngl 33
```

Will print:

```shell
llm_load_tensors: offloaded 33/33 layers to GPU
...
llama_new_context_with_model: n_ctx = 512
...
llama_new_context_with_model: CUDA0 compute buffer size = 105.00 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 6.01 MiB
llama_new_context_with_model: graph nodes = 1225
llama_new_context_with_model: graph splits = 2
ggml_debug: inp_embd = (f32) GET_ROWS(token_embd.weight{2560, 51200, 1, 1}, inp_tokens{1, 1, 1, 1}}) = {2560, 1, 1, 1}
[
[
[ -0.0181, 0.0272, 0.0272, ...],
],
]
ggml_debug: norm-0 = (f32) NORM(CUDA0#inp_embd#0{2560, 1, 1, 1}, }) = {2560, 1, 1, 1}
[
[
[ -0.6989, 1.0636, 1.0636, ...],
],
]
ggml_debug: norm_w-0 = (f32) MUL(norm-0{2560, 1, 1, 1}, blk.0.attn_norm.weight{2560, 1, 1, 1}}) = {2560, 1, 1, 1}
[
[
[ -0.1800, 0.2817, 0.2632, ...],
],
]
ggml_debug: attn_norm-0 = (f32) ADD(norm_w-0{2560, 1, 1, 1}, blk.0.attn_norm.bias{2560, 1, 1, 1}}) = {2560, 1, 1, 1}
[
[
[ -0.1863, 0.2970, 0.2604, ...],
],
]
ggml_debug: wqkv-0 = (f32) MUL_MAT(blk.0.attn_qkv.weight{2560, 7680, 1, 1}, attn_norm-0{2560, 1, 1, 1}}) = {7680, 1, 1, 1}
[
[
[ -1.1238, 1.2876, -1.8086, ...],
],
]
ggml_debug: bqkv-0 = (f32) ADD(wqkv-0{7680, 1, 1, 1}, blk.0.attn_qkv.bias{7680, 1, 1, 1}}) = {7680, 1, 1, 1}
[
[
[ -1.1135, 1.4604, -1.9226, ...],
],
]
ggml_debug: bqkv-0 (view) = (f32) VIEW(bqkv-0{7680, 1, 1, 1}, }) = {2560, 1, 1, 1}
[
[
[ -1.1135, 1.4604, -1.9226, ...],
],
]
ggml_debug: Qcur-0 = (f32) CONT(bqkv-0 (view){2560, 1, 1, 1}, }) = {2560, 1, 1, 1}
[
[
[ -1.1135, 1.4604, -1.9226, ...],
],
]
ggml_debug: Qcur-0 (reshaped) = (f32) RESHAPE(Qcur-0{2560, 1, 1, 1}, }) = {80, 32, 1, 1}
[
[
[ -1.1135, 1.4604, -1.9226, ...],
[ -0.3608, 0.5076, -1.8866, ...],
[ 1.7643, 0.0273, -2.1065, ...],
...
],
]
ggml_debug: Qcur-0 = (f32) ROPE(Qcur-0 (reshaped){80, 32, 1, 1}, CUDA0#inp_pos#0{1, 1, 1, 1}}) = {80, 32, 1, 1}
[
[
[ -1.1135, 1.4604, -1.9226, ...],
[ -0.3608, 0.5076, -1.8866, ...],
[ 1.7643, 0.0273, -2.1065, ...],
...
],
]
```
168 changes: 168 additions & 0 deletions examples/ggml-debug/ggml-debug.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#include "common.h"
#include "llama.h"
#include "ggml.h"

#include <cstdio>
#include <random>
#include <string>
#include <tuple>
#include <vector>

/**
* This the arbitrary data which will be passed to each callback.
* Later on we can for example add operation or tensor name filter from the CLI arg, or a file descriptor to dump the tensor.
*/
struct callback_data {
std::vector<int8_t> data;
};

static std::string ggml_ne_string(const ggml_tensor * t) {
std::string str;
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
str += std::to_string(t->ne[i]);
if (i + 1 < GGML_MAX_DIMS) {
str += ", ";
}
}
return str;
}

static void ggml_print_tensor(int8_t * data, const int64_t * ne, const size_t * nb, int64_t n) {
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
printf(" [\n");
for (int64_t i2 = 0; i2 < ne[2] && i2 < n; i2++) {
printf(" [\n");
for (int64_t i1 = 0; i1 < ne[1] && i1 < n; i1++) {
printf(" [");
for (int64_t i0 = 0; i0 < ne[0] && i0 < n; i0++) {
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v = *(float *)(data + i);
printf("%8.4f", v);
if (i0 < ne[0] - 1 && i0 < n - 1) printf(", ");
}
if (ne[0] > n) printf(", ...");
printf("],\n");
}
if (ne[1] > n) printf(" ...\n");
printf(" ],\n");
}
if (ne[2] > n) printf(" ...\n");
printf(" ]\n");
}
}

/**
* GGML operations callback during the graph execution.
*
* @param t current tensor
* @param ask when ask is true, the scheduler wants to know if we are interested in data from this tensor
* if we return true, a follow-up call will be made with ask=false in which we can do the actual collection.
* see ggml_backend_sched_eval_callback
* @param user_data user data to pass at each call back
* @return true to receive data or continue the graph, false otherwise
*/
static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
auto * cb_data = (callback_data *) user_data;

const struct ggml_tensor * src0 = t->src[0];
const struct ggml_tensor * src1 = t->src[1];

if (ask) {
return true; // Always retrieve data
}

char src1_str[128] = {0};
if (src1) {
sprintf(src1_str, "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
}

printf("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
t->name, ggml_type_name(t->type), ggml_op_name(t->op),
src0->name, ggml_ne_string(src0).c_str(),
src1 ? src1_str : "",
ggml_ne_string(t).c_str());


// copy the data from the GPU memory if needed
const bool is_host = ggml_backend_buffer_is_host(t->buffer);

if (!is_host) {
auto n_bytes = ggml_nbytes(t);
cb_data->data.resize(n_bytes);
ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes);
}

if (t->type == GGML_TYPE_F32 || t->type == GGML_TYPE_F16) {
int8_t * data = is_host ? (int8_t *) t->data : cb_data->data.data();
ggml_print_tensor(data, t->ne, t->nb, 3);
}

return true;
}

static bool run(llama_context * ctx, const gpt_params & params) {
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));

std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);

if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}

return true;
}

int main(int argc, char ** argv) {

callback_data cb_data;

gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}

print_build_info();

std::mt19937 rng(params.seed);
if (params.random_prompt) {
params.prompt = gpt_random_prompt(rng);
}

llama_backend_init();
llama_numa_init(params.numa);

// pass the callback to the backend scheduler
// it will be executed for each node during the graph computation
params.cb_eval = ggml_debug;
params.cb_eval_user_data = &cb_data;

// init
llama_model * model;
llama_context * ctx;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == nullptr || ctx == nullptr) {
fprintf(stderr, "%s : failed to init\n", __func__);
return 1;
}

// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "%s\n", get_system_info(params).c_str());
}

bool OK = run(ctx, params);
if (!OK) {
return 1;
}

llama_print_timings(ctx);

llama_free(ctx);
llama_free_model(model);

llama_backend_free();

return 0;
}
25 changes: 9 additions & 16 deletions examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,24 +596,17 @@ int main(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);

llama_model_params mparams = llama_model_params_from_gpt_params(params);

llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}

llama_context_params cparams = llama_context_params_from_gpt_params(params);

// pass the callback to the backend scheduler
// it will be executed for each node during the graph computation
cparams.cb_eval = ik_collect_imatrix;
cparams.cb_eval_user_data = NULL;

llama_context * ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: unable to create context\n", __func__);
params.cb_eval = ik_collect_imatrix;
params.cb_eval_user_data = NULL;

// init
llama_model * model;
llama_context * ctx;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == nullptr || ctx == nullptr) {
fprintf(stderr, "%s : failed to init\n", __func__);
return 1;
}

Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11054,7 +11054,7 @@ struct llm_tokenizer_bpe {
add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol
}

// add the fnished tokens to the final list keeping correct order for next and prev
// add the finished tokens to the final list keeping correct order for next and prev
for (auto & sym : symbols) {
if (sym.n > 0) {
sym.prev = final_prev_index;
Expand Down