Skip to content

Commit 241bb45

Browse files
committed
fix rope op mode switching, out dated func args
1 parent f1fa60f commit 241bb45

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

ggml/src/ggml.c

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3516,22 +3516,22 @@ static struct ggml_tensor * ggml_rope_impl(
35163516
}
35173517

35183518
bool is_node = false;
3519-
int sections[3] = {0, 0, 0};
3519+
int sections[4] = {0, 0, 0, 0};
35203520

35213521
if (a->grad) {
35223522
is_node = true;
35233523
}
35243524

35253525
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
35263526

3527-
int32_t params[14] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3527+
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
35283528
memcpy(params + 5, &freq_base, sizeof(float));
35293529
memcpy(params + 6, &freq_scale, sizeof(float));
35303530
memcpy(params + 7, &ext_factor, sizeof(float));
35313531
memcpy(params + 8, &attn_factor, sizeof(float));
35323532
memcpy(params + 9, &beta_fast, sizeof(float));
35333533
memcpy(params + 10, &beta_slow, sizeof(float));
3534-
memcpy(params + 11, &sections, sizeof(int) * 3);
3534+
memcpy(params + 11, &sections, sizeof(int) * 4);
35353535
ggml_set_op_params(result, params, sizeof(params));
35363536

35373537
result->op = GGML_OP_ROPE;
@@ -11238,7 +11238,7 @@ static void ggml_rope_cache_init(
1123811238
}
1123911239

1124011240
static void ggml_mrope_cache_init(
11241-
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[3], bool indep_sects,
11241+
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
1124211242
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
1124311243
float * cache, float sin_sign, float theta_scale) {
1124411244
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -11406,19 +11406,21 @@ static void ggml_compute_forward_rope_f32(
1140611406
if (ir++ < ir0) continue;
1140711407
if (ir > ir1) break;
1140811408

11409-
if (!is_neox) {
11409+
if (is_neox || is_mrope) {
1141011410
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11411+
const int64_t ic = i0/2;
11412+
1141111413
const float cos_theta = cache[i0 + 0];
1141211414
const float sin_theta = cache[i0 + 1];
1141311415

11414-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11415-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11416+
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
11417+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1141611418

1141711419
const float x0 = src[0];
11418-
const float x1 = src[1];
11420+
const float x1 = src[n_dims/2];
1141911421

11420-
dst_data[0] = x0*cos_theta - x1*sin_theta;
11421-
dst_data[1] = x0*sin_theta + x1*cos_theta;
11422+
dst_data[0] = x0*cos_theta - x1*sin_theta;
11423+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1142211424
}
1142311425
} else if (is_vision){
1142411426
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@@ -11438,19 +11440,17 @@ static void ggml_compute_forward_rope_f32(
1143811440
}
1143911441
} else {
1144011442
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11441-
const int64_t ic = i0/2;
11442-
1144311443
const float cos_theta = cache[i0 + 0];
1144411444
const float sin_theta = cache[i0 + 1];
1144511445

11446-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
11447-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
11446+
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11447+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1144811448

1144911449
const float x0 = src[0];
11450-
const float x1 = src[n_dims/2];
11450+
const float x1 = src[1];
1145111451

11452-
dst_data[0] = x0*cos_theta - x1*sin_theta;
11453-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
11452+
dst_data[0] = x0*cos_theta - x1*sin_theta;
11453+
dst_data[1] = x0*sin_theta + x1*cos_theta;
1145411454
}
1145511455
}
1145611456

@@ -11590,19 +11590,21 @@ static void ggml_compute_forward_rope_f16(
1159011590
if (ir++ < ir0) continue;
1159111591
if (ir > ir1) break;
1159211592

11593-
if (!is_neox) {
11593+
if (is_neox || is_mrope) {
1159411594
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11595+
const int64_t ic = i0/2;
11596+
1159511597
const float cos_theta = cache[i0 + 0];
1159611598
const float sin_theta = cache[i0 + 1];
1159711599

11598-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11599-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11600+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
11601+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1160011602

1160111603
const float x0 = GGML_FP16_TO_FP32(src[0]);
11602-
const float x1 = GGML_FP16_TO_FP32(src[1]);
11604+
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
1160311605

11604-
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11605-
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11606+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11607+
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
1160611608
}
1160711609
} else if (is_vision){
1160811610
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@@ -11622,19 +11624,17 @@ static void ggml_compute_forward_rope_f16(
1162211624
}
1162311625
} else {
1162411626
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11625-
const int64_t ic = i0/2;
11626-
1162711627
const float cos_theta = cache[i0 + 0];
1162811628
const float sin_theta = cache[i0 + 1];
1162911629

11630-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
11631-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
11630+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11631+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1163211632

1163311633
const float x0 = GGML_FP16_TO_FP32(src[0]);
11634-
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
11634+
const float x1 = GGML_FP16_TO_FP32(src[1]);
1163511635

11636-
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11637-
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11636+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11637+
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
1163811638
}
1163911639
}
1164011640

0 commit comments

Comments
 (0)