@@ -3450,6 +3450,19 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3450
3450
};
3451
3451
static_assert (GGML_TYPE_COUNT == 9 , "GGML_TYPE_NAME is outdated" );
3452
3452
3453
+ static bool GGML_IS_QUANTIZED [GGML_TYPE_COUNT ] = {
3454
+ [GGML_TYPE_F32 ] = false,
3455
+ [GGML_TYPE_F16 ] = false,
3456
+ [GGML_TYPE_Q4_0 ] = true,
3457
+ [GGML_TYPE_Q4_1 ] = true,
3458
+ [GGML_TYPE_Q4_2 ] = true,
3459
+ [GGML_TYPE_Q8_0 ] = true,
3460
+ [GGML_TYPE_I8 ] = false,
3461
+ [GGML_TYPE_I16 ] = false,
3462
+ [GGML_TYPE_I32 ] = false,
3463
+ };
3464
+ static_assert (GGML_TYPE_COUNT == 9 , "GGML_IS_QUANTIZED is outdated" );
3465
+
3453
3466
static const char * GGML_OP_LABEL [GGML_OP_COUNT ] = {
3454
3467
"NONE" ,
3455
3468
@@ -3710,6 +3723,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3710
3723
(t0 -> ne [3 ] == t1 -> ne [3 ]);
3711
3724
}
3712
3725
3726
+ static inline bool ggml_is_quantized (enum ggml_type type ) {
3727
+ return GGML_IS_QUANTIZED [type ];
3728
+ }
3729
+
3713
3730
static inline bool ggml_is_transposed (const struct ggml_tensor * tensor ) {
3714
3731
return tensor -> nb [0 ] > tensor -> nb [1 ];
3715
3732
}
@@ -5831,7 +5848,7 @@ static void ggml_compute_forward_dup_f16(
5831
5848
}
5832
5849
}
5833
5850
}
5834
- } else if (dst -> type == GGML_TYPE_Q4_0 || dst -> type == GGML_TYPE_Q4_1 || dst -> type == GGML_TYPE_Q4_2 ) {
5851
+ } else if (ggml_is_quantized ( dst -> type ) ) {
5835
5852
quantize_row_q_t const quantize_row_q = quantize_fns [dst -> type ].quantize_row_q ;
5836
5853
size_t id = 0 ;
5837
5854
uint8_t * dst_ptr = (uint8_t * ) dst -> data ;
@@ -6043,7 +6060,7 @@ static void ggml_compute_forward_dup_f32(
6043
6060
}
6044
6061
}
6045
6062
}
6046
- } else if (dst -> type == GGML_TYPE_Q4_0 || dst -> type == GGML_TYPE_Q4_1 || dst -> type == GGML_TYPE_Q4_2 ) {
6063
+ } else if (ggml_is_quantized ( dst -> type ) ) {
6047
6064
quantize_row_q_t const quantize_row_q = quantize_fns [dst -> type ].quantize_row_q ;
6048
6065
size_t id = 0 ;
6049
6066
uint8_t * dst_ptr = (uint8_t * ) dst -> data ;
@@ -6406,7 +6423,7 @@ static void ggml_compute_forward_add_q_f32(
6406
6423
GGML_ASSERT (nb1 <= nb2 );
6407
6424
GGML_ASSERT (nb2 <= nb3 );
6408
6425
6409
- GGML_ASSERT (src0 -> type == GGML_TYPE_Q4_0 || src0 -> type == GGML_TYPE_Q4_1 || src0 -> type == GGML_TYPE_Q4_2 );
6426
+ GGML_ASSERT (ggml_is_quantized ( src0 -> type ) );
6410
6427
GGML_ASSERT (dst -> type == src0 -> type );
6411
6428
GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
6412
6429
@@ -10623,7 +10640,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10623
10640
node -> n_tasks = 1 ;
10624
10641
10625
10642
size_t cur = 0 ;
10626
- if (node -> type == GGML_TYPE_Q4_0 || node -> type == GGML_TYPE_Q4_1 || node -> type == GGML_TYPE_Q4_2 ) {
10643
+ if (ggml_is_quantized ( node -> type ) ) {
10627
10644
cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ] * node -> ne [0 ];
10628
10645
}
10629
10646
@@ -10635,7 +10652,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10635
10652
10636
10653
size_t cur = 0 ;
10637
10654
10638
- if (node -> src0 -> type == GGML_TYPE_Q4_0 || node -> src0 -> type == GGML_TYPE_Q4_1 || node -> src0 -> type == GGML_TYPE_Q4_2 ) {
10655
+ if (ggml_is_quantized ( node -> src0 -> type ) ) {
10639
10656
cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ] * node -> src0 -> ne [0 ] * n_threads ;
10640
10657
}
10641
10658
0 commit comments