1
1
#include " rope.cuh"
2
2
3
3
struct rope_corr_dims {
4
- float v[4 ];
4
+ float v[4 ]; // TODO: is there any reson for this to be 4 instead of 2?
5
5
};
6
6
7
7
static __device__ float rope_yarn_ramp (const float low, const float high, const int i0) {
@@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
13
13
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14
14
static __device__ void rope_yarn (
15
15
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
16
- float * cos_theta, float * sin_theta
17
- ) {
16
+ float * cos_theta, float * sin_theta) {
18
17
// Get n-d rotational scaling corrected for extrapolation
19
18
float theta_interp = freq_scale * theta_extrap;
20
19
float theta = theta_interp;
@@ -29,27 +28,67 @@ static __device__ void rope_yarn(
29
28
*sin_theta = sinf (theta) * mscale;
30
29
}
31
30
32
- // rope == RoPE == rotary positional embedding
33
- template <typename T, bool has_pos>
34
- static __global__ void rope (
35
- const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
36
- float ext_factor, float attn_factor, rope_corr_dims corr_dims
37
- ) {
31
+ // // rope == RoPE == rotary positional embedding
32
+ // template<typename T, bool has_ff>
33
+ // static __global__ void rope_norm(
34
+ // const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
35
+ // float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
36
+ // const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
37
+ //
38
+ // if (col >= ncols) {
39
+ // return;
40
+ // }
41
+ //
42
+ // const int row = blockDim.x*blockIdx.x + threadIdx.x;
43
+ // const int i = row*ncols + col;
44
+ // const int i2 = row/p_delta_rows;
45
+ //
46
+ // const float theta_base = pos[i2]*powf(freq_base, -float(col)/ncols);
47
+ //
48
+ // const float freq_factor = has_ff ? freq_factors[col/2] : 1.0f;
49
+ //
50
+ // float cos_theta, sin_theta;
51
+ // rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
52
+ //
53
+ // const float x0 = x[i + 0];
54
+ // const float x1 = x[i + 1];
55
+ //
56
+ // dst[i + 0] = x0*cos_theta - x1*sin_theta;
57
+ // dst[i + 1] = x0*sin_theta + x1*cos_theta;
58
+ // }
59
+
60
+ template <typename T, bool has_ff>
61
+ static __global__ void rope_norm (
62
+ const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
63
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
38
64
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
39
65
40
66
if (col >= ncols) {
41
67
return ;
42
68
}
43
69
44
70
const int row = blockDim .x *blockIdx .x + threadIdx .x ;
45
- const int i = row*ncols + col;
71
+ const int ib = col / n_dims;
72
+ const int ic = col % n_dims;
73
+
74
+ if (ib > 0 ) {
75
+ const int i = row*ncols + ib*n_dims + ic;
76
+
77
+ dst[i + 0 ] = x[i + 0 ];
78
+ dst[i + 1 ] = x[i + 1 ];
79
+
80
+ return ;
81
+ }
82
+
83
+ const int i = row*ncols + ib*n_dims + ic;
46
84
const int i2 = row/p_delta_rows;
47
85
48
- const int p = has_pos ? pos[i2] : 0 ;
49
- const float theta_base = p*powf (freq_base, -float (col)/ncols);
86
+ const float theta_base = pos[i2]*powf (theta_scale, col/2 .0f );
87
+
88
+ const float freq_factor = has_ff ? freq_factors[ic/2 ] : 1 .0f ;
50
89
51
90
float cos_theta, sin_theta;
52
- rope_yarn (theta_base, freq_scale, corr_dims, col , ext_factor, attn_factor, &cos_theta, &sin_theta);
91
+ rope_yarn (theta_base/freq_factor , freq_scale, corr_dims, ic , ext_factor, attn_factor, &cos_theta, &sin_theta);
53
92
54
93
const float x0 = x[i + 0 ];
55
94
const float x1 = x[i + 1 ];
@@ -58,11 +97,10 @@ static __global__ void rope(
58
97
dst[i + 1 ] = x0*sin_theta + x1*cos_theta;
59
98
}
60
99
61
- template <typename T, bool has_pos, bool has_freq_facs >
100
+ template <typename T, bool has_ff >
62
101
static __global__ void rope_neox (
63
102
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
65
- ) {
103
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
66
104
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
67
105
68
106
if (col >= ncols) {
@@ -85,13 +123,12 @@ static __global__ void rope_neox(
85
123
const int i = row*ncols + ib*n_dims + ic/2 ;
86
124
const int i2 = row/p_delta_rows;
87
125
88
- const int p = has_pos ? pos[i2] : 0 ;
89
- const float freq_factor = has_freq_facs ? freq_factors[ic/2 ] : 1 .0f ;
126
+ const float theta_base = pos[i2]*powf (theta_scale, col/2 .0f );
90
127
91
- const float theta_base = p* powf (theta_scale, col/ 2 .0f )/freq_factor ;
128
+ const float freq_factor = has_ff ? freq_factors[ic/ 2 ] : 1 .0f ;
92
129
93
130
float cos_theta, sin_theta;
94
- rope_yarn (theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
131
+ rope_yarn (theta_base/freq_factor , freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
95
132
96
133
const float x0 = x[i + 0 ];
97
134
const float x1 = x[i + n_dims/2 ];
@@ -100,78 +137,66 @@ static __global__ void rope_neox(
100
137
dst[i + n_dims/2 ] = x0*sin_theta + x1*cos_theta;
101
138
}
102
139
103
-
104
140
template <typename T>
105
- static void rope_cuda (
106
- const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
107
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
108
- ) {
141
+ static void rope_norm_cuda (
142
+ const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
143
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
109
144
GGML_ASSERT (ncols % 2 == 0 );
110
145
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
111
146
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
112
147
const dim3 block_nums (nrows, num_blocks_x, 1 );
113
- if (pos == nullptr ) {
114
- rope<T, false ><<<block_nums, block_dims, 0 , stream>>> (
115
- x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
116
- );
148
+
149
+ const float theta_scale = powf (freq_base, -2 .0f /n_dims);
150
+
151
+ if (freq_factors == nullptr ) {
152
+ rope_norm<T, false ><<<block_nums, block_dims, 0 , stream>>> (
153
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
154
+ theta_scale, freq_factors
155
+ );
117
156
} else {
118
- rope<T, true ><<<block_nums, block_dims, 0 , stream>>> (
119
- x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
120
- );
157
+ rope_norm<T, true ><<<block_nums, block_dims, 0 , stream>>> (
158
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
159
+ theta_scale, freq_factors
160
+ );
121
161
}
122
162
}
123
163
124
164
template <typename T>
125
165
static void rope_neox_cuda (
126
166
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
127
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
128
- ) {
167
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
129
168
GGML_ASSERT (ncols % 2 == 0 );
130
169
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
131
170
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
132
171
const dim3 block_nums (nrows, num_blocks_x, 1 );
133
172
134
173
const float theta_scale = powf (freq_base, -2 .0f /n_dims);
135
174
136
- if (pos == nullptr ) {
137
- if (freq_factors == nullptr ) {
138
- rope_neox<T, false , false ><<<block_nums, block_dims, 0 , stream>>> (
139
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
140
- theta_scale, freq_factors
141
- );
142
- } else {
143
- rope_neox<T, false , true ><<<block_nums, block_dims, 0 , stream>>> (
175
+ if (freq_factors == nullptr ) {
176
+ rope_neox<T, false ><<<block_nums, block_dims, 0 , stream>>> (
144
177
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
145
178
theta_scale, freq_factors
146
179
);
147
- }
148
180
} else {
149
- if (freq_factors == nullptr ) {
150
- rope_neox<T, true , false ><<<block_nums, block_dims, 0 , stream>>> (
181
+ rope_neox<T, true ><<<block_nums, block_dims, 0 , stream>>> (
151
182
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
152
183
theta_scale, freq_factors
153
184
);
154
- } else {
155
- rope_neox<T, true , true ><<<block_nums, block_dims, 0 , stream>>> (
156
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
157
- theta_scale, freq_factors
158
- );
159
- }
160
185
}
161
186
}
162
187
163
- static void rope_cuda_f16 (
164
- const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
165
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
188
+ static void rope_norm_cuda_f16 (
189
+ const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
190
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
166
191
167
- rope_cuda <half>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
192
+ rope_norm_cuda <half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors , stream);
168
193
}
169
194
170
- static void rope_cuda_f32 (
171
- const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
172
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
195
+ static void rope_norm_cuda_f32 (
196
+ const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
197
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
173
198
174
- rope_cuda <float >(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
199
+ rope_norm_cuda <float >(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors , stream);
175
200
}
176
201
177
202
static void rope_neox_cuda_f16 (
@@ -231,12 +256,8 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
231
256
232
257
pos = (const int32_t *) src1_d;
233
258
234
- if (is_neox) {
235
- if (src2 != nullptr ) {
236
- freq_factors = (const float *) src2->data ;
237
- }
238
- } else {
239
- GGML_ASSERT (src2 == nullptr && " TODO: freq_factors not implemented for !is_neox" );
259
+ if (src2 != nullptr ) {
260
+ freq_factors = (const float *) src2->data ;
240
261
}
241
262
242
263
rope_corr_dims corr_dims;
@@ -259,14 +280,14 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
259
280
}
260
281
} else {
261
282
if (src0->type == GGML_TYPE_F32) {
262
- rope_cuda_f32 (
263
- (const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
264
- attn_factor, corr_dims, stream
283
+ rope_norm_cuda_f32 (
284
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
285
+ attn_factor, corr_dims, freq_factors, stream
265
286
);
266
287
} else if (src0->type == GGML_TYPE_F16) {
267
- rope_cuda_f16 (
268
- (const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
269
- attn_factor, corr_dims, stream
288
+ rope_norm_cuda_f16 (
289
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
290
+ attn_factor, corr_dims, freq_factors, stream
270
291
);
271
292
} else {
272
293
GGML_ASSERT (false );
0 commit comments