1
1
#include " rope.cuh"
2
2
3
3
struct rope_corr_dims {
4
- float v[4 ];
4
+ float v[4 ]; // TODO: is there any reson for this to be 4 instead of 2?
5
5
};
6
6
7
7
static __device__ float rope_yarn_ramp (const float low, const float high, const int i0) {
@@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
13
13
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14
14
static __device__ void rope_yarn (
15
15
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
16
- float * cos_theta, float * sin_theta
17
- ) {
16
+ float * cos_theta, float * sin_theta) {
18
17
// Get n-d rotational scaling corrected for extrapolation
19
18
float theta_interp = freq_scale * theta_extrap;
20
19
float theta = theta_interp;
@@ -29,27 +28,38 @@ static __device__ void rope_yarn(
29
28
*sin_theta = sinf (theta) * mscale;
30
29
}
31
30
32
- // rope == RoPE == rotary positional embedding
33
- template <typename T, bool has_pos>
34
- static __global__ void rope (
35
- const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
36
- float ext_factor, float attn_factor, rope_corr_dims corr_dims
37
- ) {
38
- const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
31
+ template <typename T, bool has_ff>
32
+ static __global__ void rope_norm (
33
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
34
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
35
+ const int i0 = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
39
36
40
- if (col >= ncols ) {
37
+ if (i0 >= ne0 ) {
41
38
return ;
42
39
}
43
40
44
41
const int row = blockDim .x *blockIdx .x + threadIdx .x ;
45
- const int i = row*ncols + col;
42
+
43
+ if (i0 >= n_dims) {
44
+ const int i = row*ne0 + i0;
45
+
46
+ dst[i + 0 ] = x[i + 0 ];
47
+ dst[i + 1 ] = x[i + 1 ];
48
+
49
+ return ;
50
+ }
51
+
52
+ const int i = row*ne0 + i0;
46
53
const int i2 = row/p_delta_rows;
47
54
48
- const int p = has_pos ? pos[i2] : 0 ;
49
- const float theta_base = p*powf (freq_base, -float (col)/ncols);
55
+ const float theta_base = pos[i2]*powf (theta_scale, i0/2 .0f );
56
+
57
+ const float freq_factor = has_ff ? freq_factors[i0/2 ] : 1 .0f ;
58
+
59
+ float cos_theta;
60
+ float sin_theta;
50
61
51
- float cos_theta, sin_theta;
52
- rope_yarn (theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
62
+ rope_yarn (theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
53
63
54
64
const float x0 = x[i + 0 ];
55
65
const float x1 = x[i + 1 ];
@@ -58,40 +68,38 @@ static __global__ void rope(
58
68
dst[i + 1 ] = x0*sin_theta + x1*cos_theta;
59
69
}
60
70
61
- template <typename T, bool has_pos, bool has_freq_facs >
71
+ template <typename T, bool has_ff >
62
72
static __global__ void rope_neox (
63
- const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
65
- ) {
66
- const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
73
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
74
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
75
+ const int i0 = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
67
76
68
- if (col >= ncols ) {
77
+ if (i0 >= ne0 ) {
69
78
return ;
70
79
}
71
80
72
81
const int row = blockDim .x *blockIdx .x + threadIdx .x ;
73
- const int ib = col / n_dims;
74
- const int ic = col % n_dims;
75
82
76
- if (ib > 0 ) {
77
- const int i = row*ncols + ib*n_dims + ic ;
83
+ if (i0 >= n_dims ) {
84
+ const int i = row*ne0 + i0 ;
78
85
79
86
dst[i + 0 ] = x[i + 0 ];
80
87
dst[i + 1 ] = x[i + 1 ];
81
88
82
89
return ;
83
90
}
84
91
85
- const int i = row*ncols + ib*n_dims + ic /2 ;
92
+ const int i = row*ne0 + i0 /2 ;
86
93
const int i2 = row/p_delta_rows;
87
94
88
- const int p = has_pos ? pos[i2] : 0 ;
89
- const float freq_factor = has_freq_facs ? freq_factors[ic/2 ] : 1 .0f ;
95
+ const float theta_base = pos[i2]*powf (theta_scale, i0/2 .0f );
96
+
97
+ const float freq_factor = has_ff ? freq_factors[i0/2 ] : 1 .0f ;
90
98
91
- const float theta_base = p*powf (theta_scale, col/2 .0f )/freq_factor;
99
+ float cos_theta;
100
+ float sin_theta;
92
101
93
- float cos_theta, sin_theta;
94
- rope_yarn (theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
102
+ rope_yarn (theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
95
103
96
104
const float x0 = x[i + 0 ];
97
105
const float x1 = x[i + n_dims/2 ];
@@ -100,93 +108,81 @@ static __global__ void rope_neox(
100
108
dst[i + n_dims/2 ] = x0*sin_theta + x1*cos_theta;
101
109
}
102
110
103
-
104
111
template <typename T>
105
- static void rope_cuda (
106
- const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
107
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
108
- ) {
109
- GGML_ASSERT (ncols % 2 == 0 );
112
+ static void rope_norm_cuda (
113
+ const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
114
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
115
+ GGML_ASSERT (ne0 % 2 == 0 );
110
116
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
111
- const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
112
- const dim3 block_nums (nrows, num_blocks_x, 1 );
113
- if (pos == nullptr ) {
114
- rope<T, false ><<<block_nums, block_dims, 0 , stream>>> (
115
- x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
116
- );
117
+ const int n_blocks_x = (ne0 + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
118
+ const dim3 block_nums (nr, n_blocks_x, 1 );
119
+
120
+ const float theta_scale = powf (freq_base, -2 .0f /n_dims);
121
+
122
+ if (freq_factors == nullptr ) {
123
+ rope_norm<T, false ><<<block_nums, block_dims, 0 , stream>>> (
124
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
125
+ theta_scale, freq_factors
126
+ );
117
127
} else {
118
- rope<T, true ><<<block_nums, block_dims, 0 , stream>>> (
119
- x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
120
- );
128
+ rope_norm<T, true ><<<block_nums, block_dims, 0 , stream>>> (
129
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
130
+ theta_scale, freq_factors
131
+ );
121
132
}
122
133
}
123
134
124
135
template <typename T>
125
136
static void rope_neox_cuda (
126
- const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
127
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
128
- ) {
129
- GGML_ASSERT (ncols % 2 == 0 );
137
+ const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
138
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
139
+ GGML_ASSERT (ne0 % 2 == 0 );
130
140
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
131
- const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
132
- const dim3 block_nums (nrows, num_blocks_x , 1 );
141
+ const int n_blocks_x = (ne0 + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
142
+ const dim3 block_nums (nr, n_blocks_x , 1 );
133
143
134
144
const float theta_scale = powf (freq_base, -2 .0f /n_dims);
135
145
136
- if (pos == nullptr ) {
137
- if (freq_factors == nullptr ) {
138
- rope_neox<T, false , false ><<<block_nums, block_dims, 0 , stream>>> (
139
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
140
- theta_scale, freq_factors
141
- );
142
- } else {
143
- rope_neox<T, false , true ><<<block_nums, block_dims, 0 , stream>>> (
144
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
146
+ if (freq_factors == nullptr ) {
147
+ rope_neox<T, false ><<<block_nums, block_dims, 0 , stream>>> (
148
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
145
149
theta_scale, freq_factors
146
150
);
147
- }
148
151
} else {
149
- if (freq_factors == nullptr ) {
150
- rope_neox<T, true , false ><<<block_nums, block_dims, 0 , stream>>> (
151
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
152
- theta_scale, freq_factors
153
- );
154
- } else {
155
- rope_neox<T, true , true ><<<block_nums, block_dims, 0 , stream>>> (
156
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
152
+ rope_neox<T, true ><<<block_nums, block_dims, 0 , stream>>> (
153
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
157
154
theta_scale, freq_factors
158
155
);
159
- }
160
156
}
161
157
}
162
158
163
- static void rope_cuda_f16 (
164
- const half * x, half * dst, int ncols , int nrows , const int32_t * pos, float freq_scale, int p_delta_rows,
165
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
159
+ static void rope_norm_cuda_f16 (
160
+ const half * x, half * dst, int ne0 , int n_dims, int nr , const int32_t * pos, float freq_scale, int p_delta_rows,
161
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
166
162
167
- rope_cuda <half>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
163
+ rope_norm_cuda <half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors , stream);
168
164
}
169
165
170
- static void rope_cuda_f32 (
171
- const float * x, float * dst, int ncols , int nrows , const int32_t * pos, float freq_scale, int p_delta_rows,
172
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
166
+ static void rope_norm_cuda_f32 (
167
+ const float * x, float * dst, int ne0 , int n_dims, int nr , const int32_t * pos, float freq_scale, int p_delta_rows,
168
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
173
169
174
- rope_cuda <float >(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
170
+ rope_norm_cuda <float >(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors , stream);
175
171
}
176
172
177
173
static void rope_neox_cuda_f16 (
178
- const half * x, half * dst, int ncols , int n_dims, int nrows , const int32_t * pos, float freq_scale, int p_delta_rows,
174
+ const half * x, half * dst, int ne0 , int n_dims, int nr , const int32_t * pos, float freq_scale, int p_delta_rows,
179
175
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
180
176
181
- rope_neox_cuda<half>(x, dst, ncols , n_dims, nrows , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
177
+ rope_neox_cuda<half>(x, dst, ne0 , n_dims, nr , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
182
178
}
183
179
184
180
static void rope_neox_cuda_f32 (
185
- const float * x, float * dst, int ncols , int n_dims, int nrows , const int32_t * pos, float freq_scale, int p_delta_rows,
181
+ const float * x, float * dst, int ne0 , int n_dims, int nr , const int32_t * pos, float freq_scale, int p_delta_rows,
186
182
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
187
183
) {
188
184
189
- rope_neox_cuda<float >(x, dst, ncols , n_dims, nrows , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
185
+ rope_neox_cuda<float >(x, dst, ne0 , n_dims, nr , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
190
186
}
191
187
192
188
void ggml_cuda_op_rope (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -207,36 +203,36 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
207
203
208
204
const int64_t ne00 = src0->ne [0 ];
209
205
const int64_t ne01 = src0->ne [1 ];
210
- const int64_t nrows = ggml_nrows (src0);
206
+ const int64_t nr = ggml_nrows (src0);
211
207
212
- // const int n_past = ((int32_t *) dst->op_params)[0];
213
- const int n_dims = ((int32_t *) dst->op_params )[1 ];
214
- const int mode = ((int32_t *) dst->op_params )[2 ];
215
- // const int n_ctx = ((int32_t *) dst->op_params)[3];
216
- const int n_orig_ctx = ((int32_t *) dst->op_params )[4 ];
208
+ // const int n_past = ((int32_t *) dst->op_params)[0];
209
+ const int n_dims = ((int32_t *) dst->op_params )[1 ];
210
+ const int mode = ((int32_t *) dst->op_params )[2 ];
211
+ // const int n_ctx = ((int32_t *) dst->op_params)[3];
212
+ const int n_orig_ctx = ((int32_t *) dst->op_params )[4 ];
217
213
218
214
// RoPE alteration for extended context
219
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
215
+ float freq_base;
216
+ float freq_scale;
217
+ float ext_factor;
218
+ float attn_factor;
219
+ float beta_fast;
220
+ float beta_slow;
221
+
220
222
memcpy (&freq_base, (int32_t *) dst->op_params + 5 , sizeof (float ));
221
223
memcpy (&freq_scale, (int32_t *) dst->op_params + 6 , sizeof (float ));
222
224
memcpy (&ext_factor, (int32_t *) dst->op_params + 7 , sizeof (float ));
223
225
memcpy (&attn_factor, (int32_t *) dst->op_params + 8 , sizeof (float ));
224
226
memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
225
227
memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
226
228
227
- const float * freq_factors = nullptr ;
228
- const int32_t * pos = nullptr ;
229
-
230
229
const bool is_neox = mode & 2 ;
231
230
232
- pos = (const int32_t *) src1_d;
231
+ const int32_t * pos = (const int32_t *) src1_d;
233
232
234
- if (is_neox) {
235
- if (src2 != nullptr ) {
236
- freq_factors = (const float *) src2->data ;
237
- }
238
- } else {
239
- GGML_ASSERT (src2 == nullptr && " TODO: freq_factors not implemented for !is_neox" );
233
+ const float * freq_factors = nullptr ;
234
+ if (src2 != nullptr ) {
235
+ freq_factors = (const float *) src2->data ;
240
236
}
241
237
242
238
rope_corr_dims corr_dims;
@@ -246,27 +242,27 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
246
242
if (is_neox) {
247
243
if (src0->type == GGML_TYPE_F32) {
248
244
rope_neox_cuda_f32 (
249
- (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows , pos, freq_scale, ne01, freq_base, ext_factor,
245
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr , pos, freq_scale, ne01, freq_base, ext_factor,
250
246
attn_factor, corr_dims, freq_factors, stream
251
247
);
252
248
} else if (src0->type == GGML_TYPE_F16) {
253
249
rope_neox_cuda_f16 (
254
- (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows , pos, freq_scale, ne01, freq_base, ext_factor,
250
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr , pos, freq_scale, ne01, freq_base, ext_factor,
255
251
attn_factor, corr_dims, freq_factors, stream
256
252
);
257
253
} else {
258
254
GGML_ASSERT (false );
259
255
}
260
256
} else {
261
257
if (src0->type == GGML_TYPE_F32) {
262
- rope_cuda_f32 (
263
- (const float *)src0_d, (float *)dst_d, ne00, nrows , pos, freq_scale, ne01, freq_base, ext_factor,
264
- attn_factor, corr_dims, stream
258
+ rope_norm_cuda_f32 (
259
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr , pos, freq_scale, ne01, freq_base, ext_factor,
260
+ attn_factor, corr_dims, freq_factors, stream
265
261
);
266
262
} else if (src0->type == GGML_TYPE_F16) {
267
- rope_cuda_f16 (
268
- (const half *)src0_d, (half *)dst_d, ne00, nrows , pos, freq_scale, ne01, freq_base, ext_factor,
269
- attn_factor, corr_dims, stream
263
+ rope_norm_cuda_f16 (
264
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr , pos, freq_scale, ne01, freq_base, ext_factor,
265
+ attn_factor, corr_dims, freq_factors, stream
270
266
);
271
267
} else {
272
268
GGML_ASSERT (false );
0 commit comments