@@ -2303,13 +2303,15 @@ struct ggml_tensor * ggml_repeat(
2303
2303
struct ggml_tensor * ggml_repeat_back (
2304
2304
struct ggml_context * ctx ,
2305
2305
struct ggml_tensor * a ,
2306
- struct ggml_tensor * b ) {
2306
+ struct ggml_tensor * b ,
2307
+ bool gqa_mode ) {
2307
2308
GGML_ASSERT (ggml_can_repeat (b , a ));
2308
2309
2309
2310
struct ggml_tensor * result = ggml_new_tensor (ctx , a -> type , GGML_MAX_DIMS , b -> ne );
2310
2311
2311
2312
result -> op = GGML_OP_REPEAT_BACK ;
2312
2313
result -> src [0 ] = a ;
2314
+ result -> op_params [1 ] = gqa_mode ? 1 : 0 ;
2313
2315
2314
2316
return result ;
2315
2317
}
@@ -5129,7 +5131,7 @@ static void ggml_compute_backward(
5129
5131
if (src1_needs_grads ) {
5130
5132
struct ggml_tensor * tmp = grad ;
5131
5133
if (!ggml_are_same_shape (src0 , src1 )) {
5132
- tmp = ggml_repeat_back (ctx , tmp , src1 );
5134
+ tmp = ggml_repeat_back (ctx , tmp , src1 , false );
5133
5135
}
5134
5136
ggml_add_or_set (ctx , cgraph , isrc1 , tmp );
5135
5137
}
@@ -5174,7 +5176,7 @@ static void ggml_compute_backward(
5174
5176
if (src1_needs_grads ) {
5175
5177
struct ggml_tensor * tmp = ggml_mul (ctx , src0 , grad );
5176
5178
if (!ggml_are_same_shape (src0 , src1 )) {
5177
- tmp = ggml_repeat_back (ctx , tmp , src1 );
5179
+ tmp = ggml_repeat_back (ctx , tmp , src1 , false );
5178
5180
}
5179
5181
ggml_add_or_set (ctx , cgraph , isrc1 , tmp );
5180
5182
}
@@ -5229,7 +5231,7 @@ static void ggml_compute_backward(
5229
5231
} break ;
5230
5232
case GGML_OP_REPEAT : {
5231
5233
if (src0_needs_grads ) {
5232
- ggml_add_or_set (ctx , cgraph , isrc0 , ggml_repeat_back (ctx , grad , src0 ));
5234
+ ggml_add_or_set (ctx , cgraph , isrc0 , ggml_repeat_back (ctx , grad , src0 , false ));
5233
5235
}
5234
5236
} break ;
5235
5237
case GGML_OP_REPEAT_BACK : {
@@ -5268,8 +5270,7 @@ static void ggml_compute_backward(
5268
5270
if (!ggml_are_same_shape (tmp , src0 )) {
5269
5271
GGML_ASSERT (tmp -> ne [0 ] == src0 -> ne [0 ]);
5270
5272
GGML_ASSERT (tmp -> ne [1 ] == src0 -> ne [1 ]);
5271
- tmp = ggml_repeat_back (ctx , tmp , src0 );
5272
- tmp -> op_params [0 ] = 1 ; // FIXME
5273
+ tmp = ggml_repeat_back (ctx , tmp , src0 , true);
5273
5274
}
5274
5275
ggml_add_or_set (ctx , cgraph , isrc0 , tmp );
5275
5276
}
0 commit comments