@@ -3201,14 +3201,24 @@ static void ggml_compute_forward_reglu_f32(
3201
3201
ggml_tensor * dst) {
3202
3202
3203
3203
const ggml_tensor * src0 = dst->src [0 ];
3204
+ const ggml_tensor * src1 = dst->src [1 ];
3205
+ char * src0_d = (char *) src0->data ;
3206
+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3207
+ const size_t src0_o = src0->nb [1 ];
3208
+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3204
3209
3205
3210
GGML_ASSERT (ggml_is_contiguous_1 (src0));
3206
3211
GGML_ASSERT (ggml_is_contiguous_1 (dst));
3207
3212
3213
+ if (src1) {
3214
+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3215
+ GGML_ASSERT (src0->type == src1->type );
3216
+ }
3217
+
3208
3218
const int ith = params->ith ;
3209
3219
const int nth = params->nth ;
3210
3220
3211
- const int nc = src0->ne [0 ] / 2 ;
3221
+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
3212
3222
const int nr = ggml_nrows (src0);
3213
3223
3214
3224
GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3224,10 +3234,15 @@ static void ggml_compute_forward_reglu_f32(
3224
3234
const int ir1 = MIN (ir0 + dr, nr);
3225
3235
3226
3236
for (int i1 = ir0; i1 < ir1; i1++) {
3227
- ggml_vec_reglu_f32 (nc,
3228
- (float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3229
- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3230
- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3237
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3238
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3239
+
3240
+ if (!src1) {
3241
+ src0_p += swapped ? nc : 0 ;
3242
+ src1_p += swapped ? 0 : nc;
3243
+ }
3244
+
3245
+ ggml_vec_reglu_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3231
3246
3232
3247
#ifndef NDEBUG
3233
3248
for (int k = 0 ; k < nc; k++) {
@@ -3245,14 +3260,24 @@ static void ggml_compute_forward_reglu_f16(
3245
3260
ggml_tensor * dst) {
3246
3261
3247
3262
const ggml_tensor * src0 = dst->src [0 ];
3263
+ const ggml_tensor * src1 = dst->src [1 ];
3264
+ char * src0_d = (char *) src0->data ;
3265
+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3266
+ const size_t src0_o = src0->nb [1 ];
3267
+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3248
3268
3249
3269
GGML_ASSERT (ggml_is_contiguous_1 (src0));
3250
3270
GGML_ASSERT (ggml_is_contiguous_1 (dst));
3251
3271
3272
+ if (src1) {
3273
+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3274
+ GGML_ASSERT (src0->type == src1->type );
3275
+ }
3276
+
3252
3277
const int ith = params->ith ;
3253
3278
const int nth = params->nth ;
3254
3279
3255
- const int nc = src0->ne [0 ] / 2 ;
3280
+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
3256
3281
const int nr = ggml_nrows (src0);
3257
3282
3258
3283
GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3268,10 +3293,15 @@ static void ggml_compute_forward_reglu_f16(
3268
3293
const int ir1 = MIN (ir0 + dr, nr);
3269
3294
3270
3295
for (int i1 = ir0; i1 < ir1; i1++) {
3271
- ggml_vec_reglu_f16 (nc,
3272
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3273
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3274
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3296
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3297
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3298
+
3299
+ if (!src1) {
3300
+ src0_p += swapped ? nc : 0 ;
3301
+ src1_p += swapped ? 0 : nc;
3302
+ }
3303
+
3304
+ ggml_vec_reglu_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3275
3305
3276
3306
#ifndef NDEBUG
3277
3307
for (int k = 0 ; k < nc; k++) {
@@ -3314,14 +3344,24 @@ static void ggml_compute_forward_geglu_f32(
3314
3344
ggml_tensor * dst) {
3315
3345
3316
3346
const ggml_tensor * src0 = dst->src [0 ];
3347
+ const ggml_tensor * src1 = dst->src [1 ];
3348
+ char * src0_d = (char *) src0->data ;
3349
+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3350
+ const size_t src0_o = src0->nb [1 ];
3351
+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3317
3352
3318
3353
GGML_ASSERT (ggml_is_contiguous_1 (src0));
3319
3354
GGML_ASSERT (ggml_is_contiguous_1 (dst));
3320
3355
3356
+ if (src1) {
3357
+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3358
+ GGML_ASSERT (src0->type == src1->type );
3359
+ }
3360
+
3321
3361
const int ith = params->ith ;
3322
3362
const int nth = params->nth ;
3323
3363
3324
- const int nc = src0->ne [0 ] / 2 ;
3364
+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
3325
3365
const int nr = ggml_nrows (src0);
3326
3366
3327
3367
GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3337,10 +3377,15 @@ static void ggml_compute_forward_geglu_f32(
3337
3377
const int ir1 = MIN (ir0 + dr, nr);
3338
3378
3339
3379
for (int i1 = ir0; i1 < ir1; i1++) {
3340
- ggml_vec_geglu_f32 (nc,
3341
- (float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3342
- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3343
- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3380
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3381
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3382
+
3383
+ if (!src1) {
3384
+ src0_p += swapped ? nc : 0 ;
3385
+ src1_p += swapped ? 0 : nc;
3386
+ }
3387
+
3388
+ ggml_vec_geglu_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3344
3389
3345
3390
#ifndef NDEBUG
3346
3391
for (int k = 0 ; k < nc; k++) {
@@ -3358,14 +3403,24 @@ static void ggml_compute_forward_geglu_f16(
3358
3403
ggml_tensor * dst) {
3359
3404
3360
3405
const ggml_tensor * src0 = dst->src [0 ];
3406
+ const ggml_tensor * src1 = dst->src [1 ];
3407
+ char * src0_d = (char *) src0->data ;
3408
+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3409
+ const size_t src0_o = src0->nb [1 ];
3410
+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3361
3411
3362
3412
GGML_ASSERT (ggml_is_contiguous_1 (src0));
3363
3413
GGML_ASSERT (ggml_is_contiguous_1 (dst));
3364
3414
3415
+ if (src1) {
3416
+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3417
+ GGML_ASSERT (src0->type == src1->type );
3418
+ }
3419
+
3365
3420
const int ith = params->ith ;
3366
3421
const int nth = params->nth ;
3367
3422
3368
- const int nc = src0->ne [0 ] / 2 ;
3423
+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
3369
3424
const int nr = ggml_nrows (src0);
3370
3425
3371
3426
GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3381,10 +3436,15 @@ static void ggml_compute_forward_geglu_f16(
3381
3436
const int ir1 = MIN (ir0 + dr, nr);
3382
3437
3383
3438
for (int i1 = ir0; i1 < ir1; i1++) {
3384
- ggml_vec_geglu_f16 (nc,
3385
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3386
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3387
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3439
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3440
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3441
+
3442
+ if (!src1) {
3443
+ src0_p += swapped ? nc : 0 ;
3444
+ src1_p += swapped ? 0 : nc;
3445
+ }
3446
+
3447
+ ggml_vec_geglu_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3388
3448
3389
3449
#ifndef NDEBUG
3390
3450
for (int k = 0 ; k < nc; k++) {
@@ -3427,14 +3487,24 @@ static void ggml_compute_forward_swiglu_f32(
3427
3487
ggml_tensor * dst) {
3428
3488
3429
3489
const ggml_tensor * src0 = dst->src [0 ];
3490
+ const ggml_tensor * src1 = dst->src [1 ];
3491
+ char * src0_d = (char *) src0->data ;
3492
+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3493
+ const size_t src0_o = src0->nb [1 ];
3494
+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3430
3495
3431
3496
GGML_ASSERT (ggml_is_contiguous_1 (src0));
3432
3497
GGML_ASSERT (ggml_is_contiguous_1 (dst));
3433
3498
3499
+ if (src1) {
3500
+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3501
+ GGML_ASSERT (src0->type == src1->type );
3502
+ }
3503
+
3434
3504
const int ith = params->ith ;
3435
3505
const int nth = params->nth ;
3436
3506
3437
- const int nc = src0->ne [0 ] / 2 ;
3507
+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
3438
3508
const int nr = ggml_nrows (src0);
3439
3509
3440
3510
GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3450,10 +3520,15 @@ static void ggml_compute_forward_swiglu_f32(
3450
3520
const int ir1 = MIN (ir0 + dr, nr);
3451
3521
3452
3522
for (int i1 = ir0; i1 < ir1; i1++) {
3453
- ggml_vec_swiglu_f32 (nc,
3454
- (float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3455
- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3456
- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3523
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3524
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3525
+
3526
+ if (!src1) {
3527
+ src0_p += swapped ? nc : 0 ;
3528
+ src1_p += swapped ? 0 : nc;
3529
+ }
3530
+
3531
+ ggml_vec_swiglu_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3457
3532
3458
3533
#ifndef NDEBUG
3459
3534
for (int k = 0 ; k < nc; k++) {
@@ -3471,14 +3546,24 @@ static void ggml_compute_forward_swiglu_f16(
3471
3546
ggml_tensor * dst) {
3472
3547
3473
3548
const ggml_tensor * src0 = dst->src [0 ];
3549
+ const ggml_tensor * src1 = dst->src [1 ];
3550
+ char * src0_d = (char *) src0->data ;
3551
+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3552
+ const size_t src0_o = src0->nb [1 ];
3553
+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3474
3554
3475
3555
GGML_ASSERT (ggml_is_contiguous_1 (src0));
3476
3556
GGML_ASSERT (ggml_is_contiguous_1 (dst));
3477
3557
3558
+ if (src1) {
3559
+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3560
+ GGML_ASSERT (src0->type == src1->type );
3561
+ }
3562
+
3478
3563
const int ith = params->ith ;
3479
3564
const int nth = params->nth ;
3480
3565
3481
- const int nc = src0->ne [0 ] / 2 ;
3566
+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
3482
3567
const int nr = ggml_nrows (src0);
3483
3568
3484
3569
GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3494,10 +3579,15 @@ static void ggml_compute_forward_swiglu_f16(
3494
3579
const int ir1 = MIN (ir0 + dr, nr);
3495
3580
3496
3581
for (int i1 = ir0; i1 < ir1; i1++) {
3497
- ggml_vec_swiglu_f16 (nc,
3498
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3499
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3500
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3582
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3583
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3584
+
3585
+ if (!src1) {
3586
+ src0_p += swapped ? nc : 0 ;
3587
+ src1_p += swapped ? 0 : nc;
3588
+ }
3589
+
3590
+ ggml_vec_swiglu_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3501
3591
3502
3592
#ifndef NDEBUG
3503
3593
for (int k = 0 ; k < nc; k++) {
0 commit comments