Skip to content

Commit dcc5520

Browse files
ggml: cache sin/cos for RoPE
1 parent de473f5 commit dcc5520

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

ggml.c

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11720,6 +11720,24 @@ static void ggml_compute_forward_rope_f32(
1172011720
for (int64_t i3 = 0; i3 < ne3; i3++) {
1172111721
for (int64_t i2 = 0; i2 < ne2; i2++) {
1172211722
const int64_t p = pos[i2];
11723+
11724+
float * cache = ((float *) (params->wdata)) + ith*ne0;
11725+
float theta_base_cache = (float) p;
11726+
if (!is_glm && !is_neox) {
11727+
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11728+
float cos_theta, sin_theta;
11729+
rope_yarn(
11730+
theta_base_cache, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11731+
);
11732+
sin_theta *= sin_sign;
11733+
11734+
cache[i0 + 0] = cos_theta;
11735+
cache[i0 + 1] = sin_theta;
11736+
11737+
theta_base_cache *= theta_scale;
11738+
}
11739+
}
11740+
1172311741
for (int64_t i1 = 0; i1 < ne1; i1++) {
1172411742
if (ir++ < ir0) continue;
1172511743
if (ir > ir1) break;
@@ -11753,11 +11771,8 @@ static void ggml_compute_forward_rope_f32(
1175311771
}
1175411772
} else if (!is_neox) {
1175511773
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11756-
float cos_theta, sin_theta;
11757-
rope_yarn(
11758-
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11759-
);
11760-
sin_theta *= sin_sign;
11774+
const float cos_theta = cache[i0 + 0];
11775+
const float sin_theta = cache[i0 + 1];
1176111776

1176211777
// zeta scaling for xPos only:
1176311778
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
@@ -16722,6 +16737,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
1672216737
}
1672316738
} break;
1672416739
case GGML_OP_SOFT_MAX:
16740+
case GGML_OP_ROPE:
1672516741
{
1672616742
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
1672716743
} break;

0 commit comments

Comments
 (0)