@@ -3516,22 +3516,22 @@ static struct ggml_tensor * ggml_rope_impl(
3516
3516
}
3517
3517
3518
3518
bool is_node = false;
3519
- int sections[3 ] = {0, 0, 0};
3519
+ int sections[4 ] = {0, 0, 0, 0};
3520
3520
3521
3521
if (a->grad) {
3522
3522
is_node = true;
3523
3523
}
3524
3524
3525
3525
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3526
3526
3527
- int32_t params[14 ] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3527
+ int32_t params[15 ] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3528
3528
memcpy(params + 5, &freq_base, sizeof(float));
3529
3529
memcpy(params + 6, &freq_scale, sizeof(float));
3530
3530
memcpy(params + 7, &ext_factor, sizeof(float));
3531
3531
memcpy(params + 8, &attn_factor, sizeof(float));
3532
3532
memcpy(params + 9, &beta_fast, sizeof(float));
3533
3533
memcpy(params + 10, &beta_slow, sizeof(float));
3534
- memcpy(params + 11, §ions, sizeof(int) * 3 );
3534
+ memcpy(params + 11, §ions, sizeof(int) * 4 );
3535
3535
ggml_set_op_params(result, params, sizeof(params));
3536
3536
3537
3537
result->op = GGML_OP_ROPE;
@@ -11238,7 +11238,7 @@ static void ggml_rope_cache_init(
11238
11238
}
11239
11239
11240
11240
static void ggml_mrope_cache_init(
11241
- float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[3 ], bool indep_sects,
11241
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4 ], bool indep_sects,
11242
11242
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
11243
11243
float * cache, float sin_sign, float theta_scale) {
11244
11244
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -11406,19 +11406,21 @@ static void ggml_compute_forward_rope_f32(
11406
11406
if (ir++ < ir0) continue;
11407
11407
if (ir > ir1) break;
11408
11408
11409
- if (! is_neox) {
11409
+ if (is_neox || is_mrope ) {
11410
11410
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11411
+ const int64_t ic = i0/2;
11412
+
11411
11413
const float cos_theta = cache[i0 + 0];
11412
11414
const float sin_theta = cache[i0 + 1];
11413
11415
11414
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0 *nb00);
11415
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0 *nb0);
11416
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic *nb00);
11417
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic *nb0);
11416
11418
11417
11419
const float x0 = src[0];
11418
- const float x1 = src[1 ];
11420
+ const float x1 = src[n_dims/2 ];
11419
11421
11420
- dst_data[0] = x0*cos_theta - x1*sin_theta;
11421
- dst_data[1 ] = x0*sin_theta + x1*cos_theta;
11422
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
11423
+ dst_data[n_dims/2 ] = x0*sin_theta + x1*cos_theta;
11422
11424
}
11423
11425
} else if (is_vision){
11424
11426
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@@ -11438,19 +11440,17 @@ static void ggml_compute_forward_rope_f32(
11438
11440
}
11439
11441
} else {
11440
11442
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11441
- const int64_t ic = i0/2;
11442
-
11443
11443
const float cos_theta = cache[i0 + 0];
11444
11444
const float sin_theta = cache[i0 + 1];
11445
11445
11446
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic *nb00);
11447
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic *nb0);
11446
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0 *nb00);
11447
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0 *nb0);
11448
11448
11449
11449
const float x0 = src[0];
11450
- const float x1 = src[n_dims/2 ];
11450
+ const float x1 = src[1 ];
11451
11451
11452
- dst_data[0] = x0*cos_theta - x1*sin_theta;
11453
- dst_data[n_dims/2 ] = x0*sin_theta + x1*cos_theta;
11452
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
11453
+ dst_data[1 ] = x0*sin_theta + x1*cos_theta;
11454
11454
}
11455
11455
}
11456
11456
@@ -11590,19 +11590,21 @@ static void ggml_compute_forward_rope_f16(
11590
11590
if (ir++ < ir0) continue;
11591
11591
if (ir > ir1) break;
11592
11592
11593
- if (! is_neox) {
11593
+ if (is_neox || is_mrope ) {
11594
11594
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11595
+ const int64_t ic = i0/2;
11596
+
11595
11597
const float cos_theta = cache[i0 + 0];
11596
11598
const float sin_theta = cache[i0 + 1];
11597
11599
11598
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0 *nb00);
11599
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0 *nb0);
11600
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic *nb00);
11601
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic *nb0);
11600
11602
11601
11603
const float x0 = GGML_FP16_TO_FP32(src[0]);
11602
- const float x1 = GGML_FP16_TO_FP32(src[1 ]);
11604
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2 ]);
11603
11605
11604
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11605
- dst_data[1 ] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11606
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11607
+ dst_data[n_dims/2 ] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11606
11608
}
11607
11609
} else if (is_vision){
11608
11610
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@@ -11622,19 +11624,17 @@ static void ggml_compute_forward_rope_f16(
11622
11624
}
11623
11625
} else {
11624
11626
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11625
- const int64_t ic = i0/2;
11626
-
11627
11627
const float cos_theta = cache[i0 + 0];
11628
11628
const float sin_theta = cache[i0 + 1];
11629
11629
11630
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic *nb00);
11631
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic *nb0);
11630
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0 *nb00);
11631
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0 *nb0);
11632
11632
11633
11633
const float x0 = GGML_FP16_TO_FP32(src[0]);
11634
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2 ]);
11634
+ const float x1 = GGML_FP16_TO_FP32(src[1 ]);
11635
11635
11636
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11637
- dst_data[n_dims/2 ] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11636
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11637
+ dst_data[1 ] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11638
11638
}
11639
11639
}
11640
11640
0 commit comments