Skip to content

Commit 34ed7f9

Browse files
committed
ggml : add ggml_is_quantized()
1 parent bbd2921 commit 34ed7f9

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

ggml.c

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3450,6 +3450,19 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
34503450
};
34513451
static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_NAME is outdated");
34523452

3453+
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3454+
[GGML_TYPE_F32] = false,
3455+
[GGML_TYPE_F16] = false,
3456+
[GGML_TYPE_Q4_0] = true,
3457+
[GGML_TYPE_Q4_1] = true,
3458+
[GGML_TYPE_Q4_2] = true,
3459+
[GGML_TYPE_Q8_0] = true,
3460+
[GGML_TYPE_I8] = false,
3461+
[GGML_TYPE_I16] = false,
3462+
[GGML_TYPE_I32] = false,
3463+
};
3464+
static_assert(GGML_TYPE_COUNT == 9, "GGML_IS_QUANTIZED is outdated");
3465+
34533466
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
34543467
"NONE",
34553468

@@ -3710,6 +3723,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
37103723
(t0->ne[3] == t1->ne[3]);
37113724
}
37123725

3726+
static inline bool ggml_is_quantized(enum ggml_type type) {
3727+
return GGML_IS_QUANTIZED[type];
3728+
}
3729+
37133730
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
37143731
return tensor->nb[0] > tensor->nb[1];
37153732
}
@@ -5831,7 +5848,7 @@ static void ggml_compute_forward_dup_f16(
58315848
}
58325849
}
58335850
}
5834-
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1 || dst->type == GGML_TYPE_Q4_2) {
5851+
} else if (ggml_is_quantized(dst->type)) {
58355852
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
58365853
size_t id = 0;
58375854
uint8_t * dst_ptr = (uint8_t *) dst->data;
@@ -6043,7 +6060,7 @@ static void ggml_compute_forward_dup_f32(
60436060
}
60446061
}
60456062
}
6046-
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1 || dst->type == GGML_TYPE_Q4_2) {
6063+
} else if (ggml_is_quantized(dst->type)) {
60476064
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
60486065
size_t id = 0;
60496066
uint8_t * dst_ptr = (uint8_t *) dst->data;
@@ -6406,7 +6423,7 @@ static void ggml_compute_forward_add_q_f32(
64066423
GGML_ASSERT(nb1 <= nb2);
64076424
GGML_ASSERT(nb2 <= nb3);
64086425

6409-
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q4_2);
6426+
GGML_ASSERT(ggml_is_quantized(src0->type));
64106427
GGML_ASSERT(dst->type == src0->type);
64116428
GGML_ASSERT(src1->type == GGML_TYPE_F32);
64126429

@@ -10623,7 +10640,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1062310640
node->n_tasks = 1;
1062410641

1062510642
size_t cur = 0;
10626-
if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1 || node->type == GGML_TYPE_Q4_2) {
10643+
if (ggml_is_quantized(node->type)) {
1062710644
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
1062810645
}
1062910646

@@ -10635,7 +10652,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1063510652

1063610653
size_t cur = 0;
1063710654

10638-
if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1 || node->src0->type == GGML_TYPE_Q4_2) {
10655+
if (ggml_is_quantized(node->src0->type)) {
1063910656
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
1064010657
}
1064110658

0 commit comments

Comments
 (0)