@@ -3449,6 +3449,19 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3449
3449
};
3450
3450
static_assert (GGML_TYPE_COUNT == 9 , "GGML_TYPE_NAME is outdated" );
3451
3451
3452
+ static bool GGML_IS_QUANTIZED [GGML_TYPE_COUNT ] = {
3453
+ [GGML_TYPE_F32 ] = false,
3454
+ [GGML_TYPE_F16 ] = false,
3455
+ [GGML_TYPE_Q4_0 ] = true,
3456
+ [GGML_TYPE_Q4_1 ] = true,
3457
+ [GGML_TYPE_Q4_2 ] = true,
3458
+ [GGML_TYPE_Q8_0 ] = true,
3459
+ [GGML_TYPE_I8 ] = false,
3460
+ [GGML_TYPE_I16 ] = false,
3461
+ [GGML_TYPE_I32 ] = false,
3462
+ };
3463
+ static_assert (GGML_TYPE_COUNT == 9 , "GGML_IS_QUANTIZED is outdated" );
3464
+
3452
3465
static const char * GGML_OP_LABEL [GGML_OP_COUNT ] = {
3453
3466
"NONE" ,
3454
3467
@@ -3709,6 +3722,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3709
3722
(t0 -> ne [3 ] == t1 -> ne [3 ]);
3710
3723
}
3711
3724
3725
+ static inline bool ggml_is_quantized (enum ggml_type type ) {
3726
+ return GGML_IS_QUANTIZED [type ];
3727
+ }
3728
+
3712
3729
static inline bool ggml_is_transposed (const struct ggml_tensor * tensor ) {
3713
3730
return tensor -> nb [0 ] > tensor -> nb [1 ];
3714
3731
}
@@ -5830,7 +5847,7 @@ static void ggml_compute_forward_dup_f16(
5830
5847
}
5831
5848
}
5832
5849
}
5833
- } else if (dst -> type == GGML_TYPE_Q4_0 || dst -> type == GGML_TYPE_Q4_1 || dst -> type == GGML_TYPE_Q4_2 ) {
5850
+ } else if (ggml_is_quantized ( dst -> type ) ) {
5834
5851
quantize_row_q_t const quantize_row_q = quantize_fns [dst -> type ].quantize_row_q ;
5835
5852
size_t id = 0 ;
5836
5853
uint8_t * dst_ptr = (uint8_t * ) dst -> data ;
@@ -6042,7 +6059,7 @@ static void ggml_compute_forward_dup_f32(
6042
6059
}
6043
6060
}
6044
6061
}
6045
- } else if (dst -> type == GGML_TYPE_Q4_0 || dst -> type == GGML_TYPE_Q4_1 || dst -> type == GGML_TYPE_Q4_2 ) {
6062
+ } else if (ggml_is_quantized ( dst -> type ) ) {
6046
6063
quantize_row_q_t const quantize_row_q = quantize_fns [dst -> type ].quantize_row_q ;
6047
6064
size_t id = 0 ;
6048
6065
uint8_t * dst_ptr = (uint8_t * ) dst -> data ;
@@ -6405,7 +6422,7 @@ static void ggml_compute_forward_add_q_f32(
6405
6422
GGML_ASSERT (nb1 <= nb2 );
6406
6423
GGML_ASSERT (nb2 <= nb3 );
6407
6424
6408
- GGML_ASSERT (src0 -> type == GGML_TYPE_Q4_0 || src0 -> type == GGML_TYPE_Q4_1 || src0 -> type == GGML_TYPE_Q4_2 );
6425
+ GGML_ASSERT (ggml_is_quantized ( src0 -> type ) );
6409
6426
GGML_ASSERT (dst -> type == src0 -> type );
6410
6427
GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
6411
6428
@@ -10622,7 +10639,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10622
10639
node -> n_tasks = 1 ;
10623
10640
10624
10641
size_t cur = 0 ;
10625
- if (node -> type == GGML_TYPE_Q4_0 || node -> type == GGML_TYPE_Q4_1 || node -> type == GGML_TYPE_Q4_2 ) {
10642
+ if (ggml_is_quantized ( node -> type ) ) {
10626
10643
cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ] * node -> ne [0 ];
10627
10644
}
10628
10645
@@ -10634,7 +10651,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10634
10651
10635
10652
size_t cur = 0 ;
10636
10653
10637
- if (node -> src0 -> type == GGML_TYPE_Q4_0 || node -> src0 -> type == GGML_TYPE_Q4_1 || node -> src0 -> type == GGML_TYPE_Q4_2 ) {
10654
+ if (ggml_is_quantized ( node -> src0 -> type ) ) {
10638
10655
cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ] * node -> src0 -> ne [0 ] * n_threads ;
10639
10656
}
10640
10657
0 commit comments