Skip to content

Commit 197fe6c

Browse files
authored
[SYCL] Update SYCL-Rope op and Refactor (#8157)
* align with rope.cu and move sycl-op to a single file
1 parent d0a7145 commit 197fe6c

File tree

4 files changed

+300
-303
lines changed

4 files changed

+300
-303
lines changed

ggml/src/ggml-sycl.cpp

Lines changed: 2 additions & 303 deletions
Original file line numberDiff line numberDiff line change
@@ -978,114 +978,6 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne,
978978
cpy_blck(cx + x_offset, cdst + dst_offset);
979979
}
980980

981-
static float rope_yarn_ramp(const float low, const float high, const int i0) {
982-
const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
983-
return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
984-
}
985-
986-
struct rope_corr_dims {
987-
float v[4];
988-
};
989-
990-
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
991-
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
992-
static void rope_yarn(
993-
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
994-
float * cos_theta, float * sin_theta
995-
) {
996-
// Get n-d rotational scaling corrected for extrapolation
997-
float theta_interp = freq_scale * theta_extrap;
998-
float theta = theta_interp;
999-
if (ext_factor != 0.0f) {
1000-
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
1001-
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
1002-
1003-
// Get n-d magnitude scaling corrected for interpolation
1004-
mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
1005-
}
1006-
*cos_theta = sycl::cos(theta) * mscale;
1007-
*sin_theta = sycl::sin(theta) * mscale;
1008-
}
1009-
1010-
// rope == RoPE == rotary positional embedding
1011-
template<typename T, bool has_pos>
1012-
static void rope(
1013-
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
1014-
float ext_factor, float attn_factor, rope_corr_dims corr_dims
1015-
,
1016-
const sycl::nd_item<3> &item_ct1) {
1017-
const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1018-
item_ct1.get_local_id(1));
1019-
1020-
if (col >= ncols) {
1021-
return;
1022-
}
1023-
1024-
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1025-
item_ct1.get_local_id(2);
1026-
const int i = row*ncols + col;
1027-
const int i2 = row/p_delta_rows;
1028-
1029-
const int p = has_pos ? pos[i2] : 0;
1030-
const float theta_base = p * dpct::pow(freq_base, -float(col) / ncols);
1031-
1032-
float cos_theta, sin_theta;
1033-
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
1034-
1035-
const float x0 = x[i + 0];
1036-
const float x1 = x[i + 1];
1037-
1038-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
1039-
dst[i + 1] = x0*sin_theta + x1*cos_theta;
1040-
}
1041-
1042-
template<typename T, bool has_pos, bool has_freq_facs>
1043-
static void rope_neox(
1044-
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
1045-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
1046-
const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
1047-
const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1048-
item_ct1.get_local_id(1));
1049-
1050-
if (col >= ncols) {
1051-
return;
1052-
}
1053-
1054-
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1055-
item_ct1.get_local_id(2);
1056-
const int ib = col / n_dims;
1057-
const int ic = col % n_dims;
1058-
1059-
if (ib > 0) {
1060-
const int i = row*ncols + ib*n_dims + ic;
1061-
1062-
dst[i + 0] = x[i + 0];
1063-
dst[i + 1] = x[i + 1];
1064-
1065-
return;
1066-
}
1067-
1068-
const int i = row*ncols + ib*n_dims + ic/2;
1069-
const int i2 = row/p_delta_rows;
1070-
1071-
float cur_rot = inv_ndims * ic - ib;
1072-
1073-
const int p = has_pos ? pos[i2] : 0;
1074-
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
1075-
1076-
const float theta_base =
1077-
p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
1078-
1079-
float cos_theta, sin_theta;
1080-
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1081-
1082-
const float x0 = x[i + 0];
1083-
const float x1 = x[i + n_dims/2];
1084-
1085-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
1086-
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
1087-
}
1088-
1089981
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
1090982
const sycl::nd_item<3> &item_ct1) {
1091983
const int row = item_ct1.get_group(1);
@@ -2241,110 +2133,6 @@ static void clamp_f32_sycl(const float *x, float *dst, const float min,
22412133
});
22422134
}
22432135

2244-
template <typename T>
2245-
static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
2246-
const int32_t *pos, float freq_scale, int p_delta_rows,
2247-
float freq_base, float ext_factor, float attn_factor,
2248-
rope_corr_dims corr_dims, queue_ptr stream) {
2249-
GGML_ASSERT(ncols % 2 == 0);
2250-
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
2251-
const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
2252-
const sycl::range<3> block_nums(1, num_blocks_x, nrows);
2253-
if (pos == nullptr) {
2254-
/*
2255-
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
2256-
the limit. To get the device limit, query
2257-
info::device::max_work_group_size. Adjust the work-group size if needed.
2258-
*/
2259-
dpct::has_capability_or_fail(stream->get_device(),
2260-
{sycl::aspect::fp16});
2261-
2262-
stream->parallel_for(
2263-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
2264-
[=](sycl::nd_item<3> item_ct1) {
2265-
rope<T, false>(x, dst, ncols, pos, freq_scale, p_delta_rows,
2266-
freq_base, ext_factor, attn_factor, corr_dims,
2267-
item_ct1);
2268-
});
2269-
} else {
2270-
/*
2271-
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
2272-
the limit. To get the device limit, query
2273-
info::device::max_work_group_size. Adjust the work-group size if needed.
2274-
*/
2275-
dpct::has_capability_or_fail(stream->get_device(),
2276-
{sycl::aspect::fp16});
2277-
2278-
stream->parallel_for(
2279-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
2280-
[=](sycl::nd_item<3> item_ct1) {
2281-
rope<T, true>(x, dst, ncols, pos, freq_scale, p_delta_rows,
2282-
freq_base, ext_factor, attn_factor, corr_dims,
2283-
item_ct1);
2284-
});
2285-
}
2286-
}
2287-
2288-
template <typename T>
2289-
static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
2290-
const int32_t *pos, float freq_scale,
2291-
int p_delta_rows, float freq_base, float ext_factor,
2292-
float attn_factor, rope_corr_dims corr_dims,
2293-
const float * freq_factors, queue_ptr stream) {
2294-
GGML_ASSERT(ncols % 2 == 0);
2295-
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
2296-
const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
2297-
const sycl::range<3> block_nums(1, num_blocks_x, nrows);
2298-
2299-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
2300-
const float inv_ndims = -1.0f / n_dims;
2301-
2302-
if (pos == nullptr) {
2303-
dpct::has_capability_or_fail(stream->get_device(),
2304-
{sycl::aspect::fp16});
2305-
if (freq_factors == nullptr) {
2306-
stream->parallel_for(
2307-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
2308-
[=](sycl::nd_item<3> item_ct1) {
2309-
rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
2310-
p_delta_rows, ext_factor, attn_factor,
2311-
corr_dims, theta_scale, inv_ndims, freq_factors,
2312-
item_ct1);
2313-
});
2314-
} else {
2315-
stream->parallel_for(
2316-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
2317-
[=](sycl::nd_item<3> item_ct1) {
2318-
rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
2319-
p_delta_rows, ext_factor, attn_factor,
2320-
corr_dims, theta_scale, inv_ndims, freq_factors,
2321-
item_ct1);
2322-
});
2323-
}
2324-
} else {
2325-
dpct::has_capability_or_fail(stream->get_device(),
2326-
{sycl::aspect::fp16});
2327-
2328-
if (freq_factors == nullptr) {
2329-
stream->parallel_for(
2330-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
2331-
[=](sycl::nd_item<3> item_ct1) {
2332-
rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
2333-
p_delta_rows, ext_factor, attn_factor,
2334-
corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
2335-
});
2336-
} else {
2337-
stream->parallel_for(
2338-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
2339-
[=](sycl::nd_item<3> item_ct1) {
2340-
rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
2341-
p_delta_rows, ext_factor, attn_factor,
2342-
corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
2343-
});
2344-
}
2345-
}
2346-
}
2347-
23482136
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
23492137
const int nrows, queue_ptr stream) {
23502138
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -3461,97 +3249,6 @@ catch (sycl::exception const &exc) {
34613249
std::exit(1);
34623250
}
34633251

3464-
inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3465-
ggml_tensor *dst, const float *src0_dd,
3466-
const float *src1_dd, float *dst_dd,
3467-
const queue_ptr &main_stream) {
3468-
const ggml_tensor * src2 = dst->src[2];
3469-
3470-
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3471-
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
3472-
GGML_ASSERT(src0->type == dst->type);
3473-
3474-
const int64_t ne00 = src0->ne[0];
3475-
const int64_t ne01 = src0->ne[1];
3476-
const int64_t ne2 = dst->ne[2];
3477-
const int64_t nrows = ggml_nrows(src0);
3478-
3479-
//const int n_past = ((int32_t *) dst->op_params)[0];
3480-
const int n_dims = ((int32_t *) dst->op_params)[1];
3481-
const int mode = ((int32_t *) dst->op_params)[2];
3482-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
3483-
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
3484-
3485-
// RoPE alteration for extended context
3486-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
3487-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
3488-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
3489-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
3490-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
3491-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
3492-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
3493-
3494-
const float * freq_factors = nullptr;
3495-
const int32_t * pos = nullptr;
3496-
if ((mode & 1) == 0) {
3497-
GGML_ASSERT(src1->type == GGML_TYPE_I32);
3498-
GGML_ASSERT(src1->ne[0] == ne2);
3499-
pos = (const int32_t *) src1_dd;
3500-
}
3501-
3502-
const bool is_neox = mode & 2;
3503-
3504-
#pragma message("TODO: update rope NORM mode to match NEOX mode")
3505-
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
3506-
3507-
if (is_neox) {
3508-
pos = (const int32_t *) src1_dd;
3509-
3510-
if (src2 != nullptr) {
3511-
freq_factors = (const float *) src2->data;
3512-
}
3513-
} else {
3514-
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
3515-
}
3516-
3517-
rope_corr_dims corr_dims;
3518-
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
3519-
3520-
// compute
3521-
if (is_neox) {
3522-
if (src0->type == GGML_TYPE_F32) {
3523-
rope_neox_sycl(
3524-
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
3525-
attn_factor, corr_dims, freq_factors, main_stream
3526-
);
3527-
} else if (src0->type == GGML_TYPE_F16) {
3528-
rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
3529-
ne00, n_dims, nrows, pos, freq_scale, ne01,
3530-
freq_base, ext_factor, attn_factor, corr_dims,
3531-
freq_factors, main_stream);
3532-
} else {
3533-
GGML_ASSERT(false);
3534-
}
3535-
} else {
3536-
if (src0->type == GGML_TYPE_F32) {
3537-
rope_sycl(
3538-
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
3539-
attn_factor, corr_dims, main_stream
3540-
);
3541-
} else if (src0->type == GGML_TYPE_F16) {
3542-
rope_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00,
3543-
nrows, pos, freq_scale, ne01, freq_base, ext_factor,
3544-
attn_factor, corr_dims, main_stream);
3545-
} else {
3546-
GGML_ASSERT(false);
3547-
}
3548-
}
3549-
3550-
(void) src1;
3551-
(void) dst;
3552-
(void) src1_dd;
3553-
}
3554-
35553252
static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
35563253
const ggml_tensor *src1, ggml_tensor *dst,
35573254
const float *src0_dd, const float *src1_dd,
@@ -6241,7 +5938,9 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
62415938
case GGML_OP_CONT:
62425939
case GGML_OP_DIAG_MASK_INF:
62435940
case GGML_OP_SOFT_MAX:
5941+
return true;
62445942
case GGML_OP_ROPE:
5943+
return ggml_is_contiguous(op->src[0]);
62455944
case GGML_OP_IM2COL:
62465945
case GGML_OP_POOL_2D:
62475946
case GGML_OP_SUM_ROWS:

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@
1919
#include "dmmv.hpp"
2020
#include "mmq.hpp"
2121
#include "mmvq.hpp"
22+
#include "rope.hpp"
2223

2324
#endif // GGML_SYCL_BACKEND_HPP

0 commit comments

Comments
 (0)