5
5
#include < stdio.h>
6
6
#include < atomic>
7
7
#include < assert.h>
8
+ #include < vector>
8
9
9
10
#if defined(GGML_USE_HIPBLAS)
10
11
#include < hip/hip_runtime.h>
@@ -4355,7 +4356,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
4355
4356
}
4356
4357
4357
4358
// rope == RoPE == rotary positional embedding
4358
- static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const float p0,
4359
+ static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const float * p0,
4359
4360
const float p_delta, const int p_delta_rows, const float theta_scale) {
4360
4361
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
4361
4362
@@ -4365,8 +4366,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
4365
4366
4366
4367
const int row = blockDim .x *blockIdx .x + threadIdx .x ;
4367
4368
const int i = row*ncols + col;
4369
+ const int i2 = row/p_delta_rows;
4368
4370
4369
- const float theta = (p0 + p_delta * (row/p_delta_rows) )*powf (theta_scale, col/2 );
4371
+ const float theta = (p0[i2] + p_delta*i2 )*powf (theta_scale, col/2 );
4370
4372
const float sin_theta = sinf (theta);
4371
4373
const float cos_theta = cosf (theta);
4372
4374
@@ -4377,7 +4379,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
4377
4379
dst[i + 1 ] = x0*sin_theta + x1*cos_theta;
4378
4380
}
4379
4381
4380
- static __global__ void rope_neox_f32 (const float * x, float * dst, const int ncols, const float p0,
4382
+ static __global__ void rope_neox_f32 (const float * x, float * dst, const int ncols, const float * p0,
4381
4383
const float p_delta, const int p_delta_rows, const float theta_scale) {
4382
4384
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
4383
4385
@@ -4387,8 +4389,9 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
4387
4389
4388
4390
const int row = blockDim .x *blockIdx .x + threadIdx .x ;
4389
4391
const int i = row*ncols + col/2 ;
4392
+ const int i2 = row/p_delta_rows;
4390
4393
4391
- const float theta = (p0 + p_delta * (row/p_delta_rows) )*powf (theta_scale, col/2 );
4394
+ const float theta = (p0[i2] + p_delta*i2 )*powf (theta_scale, col/2 );
4392
4395
const float sin_theta = sinf (theta);
4393
4396
const float cos_theta = cosf (theta);
4394
4397
@@ -4399,7 +4402,7 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
4399
4402
dst[i + ncols/2 ] = x0*sin_theta + x1*cos_theta;
4400
4403
}
4401
4404
4402
- static __global__ void rope_glm_f32 (const float * x, float * dst, const int ncols, const float p0,
4405
+ static __global__ void rope_glm_f32 (const float * x, float * dst, const int ncols, const float * p0,
4403
4406
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
4404
4407
const int col = blockDim .x *blockIdx .x + threadIdx .x ;
4405
4408
const int half_n_dims = ncols/4 ;
@@ -4410,9 +4413,10 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
4410
4413
4411
4414
const int row = blockDim .y *blockIdx .y + threadIdx .y ;
4412
4415
const int i = row*ncols + col;
4416
+ const int i2 = row/p_delta_rows;
4413
4417
4414
4418
const float col_theta_scale = powf (theta_scale, col);
4415
- const float p = p0 + p_delta*(row/p_delta_rows) ;
4419
+ const float p = p0[i2] + p_delta*i2 ;
4416
4420
4417
4421
const float theta = min (p, p_delta*(n_ctx - 2 ))*col_theta_scale;
4418
4422
const float sin_theta = sinf (theta);
@@ -5361,7 +5365,7 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
5361
5365
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0 , stream>>> (x, dst, scale, k);
5362
5366
}
5363
5367
5364
- static void rope_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p0,
5368
+ static void rope_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float * p0,
5365
5369
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5366
5370
GGML_ASSERT (ncols % 2 == 0 );
5367
5371
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
@@ -5370,7 +5374,7 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
5370
5374
rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
5371
5375
}
5372
5376
5373
- static void rope_neox_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p0,
5377
+ static void rope_neox_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float * p0,
5374
5378
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5375
5379
GGML_ASSERT (ncols % 2 == 0 );
5376
5380
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
@@ -5379,7 +5383,7 @@ static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, co
5379
5383
rope_neox_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
5380
5384
}
5381
5385
5382
- static void rope_glm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p0,
5386
+ static void rope_glm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float * p0,
5383
5387
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
5384
5388
GGML_ASSERT (ncols % 4 == 0 );
5385
5389
const dim3 block_dims (CUDA_ROPE_BLOCK_SIZE/4 , 1 , 1 );
@@ -6069,9 +6073,10 @@ inline void ggml_cuda_op_rope(
6069
6073
6070
6074
const int64_t ne00 = src0->ne [0 ];
6071
6075
const int64_t ne01 = src0->ne [1 ];
6076
+ const int64_t ne2 = dst->ne [2 ];
6072
6077
const int64_t nrows = ggml_nrows (src0);
6073
6078
6074
- const int n_past = ((int32_t *) dst->op_params )[0 ];
6079
+ // const int n_past = ((int32_t *) dst->op_params)[0];
6075
6080
const int n_dims = ((int32_t *) dst->op_params )[1 ];
6076
6081
const int mode = ((int32_t *) dst->op_params )[2 ];
6077
6082
const int n_ctx = ((int32_t *) dst->op_params )[3 ];
@@ -6082,21 +6087,38 @@ inline void ggml_cuda_op_rope(
6082
6087
memcpy (&freq_scale, (int32_t *) dst->op_params + 5 , sizeof (float ));
6083
6088
6084
6089
const float theta_scale = powf (freq_base, -2 .0f /n_dims);
6085
- const float p0 = (((mode & 1 ) == 0 ? n_past : 0 )) * freq_scale;
6090
+ // const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6091
+
6092
+ GGML_ASSERT (src1->type == GGML_TYPE_I32);
6093
+ GGML_ASSERT (src1->ne [0 ] == ne2);
6094
+
6095
+ std::vector<float > p0s (ne2);
6096
+ for (int64_t i = 0 ; i < ne2; ++i) {
6097
+ int n_past = ((int32_t *) src1->data )[i];
6098
+ p0s[i] = (((mode & 1 ) == 0 ? n_past : 0 )) * freq_scale;
6099
+ }
6100
+
6101
+ size_t p0d_as = 0 ;
6102
+ float * p0d;
6103
+
6104
+ p0d = (float *) ggml_cuda_pool_malloc (ne2 * sizeof (float ), &p0d_as);
6105
+ CUDA_CHECK (cudaMemcpyAsync (p0d, p0s.data (), ne2 * sizeof (float ), cudaMemcpyHostToDevice, main_stream));
6086
6106
6087
6107
const bool is_neox = mode & 2 ;
6088
6108
const bool is_glm = mode & 4 ;
6089
6109
6090
6110
// compute
6091
6111
if (is_glm) {
6092
- rope_glm_f32_cuda (src0_dd, dst_dd, ne00, nrows, p0 , freq_scale, ne01, theta_scale, n_ctx, main_stream);
6112
+ rope_glm_f32_cuda (src0_dd, dst_dd, ne00, nrows, p0d , freq_scale, ne01, theta_scale, n_ctx, main_stream);
6093
6113
} else if (is_neox) {
6094
6114
GGML_ASSERT (ne00 == n_dims && " ne00 != n_dims is not implemented for CUDA yet" );
6095
- rope_neox_f32_cuda (src0_dd, dst_dd, ne00, nrows, p0 , freq_scale, ne01, theta_scale, main_stream);
6115
+ rope_neox_f32_cuda (src0_dd, dst_dd, ne00, nrows, p0d , freq_scale, ne01, theta_scale, main_stream);
6096
6116
} else {
6097
- rope_f32_cuda (src0_dd, dst_dd, ne00, nrows, p0 , freq_scale, ne01, theta_scale, main_stream);
6117
+ rope_f32_cuda (src0_dd, dst_dd, ne00, nrows, p0d , freq_scale, ne01, theta_scale, main_stream);
6098
6118
}
6099
6119
6120
+ ggml_cuda_pool_free (p0d, p0d_as);
6121
+
6100
6122
(void ) src1;
6101
6123
(void ) dst;
6102
6124
(void ) src1_dd;
0 commit comments