@@ -28,67 +28,38 @@ static __device__ void rope_yarn(
28
28
*sin_theta = sinf (theta) * mscale;
29
29
}
30
30
31
- // // rope == RoPE == rotary positional embedding
32
- // template<typename T, bool has_ff>
33
- // static __global__ void rope_norm(
34
- // const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
35
- // float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
36
- // const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
37
- //
38
- // if (col >= ncols) {
39
- // return;
40
- // }
41
- //
42
- // const int row = blockDim.x*blockIdx.x + threadIdx.x;
43
- // const int i = row*ncols + col;
44
- // const int i2 = row/p_delta_rows;
45
- //
46
- // const float theta_base = pos[i2]*powf(freq_base, -float(col)/ncols);
47
- //
48
- // const float freq_factor = has_ff ? freq_factors[col/2] : 1.0f;
49
- //
50
- // float cos_theta, sin_theta;
51
- // rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
52
- //
53
- // const float x0 = x[i + 0];
54
- // const float x1 = x[i + 1];
55
- //
56
- // dst[i + 0] = x0*cos_theta - x1*sin_theta;
57
- // dst[i + 1] = x0*sin_theta + x1*cos_theta;
58
- // }
59
-
60
31
template <typename T, bool has_ff>
61
32
static __global__ void rope_norm (
62
- const T * x, T * dst, int ncols , int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
33
+ const T * x, T * dst, int ne0 , int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
63
34
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
64
- const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
35
+ const int i0 = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
65
36
66
- if (col >= ncols ) {
37
+ if (i0 >= ne0 ) {
67
38
return ;
68
39
}
69
40
70
41
const int row = blockDim .x *blockIdx .x + threadIdx .x ;
71
- const int ib = col / n_dims;
72
- const int ic = col % n_dims;
73
42
74
- if (ib > 0 ) {
75
- const int i = row*ncols + ib*n_dims + ic ;
43
+ if (i0 >= n_dims ) {
44
+ const int i = row*ne0 + i0 ;
76
45
77
46
dst[i + 0 ] = x[i + 0 ];
78
47
dst[i + 1 ] = x[i + 1 ];
79
48
80
49
return ;
81
50
}
82
51
83
- const int i = row*ncols + ib*n_dims + ic ;
52
+ const int i = row*ne0 + i0 ;
84
53
const int i2 = row/p_delta_rows;
85
54
86
- const float theta_base = pos[i2]*powf (theta_scale, col/2 .0f );
55
+ const float theta_base = pos[i2]*powf (theta_scale, i0/2 .0f );
56
+
57
+ const float freq_factor = has_ff ? freq_factors[i0/2 ] : 1 .0f ;
87
58
88
- const float freq_factor = has_ff ? freq_factors[ic/2 ] : 1 .0f ;
59
+ float cos_theta;
60
+ float sin_theta;
89
61
90
- float cos_theta, sin_theta;
91
- rope_yarn (theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
62
+ rope_yarn (theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
92
63
93
64
const float x0 = x[i + 0 ];
94
65
const float x1 = x[i + 1 ];
@@ -99,36 +70,36 @@ static __global__ void rope_norm(
99
70
100
71
template <typename T, bool has_ff>
101
72
static __global__ void rope_neox (
102
- const T * x, T * dst, int ncols , int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
73
+ const T * x, T * dst, int ne0 , int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
103
74
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
104
- const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
75
+ const int i0 = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
105
76
106
- if (col >= ncols ) {
77
+ if (i0 >= ne0 ) {
107
78
return ;
108
79
}
109
80
110
81
const int row = blockDim .x *blockIdx .x + threadIdx .x ;
111
- const int ib = col / n_dims;
112
- const int ic = col % n_dims;
113
82
114
- if (ib > 0 ) {
115
- const int i = row*ncols + ib*n_dims + ic ;
83
+ if (i0 >= n_dims ) {
84
+ const int i = row*ne0 + i0 ;
116
85
117
86
dst[i + 0 ] = x[i + 0 ];
118
87
dst[i + 1 ] = x[i + 1 ];
119
88
120
89
return ;
121
90
}
122
91
123
- const int i = row*ncols + ib*n_dims + ic /2 ;
92
+ const int i = row*ne0 + i0 /2 ;
124
93
const int i2 = row/p_delta_rows;
125
94
126
- const float theta_base = pos[i2]*powf (theta_scale, col /2 .0f );
95
+ const float theta_base = pos[i2]*powf (theta_scale, i0 /2 .0f );
127
96
128
- const float freq_factor = has_ff ? freq_factors[ic /2 ] : 1 .0f ;
97
+ const float freq_factor = has_ff ? freq_factors[i0 /2 ] : 1 .0f ;
129
98
130
- float cos_theta, sin_theta;
131
- rope_yarn (theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
99
+ float cos_theta;
100
+ float sin_theta;
101
+
102
+ rope_yarn (theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
132
103
133
104
const float x0 = x[i + 0 ];
134
105
const float x1 = x[i + n_dims/2 ];
@@ -139,79 +110,79 @@ static __global__ void rope_neox(
139
110
140
111
template <typename T>
141
112
static void rope_norm_cuda (
142
- const T * x, T * dst, int ncols , int n_dims, int nrows , const int32_t * pos, float freq_scale, int p_delta_rows,
113
+ const T * x, T * dst, int ne0 , int n_dims, int nr , const int32_t * pos, float freq_scale, int p_delta_rows,
143
114
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
144
- GGML_ASSERT (ncols % 2 == 0 );
115
+ GGML_ASSERT (ne0 % 2 == 0 );
145
116
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
146
- const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
147
- const dim3 block_nums (nrows, num_blocks_x , 1 );
117
+ const int n_blocks_x = (ne0 + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
118
+ const dim3 block_nums (nr, n_blocks_x , 1 );
148
119
149
120
const float theta_scale = powf (freq_base, -2 .0f /n_dims);
150
121
151
122
if (freq_factors == nullptr ) {
152
123
rope_norm<T, false ><<<block_nums, block_dims, 0 , stream>>> (
153
- x, dst, ncols , n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
124
+ x, dst, ne0 , n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
154
125
theta_scale, freq_factors
155
126
);
156
127
} else {
157
128
rope_norm<T, true ><<<block_nums, block_dims, 0 , stream>>> (
158
- x, dst, ncols , n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
129
+ x, dst, ne0 , n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
159
130
theta_scale, freq_factors
160
131
);
161
132
}
162
133
}
163
134
164
135
template <typename T>
165
136
static void rope_neox_cuda (
166
- const T * x, T * dst, int ncols , int n_dims, int nrows , const int32_t * pos, float freq_scale, int p_delta_rows,
137
+ const T * x, T * dst, int ne0 , int n_dims, int nr , const int32_t * pos, float freq_scale, int p_delta_rows,
167
138
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
168
- GGML_ASSERT (ncols % 2 == 0 );
139
+ GGML_ASSERT (ne0 % 2 == 0 );
169
140
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
170
- const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
171
- const dim3 block_nums (nrows, num_blocks_x , 1 );
141
+ const int n_blocks_x = (ne0 + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
142
+ const dim3 block_nums (nr, n_blocks_x , 1 );
172
143
173
144
const float theta_scale = powf (freq_base, -2 .0f /n_dims);
174
145
175
146
if (freq_factors == nullptr ) {
176
147
rope_neox<T, false ><<<block_nums, block_dims, 0 , stream>>> (
177
- x, dst, ncols , n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
148
+ x, dst, ne0 , n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
178
149
theta_scale, freq_factors
179
150
);
180
151
} else {
181
152
rope_neox<T, true ><<<block_nums, block_dims, 0 , stream>>> (
182
- x, dst, ncols , n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
153
+ x, dst, ne0 , n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
183
154
theta_scale, freq_factors
184
155
);
185
156
}
186
157
}
187
158
188
159
static void rope_norm_cuda_f16 (
189
- const half * x, half * dst, int ncols , int n_dims, int nrows , const int32_t * pos, float freq_scale, int p_delta_rows,
160
+ const half * x, half * dst, int ne0 , int n_dims, int nr , const int32_t * pos, float freq_scale, int p_delta_rows,
190
161
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
191
162
192
- rope_norm_cuda<half>(x, dst, ncols , n_dims, nrows , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
163
+ rope_norm_cuda<half>(x, dst, ne0 , n_dims, nr , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
193
164
}
194
165
195
166
static void rope_norm_cuda_f32 (
196
- const float * x, float * dst, int ncols , int n_dims, int nrows , const int32_t * pos, float freq_scale, int p_delta_rows,
167
+ const float * x, float * dst, int ne0 , int n_dims, int nr , const int32_t * pos, float freq_scale, int p_delta_rows,
197
168
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
198
169
199
- rope_norm_cuda<float >(x, dst, ncols , n_dims, nrows , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
170
+ rope_norm_cuda<float >(x, dst, ne0 , n_dims, nr , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
200
171
}
201
172
202
173
static void rope_neox_cuda_f16 (
203
- const half * x, half * dst, int ncols , int n_dims, int nrows , const int32_t * pos, float freq_scale, int p_delta_rows,
174
+ const half * x, half * dst, int ne0 , int n_dims, int nr , const int32_t * pos, float freq_scale, int p_delta_rows,
204
175
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
205
176
206
- rope_neox_cuda<half>(x, dst, ncols , n_dims, nrows , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
177
+ rope_neox_cuda<half>(x, dst, ne0 , n_dims, nr , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
207
178
}
208
179
209
180
static void rope_neox_cuda_f32 (
210
- const float * x, float * dst, int ncols , int n_dims, int nrows , const int32_t * pos, float freq_scale, int p_delta_rows,
181
+ const float * x, float * dst, int ne0 , int n_dims, int nr , const int32_t * pos, float freq_scale, int p_delta_rows,
211
182
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
212
183
) {
213
184
214
- rope_neox_cuda<float >(x, dst, ncols , n_dims, nrows , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
185
+ rope_neox_cuda<float >(x, dst, ne0 , n_dims, nr , pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
215
186
}
216
187
217
188
void ggml_cuda_op_rope (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -232,30 +203,34 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
232
203
233
204
const int64_t ne00 = src0->ne [0 ];
234
205
const int64_t ne01 = src0->ne [1 ];
235
- const int64_t nrows = ggml_nrows (src0);
206
+ const int64_t nr = ggml_nrows (src0);
236
207
237
- // const int n_past = ((int32_t *) dst->op_params)[0];
238
- const int n_dims = ((int32_t *) dst->op_params )[1 ];
239
- const int mode = ((int32_t *) dst->op_params )[2 ];
240
- // const int n_ctx = ((int32_t *) dst->op_params)[3];
241
- const int n_orig_ctx = ((int32_t *) dst->op_params )[4 ];
208
+ // const int n_past = ((int32_t *) dst->op_params)[0];
209
+ const int n_dims = ((int32_t *) dst->op_params )[1 ];
210
+ const int mode = ((int32_t *) dst->op_params )[2 ];
211
+ // const int n_ctx = ((int32_t *) dst->op_params)[3];
212
+ const int n_orig_ctx = ((int32_t *) dst->op_params )[4 ];
242
213
243
214
// RoPE alteration for extended context
244
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
215
+ float freq_base;
216
+ float freq_scale;
217
+ float ext_factor;
218
+ float attn_factor;
219
+ float beta_fast;
220
+ float beta_slow;
221
+
245
222
memcpy (&freq_base, (int32_t *) dst->op_params + 5 , sizeof (float ));
246
223
memcpy (&freq_scale, (int32_t *) dst->op_params + 6 , sizeof (float ));
247
224
memcpy (&ext_factor, (int32_t *) dst->op_params + 7 , sizeof (float ));
248
225
memcpy (&attn_factor, (int32_t *) dst->op_params + 8 , sizeof (float ));
249
226
memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
250
227
memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
251
228
252
- const float * freq_factors = nullptr ;
253
- const int32_t * pos = nullptr ;
254
-
255
229
const bool is_neox = mode & 2 ;
256
230
257
- pos = (const int32_t *) src1_d;
231
+ const int32_t * pos = (const int32_t *) src1_d;
258
232
233
+ const float * freq_factors = nullptr ;
259
234
if (src2 != nullptr ) {
260
235
freq_factors = (const float *) src2->data ;
261
236
}
@@ -267,12 +242,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
267
242
if (is_neox) {
268
243
if (src0->type == GGML_TYPE_F32) {
269
244
rope_neox_cuda_f32 (
270
- (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows , pos, freq_scale, ne01, freq_base, ext_factor,
245
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr , pos, freq_scale, ne01, freq_base, ext_factor,
271
246
attn_factor, corr_dims, freq_factors, stream
272
247
);
273
248
} else if (src0->type == GGML_TYPE_F16) {
274
249
rope_neox_cuda_f16 (
275
- (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows , pos, freq_scale, ne01, freq_base, ext_factor,
250
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr , pos, freq_scale, ne01, freq_base, ext_factor,
276
251
attn_factor, corr_dims, freq_factors, stream
277
252
);
278
253
} else {
@@ -281,12 +256,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
281
256
} else {
282
257
if (src0->type == GGML_TYPE_F32) {
283
258
rope_norm_cuda_f32 (
284
- (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows , pos, freq_scale, ne01, freq_base, ext_factor,
259
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr , pos, freq_scale, ne01, freq_base, ext_factor,
285
260
attn_factor, corr_dims, freq_factors, stream
286
261
);
287
262
} else if (src0->type == GGML_TYPE_F16) {
288
263
rope_norm_cuda_f16 (
289
- (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows , pos, freq_scale, ne01, freq_base, ext_factor,
264
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr , pos, freq_scale, ne01, freq_base, ext_factor,
290
265
attn_factor, corr_dims, freq_factors, stream
291
266
);
292
267
} else {
0 commit comments