@@ -9232,12 +9232,11 @@ static void rope(
9232
9232
dst[i + 1] = x0*sin_theta + x1*cos_theta;
9233
9233
}
9234
9234
9235
- template<typename T, bool has_pos>
9235
+ template<typename T, bool has_pos, bool has_freq_facs >
9236
9236
static void rope_neox(
9237
9237
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
9238
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
9239
- ,
9240
- const sycl::nd_item<3> &item_ct1) {
9238
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
9239
+ const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
9241
9240
const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
9242
9241
item_ct1.get_local_id(1));
9243
9242
@@ -9265,8 +9264,10 @@ static void rope_neox(
9265
9264
float cur_rot = inv_ndims * ic - ib;
9266
9265
9267
9266
const int p = has_pos ? pos[i2] : 0;
9267
+ const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
9268
+
9268
9269
const float theta_base =
9269
- p * freq_scale * dpct::pow(theta_scale, col / 2.0f);
9270
+ p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor ;
9270
9271
9271
9272
float cos_theta, sin_theta;
9272
9273
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -12881,7 +12882,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12881
12882
const int32_t *pos, float freq_scale,
12882
12883
int p_delta_rows, float freq_base, float ext_factor,
12883
12884
float attn_factor, rope_corr_dims corr_dims,
12884
- dpct::queue_ptr stream) {
12885
+ const float * freq_factors, dpct::queue_ptr stream) {
12885
12886
GGML_ASSERT(ncols % 2 == 0);
12886
12887
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
12887
12888
const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
@@ -12891,38 +12892,48 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12891
12892
const float inv_ndims = -1.0f / n_dims;
12892
12893
12893
12894
if (pos == nullptr) {
12894
- /*
12895
- DPCT1049:42: The work-group size passed to the SYCL kernel may exceed
12896
- the limit. To get the device limit, query
12897
- info::device::max_work_group_size. Adjust the work-group size if needed.
12898
- */
12899
12895
dpct::has_capability_or_fail(stream->get_device(),
12900
12896
{sycl::aspect::fp16});
12901
-
12902
- stream->parallel_for(
12903
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12904
- [=](sycl::nd_item<3> item_ct1) {
12905
- rope_neox<T, false>(x, dst, ncols, n_dims, pos, freq_scale,
12906
- p_delta_rows, ext_factor, attn_factor,
12907
- corr_dims, theta_scale, inv_ndims,
12908
- item_ct1);
12909
- });
12897
+ if (freq_factors == nullptr) {
12898
+ stream->parallel_for(
12899
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12900
+ [=](sycl::nd_item<3> item_ct1) {
12901
+ rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
12902
+ p_delta_rows, ext_factor, attn_factor,
12903
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12904
+ item_ct1);
12905
+ });
12906
+ } else {
12907
+ stream->parallel_for(
12908
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12909
+ [=](sycl::nd_item<3> item_ct1) {
12910
+ rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
12911
+ p_delta_rows, ext_factor, attn_factor,
12912
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12913
+ item_ct1);
12914
+ });
12915
+ }
12910
12916
} else {
12911
- /*
12912
- DPCT1049:43: The work-group size passed to the SYCL kernel may exceed
12913
- the limit. To get the device limit, query
12914
- info::device::max_work_group_size. Adjust the work-group size if needed.
12915
- */
12916
12917
dpct::has_capability_or_fail(stream->get_device(),
12917
12918
{sycl::aspect::fp16});
12918
12919
12919
- stream->parallel_for(
12920
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12921
- [=](sycl::nd_item<3> item_ct1) {
12922
- rope_neox<T, true>(x, dst, ncols, n_dims, pos, freq_scale,
12923
- p_delta_rows, ext_factor, attn_factor,
12924
- corr_dims, theta_scale, inv_ndims, item_ct1);
12925
- });
12920
+ if (freq_factors == nullptr) {
12921
+ stream->parallel_for(
12922
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12923
+ [=](sycl::nd_item<3> item_ct1) {
12924
+ rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
12925
+ p_delta_rows, ext_factor, attn_factor,
12926
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12927
+ });
12928
+ } else {
12929
+ stream->parallel_for(
12930
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12931
+ [=](sycl::nd_item<3> item_ct1) {
12932
+ rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
12933
+ p_delta_rows, ext_factor, attn_factor,
12934
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12935
+ });
12936
+ }
12926
12937
}
12927
12938
}
12928
12939
@@ -14454,9 +14465,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14454
14465
ggml_tensor *dst, const float *src0_dd,
14455
14466
const float *src1_dd, float *dst_dd,
14456
14467
const dpct::queue_ptr &main_stream) {
14457
- #pragma message("TODO: implement phi3 frequency factors support")
14458
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
14459
- GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
14468
+ const ggml_tensor * src2 = dst->src[2];
14460
14469
14461
14470
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
14462
14471
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -14482,6 +14491,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14482
14491
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
14483
14492
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
14484
14493
14494
+ const float * freq_factors = nullptr;
14485
14495
const int32_t * pos = nullptr;
14486
14496
if ((mode & 1) == 0) {
14487
14497
GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -14492,6 +14502,16 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14492
14502
const bool is_neox = mode & 2;
14493
14503
const bool is_glm = mode & 4;
14494
14504
14505
+ if (is_neox) {
14506
+ pos = (const int32_t *) src1_dd;
14507
+
14508
+ if (src2 != nullptr) {
14509
+ freq_factors = (const float *) src2->data;
14510
+ }
14511
+ } else {
14512
+ GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
14513
+ }
14514
+
14495
14515
rope_corr_dims corr_dims;
14496
14516
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
14497
14517
@@ -14503,13 +14523,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14503
14523
if (src0->type == GGML_TYPE_F32) {
14504
14524
rope_neox_sycl(
14505
14525
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
14506
- attn_factor, corr_dims, main_stream
14526
+ attn_factor, corr_dims, freq_factors, main_stream
14507
14527
);
14508
14528
} else if (src0->type == GGML_TYPE_F16) {
14509
14529
rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
14510
14530
ne00, n_dims, nrows, pos, freq_scale, ne01,
14511
14531
freq_base, ext_factor, attn_factor, corr_dims,
14512
- main_stream);
14532
+ freq_factors, main_stream);
14513
14533
} else {
14514
14534
GGML_ASSERT(false);
14515
14535
}
0 commit comments