Skip to content

Commit 6aea16e

Browse files
committed
add basic tensor data validation function
1 parent b4e4b8a commit 6aea16e

File tree

3 files changed

+224
-4
lines changed

3 files changed

+224
-4
lines changed

ggml-quants.c

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12389,3 +12389,206 @@ void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k)
1238912389
block_iq2_s * restrict y = vy;
1239012390
quantize_row_iq2_s_reference(x, y, k);
1239112391
}
12392+
12393+
static bool validate_float(float f, size_t i) {
12394+
if (isinf(f)) {
12395+
fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
12396+
return false;
12397+
}
12398+
12399+
if (isnan(f)) {
12400+
fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
12401+
return false;
12402+
}
12403+
12404+
return true;
12405+
}
12406+
12407+
static bool validate_f16(ggml_fp16_t f, size_t i) {
12408+
return validate_float(GGML_FP16_TO_FP32(f), i);
12409+
}
12410+
12411+
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
12412+
const type * q = (const type *) (data); \
12413+
for (size_t i = 0; i < (nb); ++i) { \
12414+
if (!validate_f16(q[i].d, i)) { \
12415+
return false; \
12416+
} \
12417+
}
12418+
12419+
#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \
12420+
const type * q = (const type *) (data); \
12421+
for (size_t i = 0; i < (nb); ++i) { \
12422+
if (!validate_f16(q[i].d, i) || !validate_f16(q[i].m, i)) { \
12423+
return false; \
12424+
} \
12425+
}
12426+
12427+
bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
12428+
if (type < 0 || type >= GGML_TYPE_COUNT) {
12429+
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
12430+
return false;
12431+
}
12432+
12433+
// size check
12434+
if (nbytes % ggml_type_size(type) != 0) {
12435+
fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type);
12436+
return false;
12437+
}
12438+
12439+
size_t nb = nbytes/ggml_type_size(type);
12440+
12441+
switch (type) {
12442+
case GGML_TYPE_F16:
12443+
{
12444+
const ggml_fp16_t * f = (const ggml_fp16_t *) data;
12445+
for (size_t i = 0; i < nb; ++i) {
12446+
if (!validate_f16(f[i], i)) {
12447+
return false;
12448+
}
12449+
}
12450+
} break;
12451+
case GGML_TYPE_F32:
12452+
{
12453+
const float * f = (const float *) data;
12454+
for (size_t i = 0; i < nb; ++i) {
12455+
if (!validate_float(f[i], i)) {
12456+
return false;
12457+
}
12458+
}
12459+
} break;
12460+
case GGML_TYPE_F64:
12461+
{
12462+
const double * f = (const double *) data;
12463+
for (size_t i = 0; i < nb; ++i) {
12464+
if (!validate_float(f[i], i)) {
12465+
return false;
12466+
}
12467+
}
12468+
} break;
12469+
case GGML_TYPE_Q4_0:
12470+
{
12471+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);
12472+
} break;
12473+
case GGML_TYPE_Q4_1:
12474+
{
12475+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m);
12476+
} break;
12477+
case GGML_TYPE_Q5_0:
12478+
{
12479+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb);
12480+
} break;
12481+
case GGML_TYPE_Q5_1:
12482+
{
12483+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m);
12484+
} break;
12485+
case GGML_TYPE_Q8_0:
12486+
{
12487+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
12488+
} break;
12489+
case GGML_TYPE_Q2_K:
12490+
{
12491+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
12492+
} break;
12493+
case GGML_TYPE_Q3_K:
12494+
{
12495+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb);
12496+
} break;
12497+
case GGML_TYPE_Q4_K:
12498+
{
12499+
#ifdef GGML_QKK_64
12500+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d[0], d[1]);
12501+
#else
12502+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin);
12503+
#endif
12504+
} break;
12505+
case GGML_TYPE_Q5_K:
12506+
{
12507+
#ifdef GGML_QKK_64
12508+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_K, data, nb);
12509+
#else
12510+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin);
12511+
#endif
12512+
} break;
12513+
case GGML_TYPE_Q6_K:
12514+
{
12515+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb);
12516+
} break;
12517+
case GGML_TYPE_Q8_K:
12518+
{
12519+
const block_q8_K * q = (const block_q8_K *) data;
12520+
for (size_t i = 0; i < nb; ++i) {
12521+
if (!validate_float(q[i].d, i)) {
12522+
return false;
12523+
}
12524+
}
12525+
} break;
12526+
case GGML_TYPE_IQ1_S:
12527+
{
12528+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
12529+
} break;
12530+
case GGML_TYPE_IQ1_M:
12531+
{
12532+
const block_iq1_m * q = (const block_iq1_m *) data;
12533+
for (size_t i = 0; i < nb; ++i) {
12534+
#if QK_K == 64
12535+
if (!validate_f16(q[i].d, i)) {
12536+
return false;
12537+
}
12538+
#else
12539+
iq1m_scale_t scale;
12540+
const uint16_t * sc = (const uint16_t *)q[i].scales;
12541+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
12542+
if (!validate_f16(scale.f16, i)) {
12543+
return false;
12544+
}
12545+
#endif
12546+
}
12547+
} break;
12548+
case GGML_TYPE_IQ2_XXS:
12549+
{
12550+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb);
12551+
} break;
12552+
case GGML_TYPE_IQ2_XS:
12553+
{
12554+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb);
12555+
} break;
12556+
case GGML_TYPE_IQ2_S:
12557+
{
12558+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb);
12559+
} break;
12560+
case GGML_TYPE_IQ3_XXS:
12561+
{
12562+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb);
12563+
} break;
12564+
12565+
case GGML_TYPE_IQ3_S:
12566+
{
12567+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb);
12568+
} break;
12569+
case GGML_TYPE_IQ4_XS:
12570+
#if QK_K != 64
12571+
{
12572+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb);
12573+
} break;
12574+
#endif
12575+
// with QK_K == 64, iq4_xs is iq4_nl
12576+
case GGML_TYPE_IQ4_NL:
12577+
{
12578+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
12579+
} break;
12580+
case GGML_TYPE_I8:
12581+
case GGML_TYPE_I16:
12582+
case GGML_TYPE_I32:
12583+
case GGML_TYPE_I64:
12584+
// nothing to validate
12585+
break;
12586+
default:
12587+
{
12588+
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
12589+
return false;
12590+
}
12591+
}
12592+
12593+
return true;
12594+
}

ggml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,8 @@ extern "C" {
762762
// use this to compute the memory overhead of a tensor
763763
GGML_API size_t ggml_tensor_overhead(void);
764764

765+
GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes);
766+
765767
// main
766768

767769
GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);

llama.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3472,6 +3472,10 @@ struct llama_model_loader {
34723472
file->seek(w.offs, SEEK_SET);
34733473
file->read_raw(cur->data, ggml_nbytes(cur));
34743474
}
3475+
3476+
if (!ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) {
3477+
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
3478+
}
34753479
}
34763480

34773481
size_t size_done = 0;
@@ -3509,12 +3513,17 @@ struct llama_model_loader {
35093513
if (bufs_mmap.count(weight->idx)) {
35103514
buf_mmap = bufs_mmap.at(weight->idx);
35113515
}
3516+
3517+
if (!ggml_validate_row_data(cur->type, (uint8_t *) mapping->addr + weight->offs, n_size)) {
3518+
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
3519+
}
3520+
35123521
GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated
35133522
if (buf_mmap && cur->data == nullptr) {
35143523
ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *) mapping->addr + weight->offs);
35153524
if (lmlocks) {
35163525
const auto & lmlock = lmlocks->at(weight->idx);
3517-
lmlock->grow_to(weight->offs + ggml_nbytes(cur));
3526+
lmlock->grow_to(weight->offs + n_size);
35183527
}
35193528

35203529
auto & mmap_used = mmaps_used[weight->idx];
@@ -3528,12 +3537,18 @@ struct llama_model_loader {
35283537
const auto & file = files.at(weight->idx);
35293538
if (ggml_backend_buffer_is_host(cur->buffer)) {
35303539
file->seek(weight->offs, SEEK_SET);
3531-
file->read_raw(cur->data, ggml_nbytes(cur));
3540+
file->read_raw(cur->data, n_size);
3541+
if (!ggml_validate_row_data(cur->type, cur->data, n_size)) {
3542+
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
3543+
}
35323544
} else {
3533-
read_buf.resize(ggml_nbytes(cur));
3545+
read_buf.resize(n_size);
35343546
file->seek(weight->offs, SEEK_SET);
3535-
file->read_raw(read_buf.data(), ggml_nbytes(cur));
3547+
file->read_raw(read_buf.data(), n_size);
35363548
ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
3549+
if (!ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
3550+
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
3551+
}
35373552
}
35383553
}
35393554

0 commit comments

Comments
 (0)