Skip to content

Commit 35dacd1

Browse files
CISC0cc4mqnixsynapse
committed
ggml : implement GLU for split up/gate (#14181)
* implement GLU for split up/gate * add tests for ggml_glu_split * Vulkan: Implement glu_split logic and shader support * add split to logging [no ci] * SYCL: refactor element_size ops and add split up and gate support to gated kernels * SYCL: switch GEGLU to use tanh approximation --------- Co-authored-by: 0cc4m <[email protected]> Co-authored-by: Akarshan <[email protected]>
1 parent a9aedf4 commit 35dacd1

File tree

14 files changed

+919
-1387
lines changed

14 files changed

+919
-1387
lines changed

ggml/include/ggml.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,29 @@ extern "C" {
11321132
struct ggml_context * ctx,
11331133
struct ggml_tensor * a);
11341134

1135+
// A: n columns, r rows,
1136+
// B: n columns, r rows,
1137+
GGML_API struct ggml_tensor * ggml_glu_split(
1138+
struct ggml_context * ctx,
1139+
struct ggml_tensor * a,
1140+
struct ggml_tensor * b,
1141+
enum ggml_glu_op op);
1142+
1143+
GGML_API struct ggml_tensor * ggml_reglu_split(
1144+
struct ggml_context * ctx,
1145+
struct ggml_tensor * a,
1146+
struct ggml_tensor * b);
1147+
1148+
GGML_API struct ggml_tensor * ggml_geglu_split(
1149+
struct ggml_context * ctx,
1150+
struct ggml_tensor * a,
1151+
struct ggml_tensor * b);
1152+
1153+
GGML_API struct ggml_tensor * ggml_swiglu_split(
1154+
struct ggml_context * ctx,
1155+
struct ggml_tensor * a,
1156+
struct ggml_tensor * b);
1157+
11351158
// normalize along rows
11361159
GGML_API struct ggml_tensor * ggml_norm(
11371160
struct ggml_context * ctx,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 120 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3201,14 +3201,24 @@ static void ggml_compute_forward_reglu_f32(
32013201
ggml_tensor * dst) {
32023202

32033203
const ggml_tensor * src0 = dst->src[0];
3204+
const ggml_tensor * src1 = dst->src[1];
3205+
char * src0_d = (char *) src0->data;
3206+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3207+
const size_t src0_o = src0->nb[1];
3208+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
32043209

32053210
GGML_ASSERT(ggml_is_contiguous_1(src0));
32063211
GGML_ASSERT(ggml_is_contiguous_1(dst));
32073212

3213+
if (src1) {
3214+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3215+
GGML_ASSERT(src0->type == src1->type);
3216+
}
3217+
32083218
const int ith = params->ith;
32093219
const int nth = params->nth;
32103220

3211-
const int nc = src0->ne[0] / 2;
3221+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
32123222
const int nr = ggml_nrows(src0);
32133223

32143224
GGML_ASSERT(dst->ne[0] == nc);
@@ -3224,10 +3234,15 @@ static void ggml_compute_forward_reglu_f32(
32243234
const int ir1 = MIN(ir0 + dr, nr);
32253235

32263236
for (int i1 = ir0; i1 < ir1; i1++) {
3227-
ggml_vec_reglu_f32(nc,
3228-
(float *) ((char *) dst->data + i1*( dst->nb[1])),
3229-
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3230-
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
3237+
float * src0_p = (float *) (src0_d + i1*src0_o);
3238+
float * src1_p = (float *) (src1_d + i1*src1_o);
3239+
3240+
if (!src1) {
3241+
src0_p += swapped ? nc : 0;
3242+
src1_p += swapped ? 0 : nc;
3243+
}
3244+
3245+
ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
32313246

32323247
#ifndef NDEBUG
32333248
for (int k = 0; k < nc; k++) {
@@ -3245,14 +3260,24 @@ static void ggml_compute_forward_reglu_f16(
32453260
ggml_tensor * dst) {
32463261

32473262
const ggml_tensor * src0 = dst->src[0];
3263+
const ggml_tensor * src1 = dst->src[1];
3264+
char * src0_d = (char *) src0->data;
3265+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3266+
const size_t src0_o = src0->nb[1];
3267+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
32483268

32493269
GGML_ASSERT(ggml_is_contiguous_1(src0));
32503270
GGML_ASSERT(ggml_is_contiguous_1(dst));
32513271

3272+
if (src1) {
3273+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3274+
GGML_ASSERT(src0->type == src1->type);
3275+
}
3276+
32523277
const int ith = params->ith;
32533278
const int nth = params->nth;
32543279

3255-
const int nc = src0->ne[0] / 2;
3280+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
32563281
const int nr = ggml_nrows(src0);
32573282

32583283
GGML_ASSERT(dst->ne[0] == nc);
@@ -3268,10 +3293,15 @@ static void ggml_compute_forward_reglu_f16(
32683293
const int ir1 = MIN(ir0 + dr, nr);
32693294

32703295
for (int i1 = ir0; i1 < ir1; i1++) {
3271-
ggml_vec_reglu_f16(nc,
3272-
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
3273-
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3274-
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
3296+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3297+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3298+
3299+
if (!src1) {
3300+
src0_p += swapped ? nc : 0;
3301+
src1_p += swapped ? 0 : nc;
3302+
}
3303+
3304+
ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
32753305

32763306
#ifndef NDEBUG
32773307
for (int k = 0; k < nc; k++) {
@@ -3314,14 +3344,24 @@ static void ggml_compute_forward_geglu_f32(
33143344
ggml_tensor * dst) {
33153345

33163346
const ggml_tensor * src0 = dst->src[0];
3347+
const ggml_tensor * src1 = dst->src[1];
3348+
char * src0_d = (char *) src0->data;
3349+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3350+
const size_t src0_o = src0->nb[1];
3351+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
33173352

33183353
GGML_ASSERT(ggml_is_contiguous_1(src0));
33193354
GGML_ASSERT(ggml_is_contiguous_1(dst));
33203355

3356+
if (src1) {
3357+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3358+
GGML_ASSERT(src0->type == src1->type);
3359+
}
3360+
33213361
const int ith = params->ith;
33223362
const int nth = params->nth;
33233363

3324-
const int nc = src0->ne[0] / 2;
3364+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
33253365
const int nr = ggml_nrows(src0);
33263366

33273367
GGML_ASSERT(dst->ne[0] == nc);
@@ -3337,10 +3377,15 @@ static void ggml_compute_forward_geglu_f32(
33373377
const int ir1 = MIN(ir0 + dr, nr);
33383378

33393379
for (int i1 = ir0; i1 < ir1; i1++) {
3340-
ggml_vec_geglu_f32(nc,
3341-
(float *) ((char *) dst->data + i1*( dst->nb[1])),
3342-
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3343-
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
3380+
float * src0_p = (float *) (src0_d + i1*src0_o);
3381+
float * src1_p = (float *) (src1_d + i1*src1_o);
3382+
3383+
if (!src1) {
3384+
src0_p += swapped ? nc : 0;
3385+
src1_p += swapped ? 0 : nc;
3386+
}
3387+
3388+
ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
33443389

33453390
#ifndef NDEBUG
33463391
for (int k = 0; k < nc; k++) {
@@ -3358,14 +3403,24 @@ static void ggml_compute_forward_geglu_f16(
33583403
ggml_tensor * dst) {
33593404

33603405
const ggml_tensor * src0 = dst->src[0];
3406+
const ggml_tensor * src1 = dst->src[1];
3407+
char * src0_d = (char *) src0->data;
3408+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3409+
const size_t src0_o = src0->nb[1];
3410+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
33613411

33623412
GGML_ASSERT(ggml_is_contiguous_1(src0));
33633413
GGML_ASSERT(ggml_is_contiguous_1(dst));
33643414

3415+
if (src1) {
3416+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3417+
GGML_ASSERT(src0->type == src1->type);
3418+
}
3419+
33653420
const int ith = params->ith;
33663421
const int nth = params->nth;
33673422

3368-
const int nc = src0->ne[0] / 2;
3423+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
33693424
const int nr = ggml_nrows(src0);
33703425

33713426
GGML_ASSERT(dst->ne[0] == nc);
@@ -3381,10 +3436,15 @@ static void ggml_compute_forward_geglu_f16(
33813436
const int ir1 = MIN(ir0 + dr, nr);
33823437

33833438
for (int i1 = ir0; i1 < ir1; i1++) {
3384-
ggml_vec_geglu_f16(nc,
3385-
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
3386-
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3387-
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
3439+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3440+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3441+
3442+
if (!src1) {
3443+
src0_p += swapped ? nc : 0;
3444+
src1_p += swapped ? 0 : nc;
3445+
}
3446+
3447+
ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
33883448

33893449
#ifndef NDEBUG
33903450
for (int k = 0; k < nc; k++) {
@@ -3427,14 +3487,24 @@ static void ggml_compute_forward_swiglu_f32(
34273487
ggml_tensor * dst) {
34283488

34293489
const ggml_tensor * src0 = dst->src[0];
3490+
const ggml_tensor * src1 = dst->src[1];
3491+
char * src0_d = (char *) src0->data;
3492+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3493+
const size_t src0_o = src0->nb[1];
3494+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
34303495

34313496
GGML_ASSERT(ggml_is_contiguous_1(src0));
34323497
GGML_ASSERT(ggml_is_contiguous_1(dst));
34333498

3499+
if (src1) {
3500+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3501+
GGML_ASSERT(src0->type == src1->type);
3502+
}
3503+
34343504
const int ith = params->ith;
34353505
const int nth = params->nth;
34363506

3437-
const int nc = src0->ne[0] / 2;
3507+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
34383508
const int nr = ggml_nrows(src0);
34393509

34403510
GGML_ASSERT(dst->ne[0] == nc);
@@ -3450,10 +3520,15 @@ static void ggml_compute_forward_swiglu_f32(
34503520
const int ir1 = MIN(ir0 + dr, nr);
34513521

34523522
for (int i1 = ir0; i1 < ir1; i1++) {
3453-
ggml_vec_swiglu_f32(nc,
3454-
(float *) ((char *) dst->data + i1*( dst->nb[1])),
3455-
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3456-
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
3523+
float * src0_p = (float *) (src0_d + i1*src0_o);
3524+
float * src1_p = (float *) (src1_d + i1*src1_o);
3525+
3526+
if (!src1) {
3527+
src0_p += swapped ? nc : 0;
3528+
src1_p += swapped ? 0 : nc;
3529+
}
3530+
3531+
ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
34573532

34583533
#ifndef NDEBUG
34593534
for (int k = 0; k < nc; k++) {
@@ -3471,14 +3546,24 @@ static void ggml_compute_forward_swiglu_f16(
34713546
ggml_tensor * dst) {
34723547

34733548
const ggml_tensor * src0 = dst->src[0];
3549+
const ggml_tensor * src1 = dst->src[1];
3550+
char * src0_d = (char *) src0->data;
3551+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3552+
const size_t src0_o = src0->nb[1];
3553+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
34743554

34753555
GGML_ASSERT(ggml_is_contiguous_1(src0));
34763556
GGML_ASSERT(ggml_is_contiguous_1(dst));
34773557

3558+
if (src1) {
3559+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3560+
GGML_ASSERT(src0->type == src1->type);
3561+
}
3562+
34783563
const int ith = params->ith;
34793564
const int nth = params->nth;
34803565

3481-
const int nc = src0->ne[0] / 2;
3566+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
34823567
const int nr = ggml_nrows(src0);
34833568

34843569
GGML_ASSERT(dst->ne[0] == nc);
@@ -3494,10 +3579,15 @@ static void ggml_compute_forward_swiglu_f16(
34943579
const int ir1 = MIN(ir0 + dr, nr);
34953580

34963581
for (int i1 = ir0; i1 < ir1; i1++) {
3497-
ggml_vec_swiglu_f16(nc,
3498-
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
3499-
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
3500-
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
3582+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3583+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3584+
3585+
if (!src1) {
3586+
src0_p += swapped ? nc : 0;
3587+
src1_p += swapped ? 0 : nc;
3588+
}
3589+
3590+
ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
35013591

35023592
#ifndef NDEBUG
35033593
for (int k = 0; k < nc; k++) {

0 commit comments

Comments
 (0)