Skip to content

Commit ec03bbc

Browse files
committed
ggml : remove GLM rope mode
ggml-ci
1 parent d29218c commit ec03bbc

File tree

13 files changed

+141
-353
lines changed

13 files changed

+141
-353
lines changed

examples/baby-llama/baby-llama.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,8 @@ static struct ggml_tensor * forward(
522522
// wk shape [n_embd, n_embd, 1, 1]
523523
// Qcur shape [n_embd/n_head, n_head, N, 1]
524524
// Kcur shape [n_embd/n_head, n_head, N, 1]
525-
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
526-
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
525+
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0);
526+
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0);
527527

528528
// store key and value to memory
529529
{
@@ -759,8 +759,8 @@ static struct ggml_tensor * forward_batch(
759759
// wk shape [n_embd, n_embd, 1, 1]
760760
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
761761
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
762-
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
763-
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
762+
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0);
763+
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0);
764764
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
765765
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
766766

@@ -1056,7 +1056,7 @@ static struct ggml_tensor * forward_lora(
10561056
model->layers[il].wqb,
10571057
cur)),
10581058
n_embd/n_head, n_head, N),
1059-
KQ_pos, n_rot, 0, 0);
1059+
KQ_pos, n_rot, 0);
10601060
struct ggml_tensor * Kcur = ggml_rope(ctx0,
10611061
ggml_reshape_3d(ctx0,
10621062
ggml_mul_mat(ctx0,
@@ -1065,7 +1065,7 @@ static struct ggml_tensor * forward_lora(
10651065
model->layers[il].wkb,
10661066
cur)),
10671067
n_embd/n_head, n_head, N),
1068-
KQ_pos, n_rot, 0, 0);
1068+
KQ_pos, n_rot, 0);
10691069

10701070
// store key and value to memory
10711071
{

examples/finetune/finetune.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
564564
const int rope_mode = 0;
565565

566566
return ggml_rope_ext(ctx,
567-
t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0,
567+
t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx,
568568
rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
569569
);
570570
};

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ static struct ggml_tensor * llama_build_train_graphs(
302302
const int rope_mode = 0;
303303

304304
return ggml_rope_ext(
305-
ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
305+
ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
306306
);
307307
};
308308

ggml-cuda/rope.cu

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -100,46 +100,6 @@ static __global__ void rope_neox(
100100
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
101101
}
102102

103-
static __global__ void rope_glm_f32(
104-
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
105-
int n_ctx
106-
) {
107-
const int col = blockDim.x*blockIdx.x + threadIdx.x;
108-
const int half_n_dims = ncols/4;
109-
110-
if (col >= half_n_dims) {
111-
return;
112-
}
113-
114-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
115-
const int i = row*ncols + col;
116-
const int i2 = row/p_delta_rows;
117-
118-
const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
119-
// FIXME: this is likely wrong
120-
const int p = pos != nullptr ? pos[i2] : 0;
121-
122-
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
123-
const float sin_theta = sinf(theta);
124-
const float cos_theta = cosf(theta);
125-
126-
const float x0 = x[i + 0];
127-
const float x1 = x[i + half_n_dims];
128-
129-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
130-
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
131-
132-
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
133-
const float sin_block_theta = sinf(block_theta);
134-
const float cos_block_theta = cosf(block_theta);
135-
136-
const float x2 = x[i + half_n_dims * 2];
137-
const float x3 = x[i + half_n_dims * 3];
138-
139-
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
140-
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
141-
}
142-
143103

144104
template<typename T>
145105
static void rope_cuda(
@@ -200,17 +160,6 @@ static void rope_neox_cuda(
200160
}
201161
}
202162

203-
static void rope_glm_f32_cuda(
204-
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
205-
float freq_base, int n_ctx, cudaStream_t stream
206-
) {
207-
GGML_ASSERT(ncols % 4 == 0);
208-
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
209-
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
210-
const dim3 block_nums(num_blocks_x, nrows, 1);
211-
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
212-
}
213-
214163
static void rope_cuda_f16(
215164
const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
216165
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
@@ -263,7 +212,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
263212
//const int n_past = ((int32_t *) dst->op_params)[0];
264213
const int n_dims = ((int32_t *) dst->op_params)[1];
265214
const int mode = ((int32_t *) dst->op_params)[2];
266-
const int n_ctx = ((int32_t *) dst->op_params)[3];
215+
//const int n_ctx = ((int32_t *) dst->op_params)[3];
267216
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
268217

269218
// RoPE alteration for extended context
@@ -279,7 +228,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
279228
const int32_t * pos = nullptr;
280229

281230
const bool is_neox = mode & 2;
282-
const bool is_glm = mode & 4;
283231

284232
pos = (const int32_t *) src1_d;
285233

@@ -295,10 +243,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
295243
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
296244

297245
// compute
298-
if (is_glm) {
299-
GGML_ASSERT(false);
300-
rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream);
301-
} else if (is_neox) {
246+
if (is_neox) {
302247
if (src0->type == GGML_TYPE_F32) {
303248
rope_neox_cuda_f32(
304249
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,

ggml-metal.m

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,9 +2296,6 @@ static enum ggml_status ggml_metal_graph_compute(
22962296
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
22972297

22982298
const bool is_neox = mode & 2;
2299-
const bool is_glm = mode & 4;
2300-
2301-
GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
23022299

23032300
if (!is_neox) {
23042301
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");

ggml-sycl.cpp

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -8928,49 +8928,6 @@ static void rope_neox(
89288928
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
89298929
}
89308930

8931-
static void rope_glm_f32(
8932-
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
8933-
int n_ctx
8934-
, const sycl::nd_item<3> &item_ct1) {
8935-
const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
8936-
item_ct1.get_local_id(2);
8937-
const int half_n_dims = ncols/4;
8938-
8939-
if (col >= half_n_dims) {
8940-
return;
8941-
}
8942-
8943-
const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8944-
item_ct1.get_local_id(1);
8945-
const int i = row*ncols + col;
8946-
const int i2 = row/p_delta_rows;
8947-
8948-
const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
8949-
// FIXME: this is likely wrong
8950-
const int p = pos != nullptr ? pos[i2] : 0;
8951-
8952-
const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
8953-
const float sin_theta = sycl::sin((float)theta);
8954-
const float cos_theta = sycl::cos((float)theta);
8955-
8956-
const float x0 = x[i + 0];
8957-
const float x1 = x[i + half_n_dims];
8958-
8959-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
8960-
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
8961-
8962-
const float block_theta =
8963-
((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
8964-
const float sin_block_theta = sycl::sin((float)block_theta);
8965-
const float cos_block_theta = sycl::cos((float)block_theta);
8966-
8967-
const float x2 = x[i + half_n_dims * 2];
8968-
const float x3 = x[i + half_n_dims * 3];
8969-
8970-
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
8971-
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
8972-
}
8973-
89748931
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
89758932
const sycl::nd_item<3> &item_ct1) {
89768933
const int row = item_ct1.get_group(1);
@@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
1252012477
}
1252112478
}
1252212479

12523-
static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
12524-
const int32_t *pos, float freq_scale,
12525-
int p_delta_rows, float freq_base, int n_ctx,
12526-
dpct::queue_ptr stream) {
12527-
GGML_ASSERT(ncols % 4 == 0);
12528-
const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
12529-
const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
12530-
const sycl::range<3> block_nums(1, nrows, num_blocks_x);
12531-
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
12532-
[=](sycl::nd_item<3> item_ct1) {
12533-
rope_glm_f32(x, dst, ncols, pos, freq_scale,
12534-
p_delta_rows, freq_base, n_ctx,
12535-
item_ct1);
12536-
});
12537-
}
12538-
1253912480
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
1254012481
const int nrows, dpct::queue_ptr stream) {
1254112482
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -14066,7 +14007,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
1406614007
//const int n_past = ((int32_t *) dst->op_params)[0];
1406714008
const int n_dims = ((int32_t *) dst->op_params)[1];
1406814009
const int mode = ((int32_t *) dst->op_params)[2];
14069-
const int n_ctx = ((int32_t *) dst->op_params)[3];
14010+
//const int n_ctx = ((int32_t *) dst->op_params)[3];
1407014011
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1407114012

1407214013
// RoPE alteration for extended context
@@ -14087,7 +14028,6 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
1408714028
}
1408814029

1408914030
const bool is_neox = mode & 2;
14090-
const bool is_glm = mode & 4;
1409114031

1409214032
if (is_neox) {
1409314033
pos = (const int32_t *) src1_dd;
@@ -14103,10 +14043,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
1410314043
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
1410414044

1410514045
// compute
14106-
if (is_glm) {
14107-
GGML_ASSERT(false);
14108-
rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
14109-
} else if (is_neox) {
14046+
if (is_neox) {
1411014047
if (src0->type == GGML_TYPE_F32) {
1411114048
rope_neox_sycl(
1411214049
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,

ggml-vulkan.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3823,11 +3823,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
38233823
{
38243824
const int mode = ((const int32_t *) dst->op_params)[2];
38253825
const bool is_neox = mode & 2;
3826-
const bool is_glm = mode & 4;
3827-
3828-
if (is_glm) {
3829-
return nullptr;
3830-
}
38313826

38323827
if (is_neox) {
38333828
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@@ -4307,9 +4302,6 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
43074302
const float beta_slow = ((float *) dst->op_params)[10];
43084303

43094304
const bool is_neox = mode & 2;
4310-
const bool is_glm = mode & 4;
4311-
4312-
GGML_ASSERT(!is_glm);
43134305

43144306
float corr_dims[2];
43154307
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
@@ -6365,9 +6357,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
63656357
case GGML_OP_ROPE:
63666358
{
63676359
const int mode = ((const int32_t *) op->op_params)[2];
6368-
const bool is_glm = mode & 4;
63696360

6370-
return !is_glm;
6361+
return true;
63716362
} break;
63726363
case GGML_OP_NONE:
63736364
case GGML_OP_RESHAPE:

0 commit comments

Comments
 (0)