@@ -2336,6 +2336,28 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2336
2336
* s = sumf ;
2337
2337
}
2338
2338
2339
+ // TODO: move this to a more sensible place
2340
+ static const quantize_fns_t quantize_fns [GGML_TYPE_COUNT ] = {
2341
+ [GGML_TYPE_Q4_0 ] = {
2342
+ .dequantize_row_q = dequantize_row_q4_0 ,
2343
+ .quantize_row_q = quantize_row_q4_0 ,
2344
+ .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_0_reference ,
2345
+ .vec_dot_q = ggml_vec_dot_q4_0 ,
2346
+ },
2347
+ [GGML_TYPE_Q4_1 ] = {
2348
+ .dequantize_row_q = dequantize_row_q4_1 ,
2349
+ .quantize_row_q = quantize_row_q4_1 ,
2350
+ .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_1_reference ,
2351
+ .vec_dot_q = ggml_vec_dot_q4_1 ,
2352
+ },
2353
+ };
2354
+
2355
+ // For internal test use
2356
+ quantize_fns_t ggml_internal_get_quantize_fn (size_t i ) {
2357
+ GGML_ASSERT (i < GGML_TYPE_COUNT );
2358
+ return quantize_fns [i ];
2359
+ }
2360
+
2339
2361
// compute GGML_VEC_DOT_UNROLL dot products at once
2340
2362
// xs - x row stride in bytes
2341
2363
inline static void ggml_vec_dot_f16_unroll (const int n , const int xs , float * restrict s , void * restrict xv , ggml_fp16_t * restrict y ) {
@@ -5184,13 +5206,13 @@ static void ggml_compute_forward_add_f16_f32(
5184
5206
const int n = ggml_nrows (src0 );
5185
5207
const int nc = src0 -> ne [0 ];
5186
5208
5187
- const size_t nb00 = src0 -> nb [0 ];
5209
+ // const size_t nb00 = src0->nb[0];
5188
5210
const size_t nb01 = src0 -> nb [1 ];
5189
5211
5190
5212
const size_t nb10 = src1 -> nb [0 ];
5191
5213
const size_t nb11 = src1 -> nb [1 ];
5192
5214
5193
- const size_t nb0 = dst -> nb [0 ];
5215
+ // const size_t nb0 = dst->nb[0];
5194
5216
const size_t nb1 = dst -> nb [1 ];
5195
5217
5196
5218
GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
@@ -5202,12 +5224,163 @@ static void ggml_compute_forward_add_f16_f32(
5202
5224
ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5203
5225
for (int i = 0 ; i < nc ; i ++ ) {
5204
5226
float * src1_ptr = (float * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5205
-
5206
5227
dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + * src1_ptr );
5207
5228
}
5208
5229
}
5209
5230
}
5210
5231
5232
+ static void ggml_compute_forward_add_f16_f16 (
5233
+ const struct ggml_compute_params * params ,
5234
+ const struct ggml_tensor * src0 ,
5235
+ const struct ggml_tensor * src1 ,
5236
+ struct ggml_tensor * dst ) {
5237
+ GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5238
+
5239
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5240
+ return ;
5241
+ }
5242
+
5243
+ const int ith = params -> ith ;
5244
+ const int nth = params -> nth ;
5245
+
5246
+ const int n = ggml_nrows (src0 );
5247
+ const int nc = src0 -> ne [0 ];
5248
+
5249
+ //const size_t nb00 = src0->nb[0];
5250
+ const size_t nb01 = src0 -> nb [1 ];
5251
+
5252
+ const size_t nb10 = src1 -> nb [0 ];
5253
+ const size_t nb11 = src1 -> nb [1 ];
5254
+
5255
+ //const size_t nb0 = dst->nb[0];
5256
+ const size_t nb1 = dst -> nb [1 ];
5257
+
5258
+ GGML_ASSERT (src0 -> type == GGML_TYPE_F16 );
5259
+ GGML_ASSERT (src1 -> type == GGML_TYPE_F16 );
5260
+ GGML_ASSERT (dst -> type == GGML_TYPE_F16 );
5261
+
5262
+ for (int j = ith ; j < n ; j += nth ) {
5263
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t * ) ((char * ) dst -> data + j * nb1 );
5264
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t * ) ((char * ) src0 -> data + j * nb01 );
5265
+ for (int i = 0 ; i < nc ; i ++ ) {
5266
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t * ) ((char * ) src1 -> data + j * nb11 + i * nb10 );
5267
+ dst_ptr [i ] = GGML_FP32_TO_FP16 (GGML_FP16_TO_FP32 (src0_ptr [i ]) + GGML_FP16_TO_FP32 (* src1_ptr ));
5268
+ }
5269
+ }
5270
+ }
5271
+
5272
+ static void ggml_compute_forward_add_q_f32 (
5273
+ const struct ggml_compute_params * params ,
5274
+ const struct ggml_tensor * src0 ,
5275
+ const struct ggml_tensor * src1 ,
5276
+ struct ggml_tensor * dst ) {
5277
+ GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5278
+
5279
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5280
+ return ;
5281
+ }
5282
+
5283
+ const int64_t ne00 = src0 -> ne [0 ];
5284
+ const int64_t ne01 = src0 -> ne [1 ];
5285
+ const int64_t ne02 = src0 -> ne [2 ];
5286
+ const int64_t ne03 = src0 -> ne [3 ];
5287
+
5288
+ //const int64_t ne10 = src1->ne[0];
5289
+ const int64_t ne11 = src1 -> ne [1 ];
5290
+ const int64_t ne12 = src1 -> ne [2 ];
5291
+ const int64_t ne13 = src1 -> ne [3 ];
5292
+
5293
+ const int64_t ne0 = dst -> ne [0 ];
5294
+ const int64_t ne1 = dst -> ne [1 ];
5295
+ const int64_t ne2 = dst -> ne [2 ];
5296
+ const int64_t ne3 = dst -> ne [3 ];
5297
+
5298
+ const int nb00 = src0 -> nb [0 ];
5299
+ const int nb01 = src0 -> nb [1 ];
5300
+ const int nb02 = src0 -> nb [2 ];
5301
+ const int nb03 = src0 -> nb [3 ];
5302
+
5303
+ const int nb10 = src1 -> nb [0 ];
5304
+ const int nb11 = src1 -> nb [1 ];
5305
+ const int nb12 = src1 -> nb [2 ];
5306
+ const int nb13 = src1 -> nb [3 ];
5307
+
5308
+ const int nb0 = dst -> nb [0 ];
5309
+ const int nb1 = dst -> nb [1 ];
5310
+ const int nb2 = dst -> nb [2 ];
5311
+ const int nb3 = dst -> nb [3 ];
5312
+
5313
+ const int ith = params -> ith ;
5314
+ const int nth = params -> nth ;
5315
+
5316
+ GGML_ASSERT (ne02 == ne12 );
5317
+ GGML_ASSERT (ne03 == ne13 );
5318
+ GGML_ASSERT (ne2 == ne12 );
5319
+ GGML_ASSERT (ne3 == ne13 );
5320
+
5321
+ const enum ggml_type type = src0 -> type ;
5322
+ dequantize_row_q_t const dequantize_row_q = quantize_fns [type ].dequantize_row_q ;
5323
+ quantize_row_q_t const quantize_row_q = quantize_fns [type ].quantize_row_q ;
5324
+
5325
+ // we don't support permuted src0 or src1
5326
+ GGML_ASSERT (nb00 == (int ) GGML_TYPE_SIZE [type ]);
5327
+ GGML_ASSERT (nb10 == sizeof (float ));
5328
+
5329
+ // dst cannot be transposed or permuted
5330
+ GGML_ASSERT (nb0 <= nb1 );
5331
+ GGML_ASSERT (nb1 <= nb2 );
5332
+ GGML_ASSERT (nb2 <= nb3 );
5333
+
5334
+ GGML_ASSERT (ne0 == ne01 );
5335
+ GGML_ASSERT (ne1 == ne11 );
5336
+ GGML_ASSERT (ne2 == ne02 );
5337
+ GGML_ASSERT (ne3 == ne03 );
5338
+
5339
+ GGML_ASSERT (src0 -> type == GGML_TYPE_Q4_0 || src0 -> type == GGML_TYPE_Q4_1 );
5340
+ GGML_ASSERT (dst -> type == src0 -> type );
5341
+ GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
5342
+
5343
+ // total rows in src0
5344
+ const int nr = ne01 * ne02 * ne03 ;
5345
+
5346
+ // rows per thread
5347
+ const int dr = (nr + nth - 1 )/nth ;
5348
+
5349
+ // row range for this thread
5350
+ const int ir0 = dr * ith ;
5351
+ const int ir1 = MIN (ir0 + dr , nr );
5352
+
5353
+ for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
5354
+ // src0 indices
5355
+ const int i03 = ir /(ne02 * ne01 );
5356
+ const int i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
5357
+ const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
5358
+
5359
+ // src1 and dst are same shape as src0 => same indices
5360
+ const int i13 = i03 ;
5361
+ const int i12 = i02 ;
5362
+ const int i11 = i01 ;
5363
+
5364
+ const int i3 = i03 ;
5365
+ const int i2 = i02 ;
5366
+ const int i1 = i01 ;
5367
+
5368
+ void * src0_row = (void * ) ((char * ) src0 -> data + (i01 * nb01 + i02 * nb02 + i03 * nb03 ));
5369
+ float * src1_row = (float * )((char * ) src1 -> data + (i11 * nb11 + i12 * nb12 + i13 * nb13 ));
5370
+ void * dst_row = (void * ) ((char * ) dst -> data + ( i1 * nb1 + i2 * nb2 + i3 * nb0 ));
5371
+
5372
+ assert (ne00 % 32 == 0 );
5373
+
5374
+ // unquantize row from src0 to temp buffer
5375
+ float tmp [ne00 ];
5376
+ dequantize_row_q (src0_row , tmp , ne00 );
5377
+ // add src1
5378
+ ggml_vec_acc_f32 (ne00 , tmp , src1_row );
5379
+ // quantize row to dst
5380
+ quantize_row_q (tmp , dst_row , ne00 );
5381
+ }
5382
+ }
5383
+
5211
5384
static void ggml_compute_forward_add (
5212
5385
const struct ggml_compute_params * params ,
5213
5386
const struct ggml_tensor * src0 ,
@@ -5220,10 +5393,21 @@ static void ggml_compute_forward_add(
5220
5393
} break ;
5221
5394
case GGML_TYPE_F16 :
5222
5395
{
5223
- ggml_compute_forward_add_f16_f32 (params , src0 , src1 , dst );
5396
+ if (src1 -> type == GGML_TYPE_F16 ) {
5397
+ ggml_compute_forward_add_f16_f16 (params , src0 , src1 , dst );
5398
+ }
5399
+ else if (src1 -> type == GGML_TYPE_F32 ) {
5400
+ ggml_compute_forward_add_f16_f32 (params , src0 , src1 , dst );
5401
+ }
5402
+ else {
5403
+ GGML_ASSERT (false);
5404
+ }
5224
5405
} break ;
5225
5406
case GGML_TYPE_Q4_0 :
5226
5407
case GGML_TYPE_Q4_1 :
5408
+ {
5409
+ ggml_compute_forward_add_q_f32 (params , src0 , src1 , dst );
5410
+ } break ;
5227
5411
case GGML_TYPE_I8 :
5228
5412
case GGML_TYPE_I16 :
5229
5413
case GGML_TYPE_I32 :
@@ -6608,27 +6792,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6608
6792
//}
6609
6793
}
6610
6794
6611
- static const quantize_fns_t quantize_fns [GGML_TYPE_COUNT ] = {
6612
- [GGML_TYPE_Q4_0 ] = {
6613
- .dequantize_row_q = dequantize_row_q4_0 ,
6614
- .quantize_row_q = quantize_row_q4_0 ,
6615
- .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_0_reference ,
6616
- .vec_dot_q = ggml_vec_dot_q4_0 ,
6617
- },
6618
- [GGML_TYPE_Q4_1 ] = {
6619
- .dequantize_row_q = dequantize_row_q4_1 ,
6620
- .quantize_row_q = quantize_row_q4_1 ,
6621
- .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_1_reference ,
6622
- .vec_dot_q = ggml_vec_dot_q4_1 ,
6623
- },
6624
- };
6625
-
6626
- // For internal test use
6627
- quantize_fns_t ggml_internal_get_quantize_fn (size_t i ) {
6628
- GGML_ASSERT (i < GGML_TYPE_COUNT );
6629
- return quantize_fns [i ];
6630
- }
6631
-
6632
6795
static void ggml_compute_forward_mul_mat_q_f32 (
6633
6796
const struct ggml_compute_params * params ,
6634
6797
const struct ggml_tensor * src0 ,
0 commit comments