Skip to content

Commit 55390bc

Browse files
committed
ggml : sync ggml (ggml_alibi)
1 parent 5fba3c0 commit 55390bc

File tree

2 files changed

+210
-2
lines changed

2 files changed

+210
-2
lines changed

ggml.c

Lines changed: 201 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4034,7 +4034,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
40344034
"MAP_BINARY",
40354035
};
40364036

4037-
static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
4037+
static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
40384038

40394039
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
40404040
"none",
@@ -4082,7 +4082,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
40824082
"f(x,y)",
40834083
};
40844084

4085-
static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
4085+
static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
40864086

40874087
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
40884088
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -6080,6 +6080,37 @@ struct ggml_tensor * ggml_rope(
60806080
return result;
60816081
}
60826082

6083+
// ggml_alibi
6084+
6085+
struct ggml_tensor * ggml_alibi(
6086+
struct ggml_context * ctx,
6087+
struct ggml_tensor * a,
6088+
int n_past,
6089+
int n_head) {
6090+
GGML_ASSERT(n_past >= 0);
6091+
bool is_node = false;
6092+
6093+
if (a->grad) {
6094+
GGML_ASSERT(false); // TODO: implement backward
6095+
is_node = true;
6096+
}
6097+
6098+
// TODO: when implement backward, fix this:
6099+
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6100+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
6101+
6102+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
6103+
((int32_t *) b->data)[0] = n_past;
6104+
((int32_t *) b->data)[1] = n_head;
6105+
6106+
result->op = GGML_OP_ALIBI;
6107+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6108+
result->src0 = a;
6109+
result->src1 = b;
6110+
6111+
return result;
6112+
}
6113+
60836114
// ggml_conv_1d_1s
60846115

60856116
struct ggml_tensor * ggml_conv_1d_1s(
@@ -9300,6 +9331,162 @@ static void ggml_compute_forward_soft_max(
93009331
}
93019332
}
93029333

9334+
// ggml_compute_forward_alibi
9335+
9336+
static void ggml_compute_forward_alibi_f32(
9337+
const struct ggml_compute_params * params,
9338+
const struct ggml_tensor * src0,
9339+
const struct ggml_tensor * src1,
9340+
struct ggml_tensor * dst) {
9341+
assert(params->ith == 0);
9342+
assert(src1->type == GGML_TYPE_I32);
9343+
assert(ggml_nelements(src1) == 2);
9344+
9345+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9346+
return;
9347+
}
9348+
9349+
const int n_past = ((int32_t *) src1->data)[0];
9350+
const int n_head = ((int32_t *) src1->data)[1];
9351+
9352+
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
9353+
const int ne1 = src0->ne[1]; // seq_len_without_past
9354+
//const int ne2 = src0->ne[2]; // n_head -> this is k
9355+
//const int ne3 = src0->ne[3]; // 1 -> bsz
9356+
9357+
const int n = ggml_nrows(src0);
9358+
const int ne2_ne3 = n/ne1; // ne2*ne3
9359+
9360+
const int nb0 = src0->nb[0];
9361+
const int nb1 = src0->nb[1];
9362+
const int nb2 = src0->nb[2];
9363+
//const int nb3 = src0->nb[3];
9364+
9365+
assert(nb0 == sizeof(float));
9366+
assert(ne1+n_past == ne0);
9367+
9368+
// add alibi to src0 (KQ_scaled)
9369+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
9370+
9371+
const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
9372+
const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
9373+
9374+
for (int i = 0; i < ne0; i++) {
9375+
for (int j = 0; j < ne1; j++) {
9376+
for (int k = 0; k < ne2_ne3; k++) {
9377+
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
9378+
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
9379+
9380+
// TODO: k*nb2 or k*nb3
9381+
9382+
float m_k;
9383+
9384+
if (k < n_heads_log2_floor) {
9385+
m_k = powf(m0, k + 1);
9386+
} else {
9387+
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
9388+
}
9389+
9390+
pdst[0] = (j+1) * m_k + src[0];
9391+
}
9392+
}
9393+
}
9394+
}
9395+
9396+
9397+
static void ggml_compute_forward_alibi_f16(
9398+
const struct ggml_compute_params * params,
9399+
const struct ggml_tensor * src0,
9400+
const struct ggml_tensor * src1,
9401+
struct ggml_tensor * dst) {
9402+
assert(params->ith == 0);
9403+
assert(src1->type == GGML_TYPE_I32);
9404+
assert(ggml_nelements(src1) == 2);
9405+
9406+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9407+
return;
9408+
}
9409+
9410+
const int n_past = ((int32_t *) src1->data)[0];
9411+
const int n_head = ((int32_t *) src1->data)[1];
9412+
9413+
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
9414+
const int ne1 = src0->ne[1]; // seq_len_without_past
9415+
//const int ne2 = src0->ne[2]; // n_head -> this is k
9416+
//const int ne3 = src0->ne[3]; // 1 -> bsz
9417+
9418+
const int n = ggml_nrows(src0);
9419+
const int ne2_ne3 = n/ne1; // ne2*ne3
9420+
9421+
const int nb0 = src0->nb[0];
9422+
const int nb1 = src0->nb[1];
9423+
const int nb2 = src0->nb[2];
9424+
//const int nb3 = src0->nb[3];
9425+
9426+
assert(nb0 == sizeof(ggml_fp16_t));
9427+
assert(ne1+n_past == ne0);
9428+
9429+
// add alibi to src0 (KQ_scaled)
9430+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
9431+
9432+
const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
9433+
const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
9434+
9435+
for (int i = 0; i < ne0; i++) {
9436+
for (int j = 0; j < ne1; j++) {
9437+
for (int k = 0; k < ne2_ne3; k++) {
9438+
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
9439+
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
9440+
9441+
// TODO: k*nb2 or k*nb3
9442+
9443+
float m_k;
9444+
9445+
if (k < n_heads_log2_floor) {
9446+
m_k = powf(m0, k + 1);
9447+
} else {
9448+
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
9449+
}
9450+
9451+
// we return F32
9452+
pdst[0] = (j+1) * m_k + GGML_FP16_TO_FP32(src[0]);
9453+
}
9454+
}
9455+
}
9456+
}
9457+
9458+
static void ggml_compute_forward_alibi(
9459+
const struct ggml_compute_params * params,
9460+
const struct ggml_tensor * src0,
9461+
const struct ggml_tensor * src1,
9462+
struct ggml_tensor * dst) {
9463+
switch (src0->type) {
9464+
case GGML_TYPE_F16:
9465+
{
9466+
ggml_compute_forward_alibi_f16(params, src0, src1, dst);
9467+
} break;
9468+
case GGML_TYPE_F32:
9469+
{
9470+
ggml_compute_forward_alibi_f32(params, src0, src1, dst);
9471+
} break;
9472+
case GGML_TYPE_Q4_0:
9473+
case GGML_TYPE_Q4_1:
9474+
case GGML_TYPE_Q4_2:
9475+
case GGML_TYPE_Q4_3:
9476+
case GGML_TYPE_Q5_0:
9477+
case GGML_TYPE_Q5_1:
9478+
case GGML_TYPE_Q8_0:
9479+
case GGML_TYPE_Q8_1:
9480+
case GGML_TYPE_I8:
9481+
case GGML_TYPE_I16:
9482+
case GGML_TYPE_I32:
9483+
case GGML_TYPE_COUNT:
9484+
{
9485+
GGML_ASSERT(false);
9486+
} break;
9487+
}
9488+
}
9489+
93039490
// ggml_compute_forward_rope
93049491

93059492
static void ggml_compute_forward_rope_f32(
@@ -10938,6 +11125,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1093811125
{
1093911126
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
1094011127
} break;
11128+
case GGML_OP_ALIBI:
11129+
{
11130+
ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
11131+
} break;
1094111132
case GGML_OP_CONV_1D_1S:
1094211133
{
1094311134
ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
@@ -11140,6 +11331,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1114011331
{
1114111332
GGML_ASSERT(false); // TODO: not implemented
1114211333
} break;
11334+
case GGML_OP_ALIBI:
11335+
{
11336+
GGML_ASSERT(false); // TODO: not implemented
11337+
} break;
1114311338
case GGML_OP_SILU:
1114411339
{
1114511340
GGML_ASSERT(false); // TODO: not implemented
@@ -11673,6 +11868,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1167311868
{
1167411869
node->n_tasks = n_threads;
1167511870
} break;
11871+
case GGML_OP_ALIBI:
11872+
{
11873+
node->n_tasks = 1; //TODO
11874+
} break;
1167611875
case GGML_OP_CONV_1D_1S:
1167711876
case GGML_OP_CONV_1D_2S:
1167811877
{

ggml.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ extern "C" {
269269
GGML_OP_DIAG_MASK_INF,
270270
GGML_OP_SOFT_MAX,
271271
GGML_OP_ROPE,
272+
GGML_OP_ALIBI,
272273
GGML_OP_CONV_1D_1S,
273274
GGML_OP_CONV_1D_2S,
274275

@@ -662,6 +663,14 @@ extern "C" {
662663
int n_dims,
663664
int mode);
664665

666+
// alibi position embedding
667+
// in-place, returns view(a)
668+
struct ggml_tensor * ggml_alibi(
669+
struct ggml_context * ctx,
670+
struct ggml_tensor * a,
671+
int n_past,
672+
int n_head);
673+
665674
// padding = 1
666675
// TODO: we don't support extra parameters for now
667676
// that's why we are hard-coding the stride, padding, and dilation

0 commit comments

Comments
 (0)