Skip to content

Commit b822605

Browse files
committed
ggml : fixes (hopefully)
ggml-ci
1 parent 9d5605f commit b822605

File tree

4 files changed

+28
-51
lines changed

4 files changed

+28
-51
lines changed

ggml-cuda/rope.cu

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ static __global__ void rope(
6161
template<typename T, bool has_pos, bool has_freq_facs>
6262
static __global__ void rope_neox(
6363
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
64+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
6565
) {
6666
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
6767

@@ -85,15 +85,13 @@ static __global__ void rope_neox(
8585
const int i = row*ncols + ib*n_dims + ic/2;
8686
const int i2 = row/p_delta_rows;
8787

88-
float cur_rot = inv_ndims * ic - ib;
89-
9088
const int p = has_pos ? pos[i2] : 0;
9189
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
9290

93-
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
91+
const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
9492

9593
float cos_theta, sin_theta;
96-
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
94+
rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
9795

9896
const float x0 = x[i + 0];
9997
const float x1 = x[i + n_dims/2];
@@ -174,30 +172,29 @@ static void rope_neox_cuda(
174172
const dim3 block_nums(nrows, num_blocks_x, 1);
175173

176174
const float theta_scale = powf(freq_base, -2.0f/n_dims);
177-
const float inv_ndims = -1.0f / n_dims;
178175

179176
if (pos == nullptr) {
180177
if (freq_factors == nullptr) {
181178
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
182179
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
183-
theta_scale, inv_ndims, freq_factors
180+
theta_scale, freq_factors
184181
);
185182
} else {
186183
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
187184
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
188-
theta_scale, inv_ndims, freq_factors
185+
theta_scale, freq_factors
189186
);
190187
}
191188
} else {
192189
if (freq_factors == nullptr) {
193190
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
194191
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
195-
theta_scale, inv_ndims, freq_factors
192+
theta_scale, freq_factors
196193
);
197194
} else {
198195
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
199196
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
200-
theta_scale, inv_ndims, freq_factors
197+
theta_scale, freq_factors
201198
);
202199
}
203200
}

ggml-metal.metal

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,13 +1767,13 @@ kernel void kernel_rope(
17671767

17681768
const int64_t p = pos[i2];
17691769

1770-
const float theta_0 = (float)p;
1770+
const float theta_base = (float)p;
17711771
const float inv_ndims = -1.f/n_dims;
17721772

17731773
if (!is_neox) {
17741774
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
17751776

1776-
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
17771777
float cos_theta, sin_theta;
17781778
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
17791779

@@ -1789,18 +1789,14 @@ kernel void kernel_rope(
17891789
} else {
17901790
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
17911791
if (ic < n_dims) {
1792-
const int64_t ib = 0;
1792+
const int64_t i0 = ic/2;
17931793

1794-
// simplified from `(ib * n_dims + ic) * inv_ndims`
1795-
const float cur_rot = inv_ndims*ic - ib;
1796-
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1794+
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
17971795

1798-
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1796+
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
17991797

18001798
float cos_theta, sin_theta;
1801-
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1802-
1803-
const int64_t i0 = ib*n_dims + ic/2;
1799+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
18041800

18051801
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
18061802
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

ggml.c

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14358,7 +14358,7 @@ static void ggml_compute_forward_rope_f32(
1435814358
int ir = 0;
1435914359

1436014360
const float theta_scale = powf(freq_base, -2.0f/n_dims);
14361-
const float inv_ndims = -1.f/n_dims;
14361+
1436214362
float corr_dims[2];
1436314363
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1436414364

@@ -14407,7 +14407,7 @@ static void ggml_compute_forward_rope_f32(
1440714407
const float cos_block_theta = cosf(block_theta);
1440814408
const float sin_block_theta = sinf(block_theta) * sin_sign;
1440914409

14410-
theta_base *= theta_scale;
14410+
theta_base *= theta_scale;
1441114411
block_theta *= theta_scale;
1441214412

1441314413
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14442,29 +14442,22 @@ static void ggml_compute_forward_rope_f32(
1444214442
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
1444314443
}
1444414444
} else {
14445-
// TODO: this might be wrong for ne0 != n_dims - need double check
14446-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
14447-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14448-
theta_base *= freq_scale;
14445+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
1444914446
for (int64_t ic = 0; ic < ne0; ic += 2) {
1445014447
if (ic < n_dims) {
14451-
const int64_t ib = 0;
14448+
const int64_t i0 = ic/2;
1445214449

14453-
// simplified from `(ib * n_dims + ic) * inv_ndims`
14454-
float cur_rot = inv_ndims * ic - ib;
14455-
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14450+
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
1445614451

1445714452
float cos_theta, sin_theta;
1445814453
rope_yarn(
14459-
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14454+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
1446014455
&cos_theta, &sin_theta
1446114456
);
14462-
sin_theta *= sin_sign;
1446314457

14458+
sin_theta *= sin_sign;
1446414459
theta_base *= theta_scale;
1446514460

14466-
const int64_t i0 = ib*n_dims + ic/2;
14467-
1446814461
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1446914462
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1447014463

@@ -14543,7 +14536,7 @@ static void ggml_compute_forward_rope_f16(
1454314536
int ir = 0;
1454414537

1454514538
const float theta_scale = powf(freq_base, -2.0f/n_dims);
14546-
const float inv_ndims = -1.f/n_dims;
14539+
1454714540
float corr_dims[2];
1454814541
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1454914542

@@ -14592,7 +14585,7 @@ static void ggml_compute_forward_rope_f16(
1459214585
const float cos_block_theta = cosf(block_theta);
1459314586
const float sin_block_theta = sinf(block_theta) * sin_sign;
1459414587

14595-
theta_base *= theta_scale;
14588+
theta_base *= theta_scale;
1459614589
block_theta *= theta_scale;
1459714590

1459814591
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14623,29 +14616,22 @@ static void ggml_compute_forward_rope_f16(
1462314616
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
1462414617
}
1462514618
} else {
14626-
// TODO: this might be wrong for ne0 != n_dims - need double check
14627-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
14628-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14629-
theta_base *= freq_scale;
14619+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
1463014620
for (int64_t ic = 0; ic < ne0; ic += 2) {
1463114621
if (ic < n_dims) {
14632-
const int64_t ib = 0;
14622+
const int64_t i0 = ic/2;
1463314623

14634-
// simplified from `(ib * n_dims + ic) * inv_ndims`
14635-
float cur_rot = inv_ndims * ic - ib;
14636-
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14624+
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
1463714625

1463814626
float cos_theta, sin_theta;
1463914627
rope_yarn(
14640-
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14628+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
1464114629
&cos_theta, &sin_theta
1464214630
);
14643-
sin_theta *= sin_sign;
1464414631

14632+
sin_theta *= sin_sign;
1464514633
theta_base *= theta_scale;
1464614634

14647-
const int64_t i0 = ib*n_dims + ic/2;
14648-
1464914635
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1465014636
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1465114637

ggml_vk_generate_shaders.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2670,14 +2670,12 @@
26702670
const uint i = row*p.ncols + ib*p.ndims + ic/2;
26712671
const uint i2 = row/p.p_delta_rows;
26722672
2673-
const float cur_rot = p.inv_ndims * ic - ib;
2674-
26752673
const int pos = data_b[i2];
26762674
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
26772675
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
26782676
26792677
float cos_theta, sin_theta;
2680-
rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);
2678+
rope_yarn(theta_base, ic, cos_theta, sin_theta);
26812679
26822680
const float x0 = float(data_a[i + 0]);
26832681
const float x1 = float(data_a[i + p.ndims/2]);

0 commit comments

Comments
 (0)