Skip to content

Commit fe85929

Browse files
committed
ggml : add ggml_is_quantized()
1 parent e435b81 commit fe85929

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

ggml.c

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3449,6 +3449,19 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
34493449
};
34503450
static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_NAME is outdated");
34513451

3452+
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3453+
[GGML_TYPE_F32] = false,
3454+
[GGML_TYPE_F16] = false,
3455+
[GGML_TYPE_Q4_0] = true,
3456+
[GGML_TYPE_Q4_1] = true,
3457+
[GGML_TYPE_Q4_2] = true,
3458+
[GGML_TYPE_Q8_0] = true,
3459+
[GGML_TYPE_I8] = false,
3460+
[GGML_TYPE_I16] = false,
3461+
[GGML_TYPE_I32] = false,
3462+
};
3463+
static_assert(GGML_TYPE_COUNT == 9, "GGML_IS_QUANTIZED is outdated");
3464+
34523465
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
34533466
"NONE",
34543467

@@ -3709,6 +3722,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
37093722
(t0->ne[3] == t1->ne[3]);
37103723
}
37113724

3725+
static inline bool ggml_is_quantized(enum ggml_type type) {
3726+
return GGML_IS_QUANTIZED[type];
3727+
}
3728+
37123729
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
37133730
return tensor->nb[0] > tensor->nb[1];
37143731
}
@@ -5830,7 +5847,7 @@ static void ggml_compute_forward_dup_f16(
58305847
}
58315848
}
58325849
}
5833-
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1 || dst->type == GGML_TYPE_Q4_2) {
5850+
} else if (ggml_is_quantized(dst->type)) {
58345851
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
58355852
size_t id = 0;
58365853
uint8_t * dst_ptr = (uint8_t *) dst->data;
@@ -6042,7 +6059,7 @@ static void ggml_compute_forward_dup_f32(
60426059
}
60436060
}
60446061
}
6045-
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1 || dst->type == GGML_TYPE_Q4_2) {
6062+
} else if (ggml_is_quantized(dst->type)) {
60466063
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
60476064
size_t id = 0;
60486065
uint8_t * dst_ptr = (uint8_t *) dst->data;
@@ -6405,7 +6422,7 @@ static void ggml_compute_forward_add_q_f32(
64056422
GGML_ASSERT(nb1 <= nb2);
64066423
GGML_ASSERT(nb2 <= nb3);
64076424

6408-
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q4_2);
6425+
GGML_ASSERT(ggml_is_quantized(src0->type));
64096426
GGML_ASSERT(dst->type == src0->type);
64106427
GGML_ASSERT(src1->type == GGML_TYPE_F32);
64116428

@@ -10622,7 +10639,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1062210639
node->n_tasks = 1;
1062310640

1062410641
size_t cur = 0;
10625-
if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1 || node->type == GGML_TYPE_Q4_2) {
10642+
if (ggml_is_quantized(node->type)) {
1062610643
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
1062710644
}
1062810645

@@ -10634,7 +10651,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1063410651

1063510652
size_t cur = 0;
1063610653

10637-
if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1 || node->src0->type == GGML_TYPE_Q4_2) {
10654+
if (ggml_is_quantized(node->src0->type)) {
1063810655
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
1063910656
}
1064010657

0 commit comments

Comments
 (0)