@@ -3528,22 +3528,22 @@ static struct ggml_tensor * ggml_rope_impl(
3528
3528
}
3529
3529
3530
3530
bool is_node = false;
3531
- int sections[3 ] = {0, 0, 0};
3531
+ int sections[4 ] = {0, 0, 0, 0};
3532
3532
3533
3533
if (a->grad) {
3534
3534
is_node = true;
3535
3535
}
3536
3536
3537
3537
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3538
3538
3539
- int32_t params[14 ] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3539
+ int32_t params[15 ] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3540
3540
memcpy(params + 5, &freq_base, sizeof(float));
3541
3541
memcpy(params + 6, &freq_scale, sizeof(float));
3542
3542
memcpy(params + 7, &ext_factor, sizeof(float));
3543
3543
memcpy(params + 8, &attn_factor, sizeof(float));
3544
3544
memcpy(params + 9, &beta_fast, sizeof(float));
3545
3545
memcpy(params + 10, &beta_slow, sizeof(float));
3546
- memcpy(params + 11, §ions, sizeof(int) * 3 );
3546
+ memcpy(params + 11, §ions, sizeof(int) * 4 );
3547
3547
ggml_set_op_params(result, params, sizeof(params));
3548
3548
3549
3549
result->op = GGML_OP_ROPE;
@@ -11255,7 +11255,7 @@ static void ggml_rope_cache_init(
11255
11255
}
11256
11256
11257
11257
static void ggml_mrope_cache_init(
11258
- float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[3 ], bool indep_sects,
11258
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4 ], bool indep_sects,
11259
11259
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
11260
11260
float * cache, float sin_sign, float theta_scale) {
11261
11261
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -11423,19 +11423,21 @@ static void ggml_compute_forward_rope_f32(
11423
11423
if (ir++ < ir0) continue;
11424
11424
if (ir > ir1) break;
11425
11425
11426
- if (! is_neox) {
11426
+ if (is_neox || is_mrope ) {
11427
11427
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11428
+ const int64_t ic = i0/2;
11429
+
11428
11430
const float cos_theta = cache[i0 + 0];
11429
11431
const float sin_theta = cache[i0 + 1];
11430
11432
11431
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0 *nb00);
11432
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0 *nb0);
11433
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic *nb00);
11434
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic *nb0);
11433
11435
11434
11436
const float x0 = src[0];
11435
- const float x1 = src[1 ];
11437
+ const float x1 = src[n_dims/2 ];
11436
11438
11437
- dst_data[0] = x0*cos_theta - x1*sin_theta;
11438
- dst_data[1 ] = x0*sin_theta + x1*cos_theta;
11439
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
11440
+ dst_data[n_dims/2 ] = x0*sin_theta + x1*cos_theta;
11439
11441
}
11440
11442
} else if (is_vision){
11441
11443
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@@ -11455,19 +11457,17 @@ static void ggml_compute_forward_rope_f32(
11455
11457
}
11456
11458
} else {
11457
11459
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11458
- const int64_t ic = i0/2;
11459
-
11460
11460
const float cos_theta = cache[i0 + 0];
11461
11461
const float sin_theta = cache[i0 + 1];
11462
11462
11463
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic *nb00);
11464
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic *nb0);
11463
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0 *nb00);
11464
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0 *nb0);
11465
11465
11466
11466
const float x0 = src[0];
11467
- const float x1 = src[n_dims/2 ];
11467
+ const float x1 = src[1 ];
11468
11468
11469
- dst_data[0] = x0*cos_theta - x1*sin_theta;
11470
- dst_data[n_dims/2 ] = x0*sin_theta + x1*cos_theta;
11469
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
11470
+ dst_data[1 ] = x0*sin_theta + x1*cos_theta;
11471
11471
}
11472
11472
}
11473
11473
@@ -11607,19 +11607,21 @@ static void ggml_compute_forward_rope_f16(
11607
11607
if (ir++ < ir0) continue;
11608
11608
if (ir > ir1) break;
11609
11609
11610
- if (! is_neox) {
11610
+ if (is_neox || is_mrope ) {
11611
11611
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11612
+ const int64_t ic = i0/2;
11613
+
11612
11614
const float cos_theta = cache[i0 + 0];
11613
11615
const float sin_theta = cache[i0 + 1];
11614
11616
11615
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0 *nb00);
11616
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0 *nb0);
11617
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic *nb00);
11618
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic *nb0);
11617
11619
11618
11620
const float x0 = GGML_FP16_TO_FP32(src[0]);
11619
- const float x1 = GGML_FP16_TO_FP32(src[1 ]);
11621
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2 ]);
11620
11622
11621
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11622
- dst_data[1 ] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11623
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11624
+ dst_data[n_dims/2 ] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11623
11625
}
11624
11626
} else if (is_vision){
11625
11627
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@@ -11639,19 +11641,17 @@ static void ggml_compute_forward_rope_f16(
11639
11641
}
11640
11642
} else {
11641
11643
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11642
- const int64_t ic = i0/2;
11643
-
11644
11644
const float cos_theta = cache[i0 + 0];
11645
11645
const float sin_theta = cache[i0 + 1];
11646
11646
11647
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic *nb00);
11648
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic *nb0);
11647
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0 *nb00);
11648
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0 *nb0);
11649
11649
11650
11650
const float x0 = GGML_FP16_TO_FP32(src[0]);
11651
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2 ]);
11651
+ const float x1 = GGML_FP16_TO_FP32(src[1 ]);
11652
11652
11653
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11654
- dst_data[n_dims/2 ] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11653
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11654
+ dst_data[1 ] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11655
11655
}
11656
11656
}
11657
11657
0 commit comments