Skip to content

Commit c5d70f5

Browse files
authored
ggml : optimize rope function to avoid call powf in the tight loop (#807)
1 parent be87b6e commit c5d70f5

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

ggml.c

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7507,19 +7507,20 @@ static void ggml_compute_forward_rope_f32(
75077507
// row index used to determine which thread to use
75087508
int ir = 0;
75097509

7510+
const float theta_scale = powf(10000.0, ((float)-2)/n_dims);
7511+
75107512
for (int64_t i3 = 0; i3 < ne3; i3++) {
75117513
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
75127514
const int p = (mode == 0 ? n_past + i2 : i2);
75137515
for (int64_t i1 = 0; i1 < ne1; i1++) {
75147516
if (ir++ < ir0) continue;
75157517
if (ir > ir1) break;
7516-
7518+
float theta = (float)p;
75177519
for (int i0 = 0; i0 < n_dims; i0 += 2) {
7518-
const float theta = powf(10000.0, ((float)-i0)/n_dims);
7519-
7520-
const float cos_theta = cosf(p*theta);
7521-
const float sin_theta = sinf(p*theta);
7520+
const float cos_theta = cosf(theta);
7521+
const float sin_theta = sinf(theta);
75227522

7523+
theta *= theta_scale;
75237524
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
75247525
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
75257526

@@ -7580,19 +7581,20 @@ static void ggml_compute_forward_rope_f16(
75807581
// row index used to determine which thread to use
75817582
int ir = 0;
75827583

7584+
const float theta_scale = powf(10000.0, ((float)-2)/n_dims);
7585+
75837586
for (int64_t i3 = 0; i3 < ne3; i3++) {
75847587
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
75857588
const int p = (mode == 0 ? n_past + i2 : i2);
75867589
for (int64_t i1 = 0; i1 < ne1; i1++) {
75877590
if (ir++ < ir0) continue;
75887591
if (ir > ir1) break;
7589-
7592+
float theta = (float)p;
75907593
for (int i0 = 0; i0 < n_dims; i0 += 2) {
7591-
const float theta = powf(10000.0, ((float)-i0)/n_dims);
7592-
7593-
const float cos_theta = cosf(p*theta);
7594-
const float sin_theta = sinf(p*theta);
7594+
const float cos_theta = cosf(theta);
7595+
const float sin_theta = sinf(theta);
75957596

7597+
theta *= theta_scale;
75967598
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
75977599
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
75987600

0 commit comments

Comments
 (0)