Skip to content

Commit eec6b66

Browse files
committed
ggml-cuda : update rope implementation for parallel decoding
1 parent fa0e677 commit eec6b66

File tree

1 file changed

+36
-14
lines changed

1 file changed

+36
-14
lines changed

ggml-cuda.cu

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stdio.h>
66
#include <atomic>
77
#include <assert.h>
8+
#include <vector>
89

910
#if defined(GGML_USE_HIPBLAS)
1011
#include <hip/hip_runtime.h>
@@ -4355,7 +4356,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
43554356
}
43564357

43574358
// rope == RoPE == rotary positional embedding
4358-
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
4359+
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0,
43594360
const float p_delta, const int p_delta_rows, const float theta_scale) {
43604361
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
43614362

@@ -4365,8 +4366,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
43654366

43664367
const int row = blockDim.x*blockIdx.x + threadIdx.x;
43674368
const int i = row*ncols + col;
4369+
const int i2 = row/p_delta_rows;
43684370

4369-
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
4371+
const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2);
43704372
const float sin_theta = sinf(theta);
43714373
const float cos_theta = cosf(theta);
43724374

@@ -4377,7 +4379,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
43774379
dst[i + 1] = x0*sin_theta + x1*cos_theta;
43784380
}
43794381

4380-
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
4382+
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float * p0,
43814383
const float p_delta, const int p_delta_rows, const float theta_scale) {
43824384
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
43834385

@@ -4387,8 +4389,9 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
43874389

43884390
const int row = blockDim.x*blockIdx.x + threadIdx.x;
43894391
const int i = row*ncols + col/2;
4392+
const int i2 = row/p_delta_rows;
43904393

4391-
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
4394+
const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2);
43924395
const float sin_theta = sinf(theta);
43934396
const float cos_theta = cosf(theta);
43944397

@@ -4399,7 +4402,7 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
43994402
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
44004403
}
44014404

4402-
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
4405+
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float * p0,
44034406
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
44044407
const int col = blockDim.x*blockIdx.x + threadIdx.x;
44054408
const int half_n_dims = ncols/4;
@@ -4410,9 +4413,10 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
44104413

44114414
const int row = blockDim.y*blockIdx.y + threadIdx.y;
44124415
const int i = row*ncols + col;
4416+
const int i2 = row/p_delta_rows;
44134417

44144418
const float col_theta_scale = powf(theta_scale, col);
4415-
const float p = p0 + p_delta*(row/p_delta_rows);
4419+
const float p = p0[i2] + p_delta*i2;
44164420

44174421
const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
44184422
const float sin_theta = sinf(theta);
@@ -5361,7 +5365,7 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
53615365
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
53625366
}
53635367

5364-
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
5368+
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
53655369
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
53665370
GGML_ASSERT(ncols % 2 == 0);
53675371
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
@@ -5370,7 +5374,7 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
53705374
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
53715375
}
53725376

5373-
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
5377+
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
53745378
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
53755379
GGML_ASSERT(ncols % 2 == 0);
53765380
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
@@ -5379,7 +5383,7 @@ static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, co
53795383
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
53805384
}
53815385

5382-
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
5386+
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
53835387
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
53845388
GGML_ASSERT(ncols % 4 == 0);
53855389
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
@@ -6069,9 +6073,10 @@ inline void ggml_cuda_op_rope(
60696073

60706074
const int64_t ne00 = src0->ne[0];
60716075
const int64_t ne01 = src0->ne[1];
6076+
const int64_t ne2 = dst->ne[2];
60726077
const int64_t nrows = ggml_nrows(src0);
60736078

6074-
const int n_past = ((int32_t *) dst->op_params)[0];
6079+
//const int n_past = ((int32_t *) dst->op_params)[0];
60756080
const int n_dims = ((int32_t *) dst->op_params)[1];
60766081
const int mode = ((int32_t *) dst->op_params)[2];
60776082
const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -6082,21 +6087,38 @@ inline void ggml_cuda_op_rope(
60826087
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
60836088

60846089
const float theta_scale = powf(freq_base, -2.0f/n_dims);
6085-
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6090+
//const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6091+
6092+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
6093+
GGML_ASSERT(src1->ne[0] == ne2);
6094+
6095+
std::vector<float> p0s(ne2);
6096+
for (int64_t i = 0; i < ne2; ++i) {
6097+
int n_past = ((int32_t *) src1->data)[i];
6098+
p0s[i] = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6099+
}
6100+
6101+
size_t p0d_as = 0;
6102+
float * p0d;
6103+
6104+
p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as);
6105+
CUDA_CHECK(cudaMemcpyAsync(p0d, p0s.data(), ne2 * sizeof(float), cudaMemcpyHostToDevice, main_stream));
60866106

60876107
const bool is_neox = mode & 2;
60886108
const bool is_glm = mode & 4;
60896109

60906110
// compute
60916111
if (is_glm) {
6092-
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream);
6112+
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, n_ctx, main_stream);
60936113
} else if (is_neox) {
60946114
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
6095-
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
6115+
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream);
60966116
} else {
6097-
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
6117+
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream);
60986118
}
60996119

6120+
ggml_cuda_pool_free(p0d, p0d_as);
6121+
61006122
(void) src1;
61016123
(void) dst;
61026124
(void) src1_dd;

0 commit comments

Comments
 (0)