Skip to content

Commit d93fa48

Browse files
committed
fix rope op mode switching, out dated func args
1 parent 2a03ea7 commit d93fa48

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

ggml/src/ggml.c

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3528,22 +3528,22 @@ static struct ggml_tensor * ggml_rope_impl(
35283528
}
35293529

35303530
bool is_node = false;
3531-
int sections[3] = {0, 0, 0};
3531+
int sections[4] = {0, 0, 0, 0};
35323532

35333533
if (a->grad) {
35343534
is_node = true;
35353535
}
35363536

35373537
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
35383538

3539-
int32_t params[14] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3539+
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
35403540
memcpy(params + 5, &freq_base, sizeof(float));
35413541
memcpy(params + 6, &freq_scale, sizeof(float));
35423542
memcpy(params + 7, &ext_factor, sizeof(float));
35433543
memcpy(params + 8, &attn_factor, sizeof(float));
35443544
memcpy(params + 9, &beta_fast, sizeof(float));
35453545
memcpy(params + 10, &beta_slow, sizeof(float));
3546-
memcpy(params + 11, &sections, sizeof(int) * 3);
3546+
memcpy(params + 11, &sections, sizeof(int) * 4);
35473547
ggml_set_op_params(result, params, sizeof(params));
35483548

35493549
result->op = GGML_OP_ROPE;
@@ -11255,7 +11255,7 @@ static void ggml_rope_cache_init(
1125511255
}
1125611256

1125711257
static void ggml_mrope_cache_init(
11258-
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[3], bool indep_sects,
11258+
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
1125911259
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
1126011260
float * cache, float sin_sign, float theta_scale) {
1126111261
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -11423,19 +11423,21 @@ static void ggml_compute_forward_rope_f32(
1142311423
if (ir++ < ir0) continue;
1142411424
if (ir > ir1) break;
1142511425

11426-
if (!is_neox) {
11426+
if (is_neox || is_mrope) {
1142711427
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11428+
const int64_t ic = i0/2;
11429+
1142811430
const float cos_theta = cache[i0 + 0];
1142911431
const float sin_theta = cache[i0 + 1];
1143011432

11431-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11432-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11433+
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
11434+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1143311435

1143411436
const float x0 = src[0];
11435-
const float x1 = src[1];
11437+
const float x1 = src[n_dims/2];
1143611438

11437-
dst_data[0] = x0*cos_theta - x1*sin_theta;
11438-
dst_data[1] = x0*sin_theta + x1*cos_theta;
11439+
dst_data[0] = x0*cos_theta - x1*sin_theta;
11440+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1143911441
}
1144011442
} else if (is_vision){
1144111443
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@@ -11455,19 +11457,17 @@ static void ggml_compute_forward_rope_f32(
1145511457
}
1145611458
} else {
1145711459
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11458-
const int64_t ic = i0/2;
11459-
1146011460
const float cos_theta = cache[i0 + 0];
1146111461
const float sin_theta = cache[i0 + 1];
1146211462

11463-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
11464-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
11463+
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11464+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1146511465

1146611466
const float x0 = src[0];
11467-
const float x1 = src[n_dims/2];
11467+
const float x1 = src[1];
1146811468

11469-
dst_data[0] = x0*cos_theta - x1*sin_theta;
11470-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
11469+
dst_data[0] = x0*cos_theta - x1*sin_theta;
11470+
dst_data[1] = x0*sin_theta + x1*cos_theta;
1147111471
}
1147211472
}
1147311473

@@ -11607,19 +11607,21 @@ static void ggml_compute_forward_rope_f16(
1160711607
if (ir++ < ir0) continue;
1160811608
if (ir > ir1) break;
1160911609

11610-
if (!is_neox) {
11610+
if (is_neox || is_mrope) {
1161111611
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11612+
const int64_t ic = i0/2;
11613+
1161211614
const float cos_theta = cache[i0 + 0];
1161311615
const float sin_theta = cache[i0 + 1];
1161411616

11615-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11616-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11617+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
11618+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1161711619

1161811620
const float x0 = GGML_FP16_TO_FP32(src[0]);
11619-
const float x1 = GGML_FP16_TO_FP32(src[1]);
11621+
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
1162011622

11621-
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11622-
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11623+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11624+
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
1162311625
}
1162411626
} else if (is_vision){
1162511627
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@@ -11639,19 +11641,17 @@ static void ggml_compute_forward_rope_f16(
1163911641
}
1164011642
} else {
1164111643
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11642-
const int64_t ic = i0/2;
11643-
1164411644
const float cos_theta = cache[i0 + 0];
1164511645
const float sin_theta = cache[i0 + 1];
1164611646

11647-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
11648-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
11647+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11648+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1164911649

1165011650
const float x0 = GGML_FP16_TO_FP32(src[0]);
11651-
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
11651+
const float x1 = GGML_FP16_TO_FP32(src[1]);
1165211652

11653-
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11654-
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11653+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11654+
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
1165511655
}
1165611656
}
1165711657

0 commit comments

Comments
 (0)