@@ -2318,6 +2318,28 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2318
2318
* s = sumf ;
2319
2319
}
2320
2320
2321
+ // TODO: move this to a more sensible place
2322
+ static const quantize_fns_t quantize_fns [GGML_TYPE_COUNT ] = {
2323
+ [GGML_TYPE_Q4_0 ] = {
2324
+ .dequantize_row_q = dequantize_row_q4_0 ,
2325
+ .quantize_row_q = quantize_row_q4_0 ,
2326
+ .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_0_reference ,
2327
+ .vec_dot_q = ggml_vec_dot_q4_0 ,
2328
+ },
2329
+ [GGML_TYPE_Q4_1 ] = {
2330
+ .dequantize_row_q = dequantize_row_q4_1 ,
2331
+ .quantize_row_q = quantize_row_q4_1 ,
2332
+ .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_1_reference ,
2333
+ .vec_dot_q = ggml_vec_dot_q4_1 ,
2334
+ },
2335
+ };
2336
+
2337
+ // For internal test use
2338
+ quantize_fns_t ggml_internal_get_quantize_fn (size_t i ) {
2339
+ GGML_ASSERT (i < GGML_TYPE_COUNT );
2340
+ return quantize_fns [i ];
2341
+ }
2342
+
2321
2343
// compute GGML_VEC_DOT_UNROLL dot products at once
2322
2344
// xs - x row stride in bytes
2323
2345
inline static void ggml_vec_dot_f16_unroll (const int n , const int xs , float * restrict s , void * restrict xv , ggml_fp16_t * restrict y ) {
@@ -5099,13 +5121,13 @@ static void ggml_compute_forward_add_f16_f32(
5099
5121
const int n = ggml_nrows (src0 );
5100
5122
const int nc = src0 -> ne [0 ];
5101
5123
5102
- const size_t nb00 = src0 -> nb [0 ];
5124
+ // const size_t nb00 = src0->nb[0];
5103
5125
const size_t nb01 = src0 -> nb [1 ];
5104
5126
5105
5127
const size_t nb10 = src1 -> nb [0 ];
5106
5128
const size_t nb11 = src1 -> nb [1 ];
5107
5129
5108
- const size_t nb0 = dst -> nb [0 ];
5130
+ // const size_t nb0 = dst->nb[0];
5109
5131
const size_t nb1 = dst -> nb [1 ];
5110
5132
5111
5133
GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
@@ -5117,12 +5139,163 @@ static void ggml_compute_forward_add_f16_f32(
5117
5139
ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5118
5140
for (int i = 0 ; i < nc ; i ++ ) {
5119
5141
float * src1_ptr = (float * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5120
-
5121
5142
dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + * src1_ptr );
5122
5143
}
5123
5144
}
5124
5145
}
5125
5146
5147
+ static void ggml_compute_forward_add_f16_f16 (
5148
+ const struct ggml_compute_params * params ,
5149
+ const struct ggml_tensor * src0 ,
5150
+ const struct ggml_tensor * src1 ,
5151
+ struct ggml_tensor * dst ) {
5152
+ GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5153
+
5154
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5155
+ return ;
5156
+ }
5157
+
5158
+ const int ith = params -> ith ;
5159
+ const int nth = params -> nth ;
5160
+
5161
+ const int n = ggml_nrows (src0 );
5162
+ const int nc = src0 -> ne [0 ];
5163
+
5164
+ //const size_t nb00 = src0->nb[0];
5165
+ const size_t nb01 = src0 -> nb [1 ];
5166
+
5167
+ const size_t nb10 = src1 -> nb [0 ];
5168
+ const size_t nb11 = src1 -> nb [1 ];
5169
+
5170
+ //const size_t nb0 = dst->nb[0];
5171
+ const size_t nb1 = dst -> nb [1 ];
5172
+
5173
+ GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
5174
+ GGML_ASSERT (src1 -> type == GGML_TYPE_F16 );
5175
+ GGML_ASSERT (dst -> type == GGML_TYPE_F16 );
5176
+
5177
+ for (int j = ith ; j < n ; j += nth ) {
5178
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + j * nb1 );
5179
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5180
+ for (int i = 0 ; i < nc ; i ++ ) {
5181
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5182
+ dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + GGML_FP16_TO_FP32 (* src1_ptr ));
5183
+ }
5184
+ }
5185
+ }
5186
+
5187
+ static void ggml_compute_forward_add_q_f32 (
5188
+ const struct ggml_compute_params * params ,
5189
+ const struct ggml_tensor * src0 ,
5190
+ const struct ggml_tensor * src1 ,
5191
+ struct ggml_tensor * dst ) {
5192
+ GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5193
+
5194
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5195
+ return ;
5196
+ }
5197
+
5198
+ const int64_t ne00 = src0 -> ne [0 ];
5199
+ const int64_t ne01 = src0 -> ne [1 ];
5200
+ const int64_t ne02 = src0 -> ne [2 ];
5201
+ const int64_t ne03 = src0 -> ne [3 ];
5202
+
5203
+ //const int64_t ne10 = src1->ne[0];
5204
+ const int64_t ne11 = src1 -> ne [1 ];
5205
+ const int64_t ne12 = src1 -> ne [2 ];
5206
+ const int64_t ne13 = src1 -> ne [3 ];
5207
+
5208
+ const int64_t ne0 = dst -> ne [0 ];
5209
+ const int64_t ne1 = dst -> ne [1 ];
5210
+ const int64_t ne2 = dst -> ne [2 ];
5211
+ const int64_t ne3 = dst -> ne [3 ];
5212
+
5213
+ const int nb00 = src0 -> nb [0 ];
5214
+ const int nb01 = src0 -> nb [1 ];
5215
+ const int nb02 = src0 -> nb [2 ];
5216
+ const int nb03 = src0 -> nb [3 ];
5217
+
5218
+ const int nb10 = src1 -> nb [0 ];
5219
+ const int nb11 = src1 -> nb [1 ];
5220
+ const int nb12 = src1 -> nb [2 ];
5221
+ const int nb13 = src1 -> nb [3 ];
5222
+
5223
+ const int nb0 = dst -> nb [0 ];
5224
+ const int nb1 = dst -> nb [1 ];
5225
+ const int nb2 = dst -> nb [2 ];
5226
+ const int nb3 = dst -> nb [3 ];
5227
+
5228
+ const int ith = params -> ith ;
5229
+ const int nth = params -> nth ;
5230
+
5231
+ GGML_ASSERT (ne02 == ne12 );
5232
+ GGML_ASSERT (ne03 == ne13 );
5233
+ GGML_ASSERT (ne2 == ne12 );
5234
+ GGML_ASSERT (ne3 == ne13 );
5235
+
5236
+ const enum ggml_type type = src0 -> type ;
5237
+ dequantize_row_q_t const dequantize_row_q = quantize_fns [type ].dequantize_row_q ;
5238
+ quantize_row_q_t const quantize_row_q = quantize_fns [type ].quantize_row_q ;
5239
+
5240
+ // we don't support permuted src0 or src1
5241
+ GGML_ASSERT (nb00 == (int ) GGML_TYPE_SIZE [type ]);
5242
+ GGML_ASSERT (nb10 == sizeof (float ));
5243
+
5244
+ // dst cannot be transposed or permuted
5245
+ GGML_ASSERT (nb0 <= nb1 );
5246
+ GGML_ASSERT (nb1 <= nb2 );
5247
+ GGML_ASSERT (nb2 <= nb3 );
5248
+
5249
+ GGML_ASSERT (ne0 == ne01 );
5250
+ GGML_ASSERT (ne1 == ne11 );
5251
+ GGML_ASSERT (ne2 == ne02 );
5252
+ GGML_ASSERT (ne3 == ne03 );
5253
+
5254
+ GGML_ASSERT (src0 -> type == GGML_TYPE_Q4_0 || src0 -> type == GGML_TYPE_Q4_1 );
5255
+ GGML_ASSERT (dst -> type == src0 -> type );
5256
+ GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
5257
+
5258
+ // total rows in src0
5259
+ const int nr = ne01 * ne02 * ne03 ;
5260
+
5261
+ // rows per thread
5262
+ const int dr = (nr + nth - 1 )/nth ;
5263
+
5264
+ // row range for this thread
5265
+ const int ir0 = dr * ith ;
5266
+ const int ir1 = MIN (ir0 + dr , nr );
5267
+
5268
+ for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
5269
+ // src0 indices
5270
+ const int i03 = ir /(ne02 * ne01 );
5271
+ const int i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
5272
+ const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
5273
+
5274
+ // src1 and dst are same shape as src0 => same indices
5275
+ const int i13 = i03 ;
5276
+ const int i12 = i02 ;
5277
+ const int i11 = i01 ;
5278
+
5279
+ const int i3 = i03 ;
5280
+ const int i2 = i02 ;
5281
+ const int i1 = i01 ;
5282
+
5283
+ void * src0_row = (void * ) ((char * ) src0 -> data + (i01 * nb01 + i02 * nb02 + i03 * nb03 ));
5284
+ float * src1_row = (float * )((char * ) src1 -> data + (i11 * nb11 + i12 * nb12 + i13 * nb13 ));
5285
+ void * dst_row = (void * ) ((char * ) dst -> data + ( i1 * nb1 + i2 * nb2 + i3 * nb0 ));
5286
+
5287
+ assert (ne00 % 32 == 0 );
5288
+
5289
+ // unquantize row from src0 to temp buffer
5290
+ float tmp [ne00 ];
5291
+ dequantize_row_q (src0_row , tmp , ne00 );
5292
+ // add src1
5293
+ ggml_vec_acc_f32 (ne00 , tmp , src1_row );
5294
+ // quantize row to dst
5295
+ quantize_row_q (tmp , dst_row , ne00 );
5296
+ }
5297
+ }
5298
+
5126
5299
static void ggml_compute_forward_add (
5127
5300
const struct ggml_compute_params * params ,
5128
5301
const struct ggml_tensor * src0 ,
@@ -5135,10 +5308,21 @@ static void ggml_compute_forward_add(
5135
5308
} break ;
5136
5309
case GGML_TYPE_F16 :
5137
5310
{
5138
- ggml_compute_forward_add_f16_f32 (params , src0 , src1 , dst );
5311
+ if (src1 -> type == GGML_TYPE_F16 ) {
5312
+ ggml_compute_forward_add_f16_f16 (params , src0 , src1 , dst );
5313
+ }
5314
+ else if (src1 -> type == GGML_TYPE_F32 ) {
5315
+ ggml_compute_forward_add_f16_f32 (params , src0 , src1 , dst );
5316
+ }
5317
+ else {
5318
+ GGML_ASSERT (false);
5319
+ }
5139
5320
} break ;
5140
5321
case GGML_TYPE_Q4_0 :
5141
5322
case GGML_TYPE_Q4_1 :
5323
+ {
5324
+ ggml_compute_forward_add_q_f32 (params , src0 , src1 , dst );
5325
+ } break ;
5142
5326
case GGML_TYPE_I8 :
5143
5327
case GGML_TYPE_I16 :
5144
5328
case GGML_TYPE_I32 :
@@ -6523,27 +6707,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6523
6707
//}
6524
6708
}
6525
6709
6526
- static const quantize_fns_t quantize_fns [GGML_TYPE_COUNT ] = {
6527
- [GGML_TYPE_Q4_0 ] = {
6528
- .dequantize_row_q = dequantize_row_q4_0 ,
6529
- .quantize_row_q = quantize_row_q4_0 ,
6530
- .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_0_reference ,
6531
- .vec_dot_q = ggml_vec_dot_q4_0 ,
6532
- },
6533
- [GGML_TYPE_Q4_1 ] = {
6534
- .dequantize_row_q = dequantize_row_q4_1 ,
6535
- .quantize_row_q = quantize_row_q4_1 ,
6536
- .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_1_reference ,
6537
- .vec_dot_q = ggml_vec_dot_q4_1 ,
6538
- },
6539
- };
6540
-
6541
- // For internal test use
6542
- quantize_fns_t ggml_internal_get_quantize_fn (size_t i ) {
6543
- GGML_ASSERT (i < GGML_TYPE_COUNT );
6544
- return quantize_fns [i ];
6545
- }
6546
-
6547
6710
static void ggml_compute_forward_mul_mat_q_f32 (
6548
6711
const struct ggml_compute_params * params ,
6549
6712
const struct ggml_tensor * src0 ,
0 commit comments