Skip to content

Commit fb97b9e

Browse files
committed
wip
1 parent 13c6267 commit fb97b9e

File tree

1 file changed

+92
-71
lines changed

1 file changed

+92
-71
lines changed

ggml-cuda/rope.cu

Lines changed: 92 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "rope.cuh"
22

33
struct rope_corr_dims {
4-
float v[4];
4+
float v[4]; // TODO: is there any reson for this to be 4 instead of 2?
55
};
66

77
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
@@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
1313
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1414
static __device__ void rope_yarn(
1515
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
16-
float * cos_theta, float * sin_theta
17-
) {
16+
float * cos_theta, float * sin_theta) {
1817
// Get n-d rotational scaling corrected for extrapolation
1918
float theta_interp = freq_scale * theta_extrap;
2019
float theta = theta_interp;
@@ -29,27 +28,67 @@ static __device__ void rope_yarn(
2928
*sin_theta = sinf(theta) * mscale;
3029
}
3130

32-
// rope == RoPE == rotary positional embedding
33-
template<typename T, bool has_pos>
34-
static __global__ void rope(
35-
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
36-
float ext_factor, float attn_factor, rope_corr_dims corr_dims
37-
) {
31+
//// rope == RoPE == rotary positional embedding
32+
//template<typename T, bool has_ff>
33+
//static __global__ void rope_norm(
34+
// const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
35+
// float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
36+
// const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
37+
//
38+
// if (col >= ncols) {
39+
// return;
40+
// }
41+
//
42+
// const int row = blockDim.x*blockIdx.x + threadIdx.x;
43+
// const int i = row*ncols + col;
44+
// const int i2 = row/p_delta_rows;
45+
//
46+
// const float theta_base = pos[i2]*powf(freq_base, -float(col)/ncols);
47+
//
48+
// const float freq_factor = has_ff ? freq_factors[col/2] : 1.0f;
49+
//
50+
// float cos_theta, sin_theta;
51+
// rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
52+
//
53+
// const float x0 = x[i + 0];
54+
// const float x1 = x[i + 1];
55+
//
56+
// dst[i + 0] = x0*cos_theta - x1*sin_theta;
57+
// dst[i + 1] = x0*sin_theta + x1*cos_theta;
58+
//}
59+
60+
template<typename T, bool has_ff>
61+
static __global__ void rope_norm(
62+
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
63+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
3864
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
3965

4066
if (col >= ncols) {
4167
return;
4268
}
4369

4470
const int row = blockDim.x*blockIdx.x + threadIdx.x;
45-
const int i = row*ncols + col;
71+
const int ib = col / n_dims;
72+
const int ic = col % n_dims;
73+
74+
if (ib > 0) {
75+
const int i = row*ncols + ib*n_dims + ic;
76+
77+
dst[i + 0] = x[i + 0];
78+
dst[i + 1] = x[i + 1];
79+
80+
return;
81+
}
82+
83+
const int i = row*ncols + ib*n_dims + ic;
4684
const int i2 = row/p_delta_rows;
4785

48-
const int p = has_pos ? pos[i2] : 0;
49-
const float theta_base = p*powf(freq_base, -float(col)/ncols);
86+
const float theta_base = pos[i2]*powf(theta_scale, col/2.0f);
87+
88+
const float freq_factor = has_ff ? freq_factors[ic/2] : 1.0f;
5089

5190
float cos_theta, sin_theta;
52-
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
91+
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
5392

5493
const float x0 = x[i + 0];
5594
const float x1 = x[i + 1];
@@ -58,11 +97,10 @@ static __global__ void rope(
5897
dst[i + 1] = x0*sin_theta + x1*cos_theta;
5998
}
6099

61-
template<typename T, bool has_pos, bool has_freq_facs>
100+
template<typename T, bool has_ff>
62101
static __global__ void rope_neox(
63102
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
65-
) {
103+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
66104
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
67105

68106
if (col >= ncols) {
@@ -85,13 +123,12 @@ static __global__ void rope_neox(
85123
const int i = row*ncols + ib*n_dims + ic/2;
86124
const int i2 = row/p_delta_rows;
87125

88-
const int p = has_pos ? pos[i2] : 0;
89-
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
126+
const float theta_base = pos[i2]*powf(theta_scale, col/2.0f);
90127

91-
const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
128+
const float freq_factor = has_ff ? freq_factors[ic/2] : 1.0f;
92129

93130
float cos_theta, sin_theta;
94-
rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
131+
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
95132

96133
const float x0 = x[i + 0];
97134
const float x1 = x[i + n_dims/2];
@@ -100,78 +137,66 @@ static __global__ void rope_neox(
100137
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
101138
}
102139

103-
104140
template<typename T>
105-
static void rope_cuda(
106-
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
107-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
108-
) {
141+
static void rope_norm_cuda(
142+
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
143+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
109144
GGML_ASSERT(ncols % 2 == 0);
110145
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
111146
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
112147
const dim3 block_nums(nrows, num_blocks_x, 1);
113-
if (pos == nullptr) {
114-
rope<T, false><<<block_nums, block_dims, 0, stream>>>(
115-
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
116-
);
148+
149+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
150+
151+
if (freq_factors == nullptr) {
152+
rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
153+
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
154+
theta_scale, freq_factors
155+
);
117156
} else {
118-
rope<T, true><<<block_nums, block_dims, 0, stream>>>(
119-
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
120-
);
157+
rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
158+
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
159+
theta_scale, freq_factors
160+
);
121161
}
122162
}
123163

124164
template<typename T>
125165
static void rope_neox_cuda(
126166
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
127-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
128-
) {
167+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
129168
GGML_ASSERT(ncols % 2 == 0);
130169
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
131170
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
132171
const dim3 block_nums(nrows, num_blocks_x, 1);
133172

134173
const float theta_scale = powf(freq_base, -2.0f/n_dims);
135174

136-
if (pos == nullptr) {
137-
if (freq_factors == nullptr) {
138-
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
139-
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
140-
theta_scale, freq_factors
141-
);
142-
} else {
143-
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
175+
if (freq_factors == nullptr) {
176+
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
144177
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
145178
theta_scale, freq_factors
146179
);
147-
}
148180
} else {
149-
if (freq_factors == nullptr) {
150-
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
181+
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
151182
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
152183
theta_scale, freq_factors
153184
);
154-
} else {
155-
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
156-
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
157-
theta_scale, freq_factors
158-
);
159-
}
160185
}
161186
}
162187

163-
static void rope_cuda_f16(
164-
const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
165-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
188+
static void rope_norm_cuda_f16(
189+
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
190+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
166191

167-
rope_cuda<half>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
192+
rope_norm_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
168193
}
169194

170-
static void rope_cuda_f32(
171-
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
172-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
195+
static void rope_norm_cuda_f32(
196+
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
197+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
173198

174-
rope_cuda<float>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
199+
rope_norm_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
175200
}
176201

177202
static void rope_neox_cuda_f16(
@@ -231,12 +256,8 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
231256

232257
pos = (const int32_t *) src1_d;
233258

234-
if (is_neox) {
235-
if (src2 != nullptr) {
236-
freq_factors = (const float *) src2->data;
237-
}
238-
} else {
239-
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
259+
if (src2 != nullptr) {
260+
freq_factors = (const float *) src2->data;
240261
}
241262

242263
rope_corr_dims corr_dims;
@@ -259,14 +280,14 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
259280
}
260281
} else {
261282
if (src0->type == GGML_TYPE_F32) {
262-
rope_cuda_f32(
263-
(const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
264-
attn_factor, corr_dims, stream
283+
rope_norm_cuda_f32(
284+
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
285+
attn_factor, corr_dims, freq_factors, stream
265286
);
266287
} else if (src0->type == GGML_TYPE_F16) {
267-
rope_cuda_f16(
268-
(const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
269-
attn_factor, corr_dims, stream
288+
rope_norm_cuda_f16(
289+
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
290+
attn_factor, corr_dims, freq_factors, stream
270291
);
271292
} else {
272293
GGML_ASSERT(false);

0 commit comments

Comments
 (0)