Skip to content

Commit 4739018

Browse files
committed
cuda : better rope implementation
ggml-ci
1 parent fb97b9e commit 4739018

File tree

2 files changed

+64
-88
lines changed

2 files changed

+64
-88
lines changed

ggml-cuda/rope.cu

Lines changed: 63 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -28,67 +28,38 @@ static __device__ void rope_yarn(
2828
*sin_theta = sinf(theta) * mscale;
2929
}
3030

31-
//// rope == RoPE == rotary positional embedding
32-
//template<typename T, bool has_ff>
33-
//static __global__ void rope_norm(
34-
// const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
35-
// float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
36-
// const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
37-
//
38-
// if (col >= ncols) {
39-
// return;
40-
// }
41-
//
42-
// const int row = blockDim.x*blockIdx.x + threadIdx.x;
43-
// const int i = row*ncols + col;
44-
// const int i2 = row/p_delta_rows;
45-
//
46-
// const float theta_base = pos[i2]*powf(freq_base, -float(col)/ncols);
47-
//
48-
// const float freq_factor = has_ff ? freq_factors[col/2] : 1.0f;
49-
//
50-
// float cos_theta, sin_theta;
51-
// rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
52-
//
53-
// const float x0 = x[i + 0];
54-
// const float x1 = x[i + 1];
55-
//
56-
// dst[i + 0] = x0*cos_theta - x1*sin_theta;
57-
// dst[i + 1] = x0*sin_theta + x1*cos_theta;
58-
//}
59-
6031
template<typename T, bool has_ff>
6132
static __global__ void rope_norm(
62-
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
33+
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
6334
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
64-
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
35+
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
6536

66-
if (col >= ncols) {
37+
if (i0 >= ne0) {
6738
return;
6839
}
6940

7041
const int row = blockDim.x*blockIdx.x + threadIdx.x;
71-
const int ib = col / n_dims;
72-
const int ic = col % n_dims;
7342

74-
if (ib > 0) {
75-
const int i = row*ncols + ib*n_dims + ic;
43+
if (i0 >= n_dims) {
44+
const int i = row*ne0 + i0;
7645

7746
dst[i + 0] = x[i + 0];
7847
dst[i + 1] = x[i + 1];
7948

8049
return;
8150
}
8251

83-
const int i = row*ncols + ib*n_dims + ic;
52+
const int i = row*ne0 + i0;
8453
const int i2 = row/p_delta_rows;
8554

86-
const float theta_base = pos[i2]*powf(theta_scale, col/2.0f);
55+
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
56+
57+
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
8758

88-
const float freq_factor = has_ff ? freq_factors[ic/2] : 1.0f;
59+
float cos_theta;
60+
float sin_theta;
8961

90-
float cos_theta, sin_theta;
91-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
62+
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
9263

9364
const float x0 = x[i + 0];
9465
const float x1 = x[i + 1];
@@ -99,36 +70,36 @@ static __global__ void rope_norm(
9970

10071
template<typename T, bool has_ff>
10172
static __global__ void rope_neox(
102-
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
73+
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
10374
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
104-
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
75+
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
10576

106-
if (col >= ncols) {
77+
if (i0 >= ne0) {
10778
return;
10879
}
10980

11081
const int row = blockDim.x*blockIdx.x + threadIdx.x;
111-
const int ib = col / n_dims;
112-
const int ic = col % n_dims;
11382

114-
if (ib > 0) {
115-
const int i = row*ncols + ib*n_dims + ic;
83+
if (i0 >= n_dims) {
84+
const int i = row*ne0 + i0;
11685

11786
dst[i + 0] = x[i + 0];
11887
dst[i + 1] = x[i + 1];
11988

12089
return;
12190
}
12291

123-
const int i = row*ncols + ib*n_dims + ic/2;
92+
const int i = row*ne0 + i0/2;
12493
const int i2 = row/p_delta_rows;
12594

126-
const float theta_base = pos[i2]*powf(theta_scale, col/2.0f);
95+
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
12796

128-
const float freq_factor = has_ff ? freq_factors[ic/2] : 1.0f;
97+
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
12998

130-
float cos_theta, sin_theta;
131-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
99+
float cos_theta;
100+
float sin_theta;
101+
102+
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
132103

133104
const float x0 = x[i + 0];
134105
const float x1 = x[i + n_dims/2];
@@ -139,79 +110,79 @@ static __global__ void rope_neox(
139110

140111
template<typename T>
141112
static void rope_norm_cuda(
142-
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
113+
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
143114
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
144-
GGML_ASSERT(ncols % 2 == 0);
115+
GGML_ASSERT(ne0 % 2 == 0);
145116
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
146-
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
147-
const dim3 block_nums(nrows, num_blocks_x, 1);
117+
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
118+
const dim3 block_nums(nr, n_blocks_x, 1);
148119

149120
const float theta_scale = powf(freq_base, -2.0f/n_dims);
150121

151122
if (freq_factors == nullptr) {
152123
rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
153-
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
124+
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
154125
theta_scale, freq_factors
155126
);
156127
} else {
157128
rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
158-
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
129+
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
159130
theta_scale, freq_factors
160131
);
161132
}
162133
}
163134

164135
template<typename T>
165136
static void rope_neox_cuda(
166-
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
137+
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
167138
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
168-
GGML_ASSERT(ncols % 2 == 0);
139+
GGML_ASSERT(ne0 % 2 == 0);
169140
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
170-
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
171-
const dim3 block_nums(nrows, num_blocks_x, 1);
141+
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
142+
const dim3 block_nums(nr, n_blocks_x, 1);
172143

173144
const float theta_scale = powf(freq_base, -2.0f/n_dims);
174145

175146
if (freq_factors == nullptr) {
176147
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
177-
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
148+
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
178149
theta_scale, freq_factors
179150
);
180151
} else {
181152
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
182-
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
153+
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
183154
theta_scale, freq_factors
184155
);
185156
}
186157
}
187158

188159
static void rope_norm_cuda_f16(
189-
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
160+
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
190161
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
191162

192-
rope_norm_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
163+
rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
193164
}
194165

195166
static void rope_norm_cuda_f32(
196-
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
167+
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
197168
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
198169

199-
rope_norm_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
170+
rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
200171
}
201172

202173
static void rope_neox_cuda_f16(
203-
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
174+
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
204175
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
205176

206-
rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
177+
rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
207178
}
208179

209180
static void rope_neox_cuda_f32(
210-
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
181+
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
211182
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
212183
) {
213184

214-
rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
185+
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
215186
}
216187

217188
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -232,30 +203,34 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
232203

233204
const int64_t ne00 = src0->ne[0];
234205
const int64_t ne01 = src0->ne[1];
235-
const int64_t nrows = ggml_nrows(src0);
206+
const int64_t nr = ggml_nrows(src0);
236207

237-
//const int n_past = ((int32_t *) dst->op_params)[0];
238-
const int n_dims = ((int32_t *) dst->op_params)[1];
239-
const int mode = ((int32_t *) dst->op_params)[2];
240-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
241-
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
208+
//const int n_past = ((int32_t *) dst->op_params)[0];
209+
const int n_dims = ((int32_t *) dst->op_params)[1];
210+
const int mode = ((int32_t *) dst->op_params)[2];
211+
//const int n_ctx = ((int32_t *) dst->op_params)[3];
212+
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
242213

243214
// RoPE alteration for extended context
244-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
215+
float freq_base;
216+
float freq_scale;
217+
float ext_factor;
218+
float attn_factor;
219+
float beta_fast;
220+
float beta_slow;
221+
245222
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
246223
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
247224
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
248225
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
249226
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
250227
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
251228

252-
const float * freq_factors = nullptr;
253-
const int32_t * pos = nullptr;
254-
255229
const bool is_neox = mode & 2;
256230

257-
pos = (const int32_t *) src1_d;
231+
const int32_t * pos = (const int32_t *) src1_d;
258232

233+
const float * freq_factors = nullptr;
259234
if (src2 != nullptr) {
260235
freq_factors = (const float *) src2->data;
261236
}
@@ -267,12 +242,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
267242
if (is_neox) {
268243
if (src0->type == GGML_TYPE_F32) {
269244
rope_neox_cuda_f32(
270-
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
245+
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
271246
attn_factor, corr_dims, freq_factors, stream
272247
);
273248
} else if (src0->type == GGML_TYPE_F16) {
274249
rope_neox_cuda_f16(
275-
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
250+
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
276251
attn_factor, corr_dims, freq_factors, stream
277252
);
278253
} else {
@@ -281,12 +256,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
281256
} else {
282257
if (src0->type == GGML_TYPE_F32) {
283258
rope_norm_cuda_f32(
284-
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
259+
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
285260
attn_factor, corr_dims, freq_factors, stream
286261
);
287262
} else if (src0->type == GGML_TYPE_F16) {
288263
rope_norm_cuda_f16(
289-
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
264+
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
290265
attn_factor, corr_dims, freq_factors, stream
291266
);
292267
} else {

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2252,6 +2252,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22522252
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
22532253
}
22542254
}
2255+
22552256
all = false;
22562257
}
22572258
}

0 commit comments

Comments
 (0)