Skip to content

Commit 99bb260

Browse files
committed
metal : implement RoPE (mode = 2) + avoid ggml_repeat
1 parent e3c52bd commit 99bb260

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

ggml-metal.metal

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,25 @@ kernel void kernel_rope(
571571
dst_data[1] = x0*sin_theta + x1*cos_theta;
572572
}
573573
} else {
574-
// TODO: implement
574+
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
575+
for (int64_t ic = 0; ic < n_dims; ic += 2) {
576+
const float cos_theta = cos(theta);
577+
const float sin_theta = sin(theta);
578+
579+
theta *= theta_scale;
580+
581+
const int64_t i0 = ib*n_dims + ic/2;
582+
583+
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
584+
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
585+
586+
const float x0 = src[0];
587+
const float x1 = src[n_dims/2];
588+
589+
dst_data[0] = x0*cos_theta - x1*sin_theta;
590+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
591+
}
592+
}
575593
}
576594
}
577595

llama.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,19 +2445,15 @@ static struct ggml_cgraph * llm_build_falcon(
24452445
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
24462446

24472447
attn_norm = ggml_add(ctx0,
2448-
ggml_mul(ctx0,
2449-
ggml_repeat(ctx0, model.layers[il].attn_norm, attn_norm),
2450-
attn_norm),
2451-
ggml_repeat(ctx0, model.layers[il].attn_norm_b, attn_norm));
2448+
ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm),
2449+
model.layers[il].attn_norm_b);
24522450

24532451
if (model.layers[il].attn_norm_2) { // Falcon-40B
24542452
cur = ggml_norm(ctx0, inpL, norm_eps);
24552453

24562454
cur = ggml_add(ctx0,
2457-
ggml_mul(ctx0,
2458-
ggml_repeat(ctx0, model.layers[il].attn_norm_2, cur),
2459-
cur),
2460-
ggml_repeat(ctx0, model.layers[il].attn_norm_2_b, cur));
2455+
ggml_mul(ctx0, cur, model.layers[il].attn_norm_2),
2456+
model.layers[il].attn_norm_2_b);
24612457
} else { // Falcon 7B
24622458
cur = attn_norm;
24632459
}
@@ -2595,10 +2591,8 @@ static struct ggml_cgraph * llm_build_falcon(
25952591
cur = ggml_norm(ctx0, inpL, norm_eps);
25962592

25972593
cur = ggml_add(ctx0,
2598-
ggml_mul(ctx0,
2599-
ggml_repeat(ctx0, model.output_norm, cur),
2600-
cur),
2601-
ggml_repeat(ctx0, model.output_norm_b, cur));
2594+
ggml_mul(ctx0, cur, model.output_norm),
2595+
model.output_norm_b);
26022596
ggml_set_name(cur, "result_norm");
26032597
}
26042598

0 commit comments

Comments
 (0)