@@ -4034,7 +4034,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
4034
4034
"MAP_BINARY" ,
4035
4035
};
4036
4036
4037
- static_assert (GGML_OP_COUNT == 38 , "GGML_OP_COUNT != 38 " );
4037
+ static_assert (GGML_OP_COUNT == 39 , "GGML_OP_COUNT != 39 " );
4038
4038
4039
4039
static const char * GGML_OP_SYMBOL [GGML_OP_COUNT ] = {
4040
4040
"none" ,
@@ -4082,7 +4082,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4082
4082
"f(x,y)" ,
4083
4083
};
4084
4084
4085
- static_assert (GGML_OP_COUNT == 38 , "GGML_OP_COUNT != 38 " );
4085
+ static_assert (GGML_OP_COUNT == 39 , "GGML_OP_COUNT != 39 " );
4086
4086
4087
4087
static_assert (sizeof (struct ggml_object )%GGML_MEM_ALIGN == 0 , "ggml_object size must be a multiple of GGML_MEM_ALIGN" );
4088
4088
static_assert (sizeof (struct ggml_tensor )%GGML_MEM_ALIGN == 0 , "ggml_tensor size must be a multiple of GGML_MEM_ALIGN" );
@@ -6080,6 +6080,37 @@ struct ggml_tensor * ggml_rope(
6080
6080
return result ;
6081
6081
}
6082
6082
6083
+ // ggml_alibi
6084
+
6085
+ struct ggml_tensor * ggml_alibi (
6086
+ struct ggml_context * ctx ,
6087
+ struct ggml_tensor * a ,
6088
+ int n_past ,
6089
+ int n_head ) {
6090
+ GGML_ASSERT (n_past >= 0 );
6091
+ bool is_node = false;
6092
+
6093
+ if (a -> grad ) {
6094
+ GGML_ASSERT (false); // TODO: implement backward
6095
+ is_node = true;
6096
+ }
6097
+
6098
+ // TODO: when implement backward, fix this:
6099
+ //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6100
+ struct ggml_tensor * result = ggml_view_tensor (ctx , a );
6101
+
6102
+ struct ggml_tensor * b = ggml_new_tensor_1d (ctx , GGML_TYPE_I32 , 2 );
6103
+ ((int32_t * ) b -> data )[0 ] = n_past ;
6104
+ ((int32_t * ) b -> data )[1 ] = n_head ;
6105
+
6106
+ result -> op = GGML_OP_ALIBI ;
6107
+ result -> grad = is_node ? ggml_dup_tensor (ctx , result ) : NULL ;
6108
+ result -> src0 = a ;
6109
+ result -> src1 = b ;
6110
+
6111
+ return result ;
6112
+ }
6113
+
6083
6114
// ggml_conv_1d_1s
6084
6115
6085
6116
struct ggml_tensor * ggml_conv_1d_1s (
@@ -9300,6 +9331,162 @@ static void ggml_compute_forward_soft_max(
9300
9331
}
9301
9332
}
9302
9333
9334
+ // ggml_compute_forward_alibi
9335
+
9336
+ static void ggml_compute_forward_alibi_f32 (
9337
+ const struct ggml_compute_params * params ,
9338
+ const struct ggml_tensor * src0 ,
9339
+ const struct ggml_tensor * src1 ,
9340
+ struct ggml_tensor * dst ) {
9341
+ assert (params -> ith == 0 );
9342
+ assert (src1 -> type == GGML_TYPE_I32 );
9343
+ assert (ggml_nelements (src1 ) == 2 );
9344
+
9345
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
9346
+ return ;
9347
+ }
9348
+
9349
+ const int n_past = ((int32_t * ) src1 -> data )[0 ];
9350
+ const int n_head = ((int32_t * ) src1 -> data )[1 ];
9351
+
9352
+ const int ne0 = src0 -> ne [0 ]; // all_seq_len = n_past + ne1
9353
+ const int ne1 = src0 -> ne [1 ]; // seq_len_without_past
9354
+ //const int ne2 = src0->ne[2]; // n_head -> this is k
9355
+ //const int ne3 = src0->ne[3]; // 1 -> bsz
9356
+
9357
+ const int n = ggml_nrows (src0 );
9358
+ const int ne2_ne3 = n /ne1 ; // ne2*ne3
9359
+
9360
+ const int nb0 = src0 -> nb [0 ];
9361
+ const int nb1 = src0 -> nb [1 ];
9362
+ const int nb2 = src0 -> nb [2 ];
9363
+ //const int nb3 = src0->nb[3];
9364
+
9365
+ assert (nb0 == sizeof (float ));
9366
+ assert (ne1 + n_past == ne0 );
9367
+
9368
+ // add alibi to src0 (KQ_scaled)
9369
+ const int n_heads_log2_floor = 1 << (int ) floor (log2 (n_head ));
9370
+
9371
+ const float m0 = powf (2.0f , -8.0f / n_heads_log2_floor );
9372
+ const float m1 = powf (2.0f , -4.0f / n_heads_log2_floor );
9373
+
9374
+ for (int i = 0 ; i < ne0 ; i ++ ) {
9375
+ for (int j = 0 ; j < ne1 ; j ++ ) {
9376
+ for (int k = 0 ; k < ne2_ne3 ; k ++ ) {
9377
+ float * const src = (float * )((char * ) src0 -> data + i * nb0 + j * nb1 + k * nb2 );
9378
+ float * pdst = (float * )((char * ) dst -> data + i * nb0 + j * nb1 + k * nb2 );
9379
+
9380
+ // TODO: k*nb2 or k*nb3
9381
+
9382
+ float m_k ;
9383
+
9384
+ if (k < n_heads_log2_floor ) {
9385
+ m_k = powf (m0 , k + 1 );
9386
+ } else {
9387
+ m_k = powf (m1 , 2 * (k - n_heads_log2_floor ) + 1 );
9388
+ }
9389
+
9390
+ pdst [0 ] = (j + 1 ) * m_k + src [0 ];
9391
+ }
9392
+ }
9393
+ }
9394
+ }
9395
+
9396
+
9397
+ static void ggml_compute_forward_alibi_f16 (
9398
+ const struct ggml_compute_params * params ,
9399
+ const struct ggml_tensor * src0 ,
9400
+ const struct ggml_tensor * src1 ,
9401
+ struct ggml_tensor * dst ) {
9402
+ assert (params -> ith == 0 );
9403
+ assert (src1 -> type == GGML_TYPE_I32 );
9404
+ assert (ggml_nelements (src1 ) == 2 );
9405
+
9406
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
9407
+ return ;
9408
+ }
9409
+
9410
+ const int n_past = ((int32_t * ) src1 -> data )[0 ];
9411
+ const int n_head = ((int32_t * ) src1 -> data )[1 ];
9412
+
9413
+ const int ne0 = src0 -> ne [0 ]; // all_seq_len = n_past + ne1
9414
+ const int ne1 = src0 -> ne [1 ]; // seq_len_without_past
9415
+ //const int ne2 = src0->ne[2]; // n_head -> this is k
9416
+ //const int ne3 = src0->ne[3]; // 1 -> bsz
9417
+
9418
+ const int n = ggml_nrows (src0 );
9419
+ const int ne2_ne3 = n /ne1 ; // ne2*ne3
9420
+
9421
+ const int nb0 = src0 -> nb [0 ];
9422
+ const int nb1 = src0 -> nb [1 ];
9423
+ const int nb2 = src0 -> nb [2 ];
9424
+ //const int nb3 = src0->nb[3];
9425
+
9426
+ assert (nb0 == sizeof (ggml_fp16_t ));
9427
+ assert (ne1 + n_past == ne0 );
9428
+
9429
+ // add alibi to src0 (KQ_scaled)
9430
+ const int n_heads_log2_floor = 1 << (int ) floor (log2 (n_head ));
9431
+
9432
+ const float m0 = powf (2.0f , -8.0f / n_heads_log2_floor );
9433
+ const float m1 = powf (2.0f , -4.0f / n_heads_log2_floor );
9434
+
9435
+ for (int i = 0 ; i < ne0 ; i ++ ) {
9436
+ for (int j = 0 ; j < ne1 ; j ++ ) {
9437
+ for (int k = 0 ; k < ne2_ne3 ; k ++ ) {
9438
+ ggml_fp16_t * const src = (ggml_fp16_t * )((char * ) src0 -> data + i * nb0 + j * nb1 + k * nb2 );
9439
+ float * pdst = (float * )((char * ) dst -> data + i * nb0 + j * nb1 + k * nb2 );
9440
+
9441
+ // TODO: k*nb2 or k*nb3
9442
+
9443
+ float m_k ;
9444
+
9445
+ if (k < n_heads_log2_floor ) {
9446
+ m_k = powf (m0 , k + 1 );
9447
+ } else {
9448
+ m_k = powf (m1 , 2 * (k - n_heads_log2_floor ) + 1 );
9449
+ }
9450
+
9451
+ // we return F32
9452
+ pdst [0 ] = (j + 1 ) * m_k + GGML_FP16_TO_FP32 (src [0 ]);
9453
+ }
9454
+ }
9455
+ }
9456
+ }
9457
+
9458
+ static void ggml_compute_forward_alibi (
9459
+ const struct ggml_compute_params * params ,
9460
+ const struct ggml_tensor * src0 ,
9461
+ const struct ggml_tensor * src1 ,
9462
+ struct ggml_tensor * dst ) {
9463
+ switch (src0 -> type ) {
9464
+ case GGML_TYPE_F16 :
9465
+ {
9466
+ ggml_compute_forward_alibi_f16 (params , src0 , src1 , dst );
9467
+ } break ;
9468
+ case GGML_TYPE_F32 :
9469
+ {
9470
+ ggml_compute_forward_alibi_f32 (params , src0 , src1 , dst );
9471
+ } break ;
9472
+ case GGML_TYPE_Q4_0 :
9473
+ case GGML_TYPE_Q4_1 :
9474
+ case GGML_TYPE_Q4_2 :
9475
+ case GGML_TYPE_Q4_3 :
9476
+ case GGML_TYPE_Q5_0 :
9477
+ case GGML_TYPE_Q5_1 :
9478
+ case GGML_TYPE_Q8_0 :
9479
+ case GGML_TYPE_Q8_1 :
9480
+ case GGML_TYPE_I8 :
9481
+ case GGML_TYPE_I16 :
9482
+ case GGML_TYPE_I32 :
9483
+ case GGML_TYPE_COUNT :
9484
+ {
9485
+ GGML_ASSERT (false);
9486
+ } break ;
9487
+ }
9488
+ }
9489
+
9303
9490
// ggml_compute_forward_rope
9304
9491
9305
9492
static void ggml_compute_forward_rope_f32 (
@@ -10938,6 +11125,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
10938
11125
{
10939
11126
ggml_compute_forward_rope (params , tensor -> src0 , tensor -> src1 , tensor );
10940
11127
} break ;
11128
+ case GGML_OP_ALIBI :
11129
+ {
11130
+ ggml_compute_forward_alibi (params , tensor -> src0 , tensor -> src1 , tensor );
11131
+ } break ;
10941
11132
case GGML_OP_CONV_1D_1S :
10942
11133
{
10943
11134
ggml_compute_forward_conv_1d_1s (params , tensor -> src0 , tensor -> src1 , tensor );
@@ -11140,6 +11331,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
11140
11331
{
11141
11332
GGML_ASSERT (false); // TODO: not implemented
11142
11333
} break ;
11334
+ case GGML_OP_ALIBI :
11335
+ {
11336
+ GGML_ASSERT (false); // TODO: not implemented
11337
+ } break ;
11143
11338
case GGML_OP_SILU :
11144
11339
{
11145
11340
GGML_ASSERT (false); // TODO: not implemented
@@ -11673,6 +11868,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
11673
11868
{
11674
11869
node -> n_tasks = n_threads ;
11675
11870
} break ;
11871
+ case GGML_OP_ALIBI :
11872
+ {
11873
+ node -> n_tasks = 1 ; //TODO
11874
+ } break ;
11676
11875
case GGML_OP_CONV_1D_1S :
11677
11876
case GGML_OP_CONV_1D_2S :
11678
11877
{
0 commit comments