@@ -5265,6 +5265,8 @@ static void ggml_compute_forward_add_q_f32(
5265
5265
const int ir0 = dr * ith ;
5266
5266
const int ir1 = MIN (ir0 + dr , nr );
5267
5267
5268
+ float * wdata = (float * ) params -> wdata + ne00 * ith ;
5269
+
5268
5270
for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
5269
5271
// src0 indices
5270
5272
const int i03 = ir /(ne02 * ne01 );
@@ -5287,12 +5289,11 @@ static void ggml_compute_forward_add_q_f32(
5287
5289
assert (ne00 % 32 == 0 );
5288
5290
5289
5291
// unquantize row from src0 to temp buffer
5290
- float tmp [ne00 ];
5291
- dequantize_row_q (src0_row , tmp , ne00 );
5292
+ dequantize_row_q (src0_row , wdata , ne00 );
5292
5293
// add src1
5293
- ggml_vec_acc_f32 (ne00 , tmp , src1_row );
5294
+ ggml_vec_acc_f32 (ne00 , wdata , src1_row );
5294
5295
// quantize row to dst
5295
- quantize_row_q (tmp , dst_row , ne00 );
5296
+ quantize_row_q (wdata , dst_row , ne00 );
5296
5297
}
5297
5298
}
5298
5299
@@ -9481,6 +9482,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9481
9482
case GGML_OP_ADD :
9482
9483
{
9483
9484
node -> n_tasks = n_threads ;
9485
+
9486
+ size_t cur = 0 ;
9487
+
9488
+ if (node -> src0 -> type == GGML_TYPE_Q4_0 || node -> src0 -> type == GGML_TYPE_Q4_1 ) {
9489
+ cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ] * node -> src0 -> ne [0 ] * n_threads ;
9490
+ }
9491
+
9492
+ work_size = MAX (work_size , cur );
9484
9493
} break ;
9485
9494
case GGML_OP_SUB :
9486
9495
case GGML_OP_MUL :
0 commit comments