Skip to content

Commit 06b5c62

Browse files
committed
cuda : restore lost changes (StableLM rope)
1 parent c6b3d19 commit 06b5c62

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

ggml-cuda.cu

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4774,8 +4774,8 @@ static __global__ void rope(
47744774

47754775
template<typename T, bool has_pos>
47764776
static __global__ void rope_neox(
4777-
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
4778-
float ext_factor, float attn_factor, rope_corr_dims corr_dims
4777+
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
4778+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
47794779
) {
47804780
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
47814781

@@ -4784,23 +4784,25 @@ static __global__ void rope_neox(
47844784
}
47854785

47864786
const int row = blockDim.x*blockIdx.x + threadIdx.x;
4787-
const int i = row*ncols + col/2;
4787+
const int ib = col / n_dims;
4788+
const int ic = col % n_dims;
4789+
4790+
const int i = row*ncols + ib*n_dims + ic/2;
47884791
const int i2 = row/p_delta_rows;
47894792

4790-
// simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
4791-
const float cur_rot = -float(col)/ncols;
4793+
float cur_rot = inv_ndims * ic - ib;
47924794

47934795
const int p = has_pos ? pos[i2] : 0;
4794-
const float theta_base = p*powf(freq_base, cur_rot);
4796+
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
47954797

47964798
float cos_theta, sin_theta;
47974799
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
47984800

47994801
const float x0 = x[i + 0];
4800-
const float x1 = x[i + ncols/2];
4802+
const float x1 = x[i + n_dims/2];
48014803

4802-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
4803-
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
4804+
dst[i + 0] = x0*cos_theta - x1*sin_theta;
4805+
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
48044806
}
48054807

48064808
static __global__ void rope_glm_f32(
@@ -6085,20 +6087,26 @@ static void rope_cuda(
60856087

60866088
template<typename T>
60876089
static void rope_neox_cuda(
6088-
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
6090+
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
60896091
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
60906092
) {
60916093
GGML_ASSERT(ncols % 2 == 0);
60926094
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
60936095
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
60946096
const dim3 block_nums(nrows, num_blocks_x, 1);
6097+
6098+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
6099+
const float inv_ndims = -1.0f / n_dims;
6100+
60956101
if (pos == nullptr) {
60966102
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
6097-
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
6103+
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
6104+
theta_scale, inv_ndims
60986105
);
60996106
} else {
61006107
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
6101-
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
6108+
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
6109+
theta_scale, inv_ndims
61026110
);
61036111
}
61046112
}
@@ -7039,12 +7047,12 @@ inline void ggml_cuda_op_rope(
70397047
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
70407048
if (src0->type == GGML_TYPE_F32) {
70417049
rope_neox_cuda(
7042-
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
7050+
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
70437051
attn_factor, corr_dims, main_stream
70447052
);
70457053
} else if (src0->type == GGML_TYPE_F16) {
70467054
rope_neox_cuda(
7047-
(const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
7055+
(const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
70487056
attn_factor, corr_dims, main_stream
70497057
);
70507058
} else {

0 commit comments

Comments
 (0)