@@ -8826,7 +8826,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
8826
8826
}
8827
8827
8828
8828
struct rope_corr_dims {
8829
- float v[4 ];
8829
+ float v[2 ];
8830
8830
};
8831
8831
8832
8832
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
@@ -8850,29 +8850,38 @@ static void rope_yarn(
8850
8850
}
8851
8851
8852
8852
// rope == RoPE == rotary positional embedding
8853
- template<typename T, bool has_pos>
8854
- static void rope(
8855
- const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
8856
- float ext_factor, float attn_factor, rope_corr_dims corr_dims
8857
- ,
8853
+ template<typename T, bool has_ff>
8854
+ static void rope_norm(
8855
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
8856
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
8858
8857
const sycl::nd_item<3> &item_ct1) {
8859
- const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8858
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8860
8859
item_ct1.get_local_id(1));
8861
8860
8862
- if (col >= ncols ) {
8861
+ if (i0 >= ne0 ) {
8863
8862
return;
8864
8863
}
8865
8864
8866
8865
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
8867
8866
item_ct1.get_local_id(2);
8868
- const int i = row*ncols + col;
8867
+
8868
+ if (i0 >= n_dims) {
8869
+ const int i = row*ne0 + i0;
8870
+
8871
+ dst[i + 0] = x[i + 0];
8872
+ dst[i + 1] = x[i + 1];
8873
+
8874
+ return;
8875
+ }
8876
+
8877
+ const int i = row*ne0 + i0;
8869
8878
const int i2 = row/p_delta_rows;
8870
8879
8871
- const int p = has_pos ? pos[i2] : 0 ;
8872
- const float theta_base = p * dpct::pow(freq_base, -float(col) / ncols) ;
8880
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f) ;
8881
+ const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f ;
8873
8882
8874
8883
float cos_theta, sin_theta;
8875
- rope_yarn(theta_base, freq_scale, corr_dims, col , ext_factor, attn_factor, &cos_theta, &sin_theta);
8884
+ rope_yarn(theta_base/freq_factor , freq_scale, corr_dims, i0 , ext_factor, attn_factor, &cos_theta, &sin_theta);
8876
8885
8877
8886
const float x0 = x[i + 0];
8878
8887
const float x1 = x[i + 1];
@@ -8881,45 +8890,40 @@ static void rope(
8881
8890
dst[i + 1] = x0*sin_theta + x1*cos_theta;
8882
8891
}
8883
8892
8884
- template<typename T, bool has_pos, bool has_freq_facs>
8885
- static void rope_neox(
8886
- const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
8887
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
8888
- const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
8889
- const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8893
+ template <typename T, bool has_ff>
8894
+ static void rope_neox(const T *x, T *dst, int ne0, int n_dims,
8895
+ const int32_t *pos, float freq_scale, int p_delta_rows,
8896
+ float ext_factor, float attn_factor,
8897
+ rope_corr_dims corr_dims, float theta_scale,
8898
+ const float *freq_factors,
8899
+ const sycl::nd_item<3> &item_ct1) {
8900
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8890
8901
item_ct1.get_local_id(1));
8891
8902
8892
- if (col >= ncols ) {
8903
+ if (i0 >= ne0 ) {
8893
8904
return;
8894
8905
}
8895
8906
8896
8907
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
8897
8908
item_ct1.get_local_id(2);
8898
- const int ib = col / n_dims;
8899
- const int ic = col % n_dims;
8900
8909
8901
- if (ib > 0 ) {
8902
- const int i = row*ncols + ib*n_dims + ic ;
8910
+ if (i0 >= n_dims ) {
8911
+ const int i = row*ne0 + i0 ;
8903
8912
8904
8913
dst[i + 0] = x[i + 0];
8905
8914
dst[i + 1] = x[i + 1];
8906
8915
8907
8916
return;
8908
8917
}
8909
8918
8910
- const int i = row*ncols + ib*n_dims + ic /2;
8919
+ const int i = row*ne0 + i0 /2;
8911
8920
const int i2 = row/p_delta_rows;
8912
8921
8913
- float cur_rot = inv_ndims * ic - ib;
8914
-
8915
- const int p = has_pos ? pos[i2] : 0;
8916
- const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
8917
-
8918
- const float theta_base =
8919
- p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
8922
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
8923
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
8920
8924
8921
8925
float cos_theta, sin_theta;
8922
- rope_yarn(theta_base, freq_scale, corr_dims, cur_rot , ext_factor, attn_factor, &cos_theta, &sin_theta);
8926
+ rope_yarn(theta_base/freq_factor , freq_scale, corr_dims, i0 , ext_factor, attn_factor, &cos_theta, &sin_theta);
8923
8927
8924
8928
const float x0 = x[i + 0];
8925
8929
const float x1 = x[i + n_dims/2];
@@ -12375,15 +12379,18 @@ static void clamp_f32_sycl(const float *x, float *dst, const float min,
12375
12379
}
12376
12380
12377
12381
template <typename T>
12378
- static void rope_sycl (const T *x, T *dst, int ncols , int nrows ,
12382
+ static void rope_norm_sycl (const T *x, T *dst, int ne0 , int n_dims, int nr ,
12379
12383
const int32_t *pos, float freq_scale, int p_delta_rows,
12380
12384
float freq_base, float ext_factor, float attn_factor,
12381
- rope_corr_dims corr_dims, dpct::queue_ptr stream) {
12382
- GGML_ASSERT(ncols % 2 == 0);
12385
+ rope_corr_dims corr_dims, const float * freq_factors, dpct::queue_ptr stream) {
12386
+ GGML_ASSERT(ne0 % 2 == 0);
12383
12387
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
12384
- const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
12385
- const sycl::range<3> block_nums(1, num_blocks_x, nrows);
12386
- if (pos == nullptr) {
12388
+ const int n_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
12389
+ const sycl::range<3> block_nums(1, n_blocks_x, nr);
12390
+
12391
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
12392
+
12393
+ if (freq_factors == nullptr) {
12387
12394
/*
12388
12395
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
12389
12396
the limit. To get the device limit, query
@@ -12395,8 +12402,8 @@ static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
12395
12402
stream->parallel_for(
12396
12403
sycl::nd_range<3>(block_nums * block_dims, block_dims),
12397
12404
[=](sycl::nd_item<3> item_ct1) {
12398
- rope <T, false>(x, dst, ncols , pos, freq_scale, p_delta_rows,
12399
- freq_base, ext_factor, attn_factor, corr_dims,
12405
+ rope_norm <T, false>(x, dst, ne0, n_dims , pos, freq_scale, p_delta_rows,
12406
+ ext_factor, attn_factor, corr_dims, theta_scale, freq_factors ,
12400
12407
item_ct1);
12401
12408
});
12402
12409
} else {
@@ -12411,70 +12418,46 @@ static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
12411
12418
stream->parallel_for(
12412
12419
sycl::nd_range<3>(block_nums * block_dims, block_dims),
12413
12420
[=](sycl::nd_item<3> item_ct1) {
12414
- rope <T, true>(x, dst, ncols , pos, freq_scale, p_delta_rows,
12415
- freq_base, ext_factor, attn_factor, corr_dims,
12421
+ rope_norm <T, true>(x, dst, ne0, n_dims , pos, freq_scale, p_delta_rows,
12422
+ ext_factor, attn_factor, corr_dims, theta_scale, freq_factors ,
12416
12423
item_ct1);
12417
12424
});
12418
12425
}
12419
12426
}
12420
12427
12421
12428
template <typename T>
12422
- static void rope_neox_sycl(const T *x, T *dst, int ncols , int n_dims, int nrows ,
12429
+ static void rope_neox_sycl(const T *x, T *dst, int ne0 , int n_dims, int nr ,
12423
12430
const int32_t *pos, float freq_scale,
12424
12431
int p_delta_rows, float freq_base, float ext_factor,
12425
12432
float attn_factor, rope_corr_dims corr_dims,
12426
12433
const float * freq_factors, dpct::queue_ptr stream) {
12427
- GGML_ASSERT(ncols % 2 == 0);
12434
+ GGML_ASSERT(ne0 % 2 == 0);
12428
12435
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
12429
- const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
12430
- const sycl::range<3> block_nums(1, num_blocks_x, nrows );
12436
+ const int n_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
12437
+ const sycl::range<3> block_nums(1, n_blocks_x, nr );
12431
12438
12432
12439
const float theta_scale = powf(freq_base, -2.0f/n_dims);
12433
- const float inv_ndims = -1.0f / n_dims;
12434
12440
12435
- if (pos == nullptr) {
12436
12441
dpct::has_capability_or_fail(stream->get_device(),
12437
12442
{sycl::aspect::fp16});
12438
12443
if (freq_factors == nullptr) {
12439
12444
stream->parallel_for(
12440
12445
sycl::nd_range<3>(block_nums * block_dims, block_dims),
12441
12446
[=](sycl::nd_item<3> item_ct1) {
12442
- rope_neox<T, false, false >(x, dst, ncols , n_dims, pos, freq_scale,
12447
+ rope_neox<T, false>(x, dst, ne0 , n_dims, pos, freq_scale,
12443
12448
p_delta_rows, ext_factor, attn_factor,
12444
- corr_dims, theta_scale, inv_ndims , freq_factors,
12449
+ corr_dims, theta_scale, freq_factors,
12445
12450
item_ct1);
12446
12451
});
12447
12452
} else {
12448
12453
stream->parallel_for(
12449
12454
sycl::nd_range<3>(block_nums * block_dims, block_dims),
12450
12455
[=](sycl::nd_item<3> item_ct1) {
12451
- rope_neox<T, false, true>(x, dst, ncols , n_dims, pos, freq_scale,
12456
+ rope_neox<T, true>(x, dst, ne0 , n_dims, pos, freq_scale,
12452
12457
p_delta_rows, ext_factor, attn_factor,
12453
- corr_dims, theta_scale, inv_ndims , freq_factors,
12458
+ corr_dims, theta_scale, freq_factors,
12454
12459
item_ct1);
12455
12460
});
12456
- }
12457
- } else {
12458
- dpct::has_capability_or_fail(stream->get_device(),
12459
- {sycl::aspect::fp16});
12460
-
12461
- if (freq_factors == nullptr) {
12462
- stream->parallel_for(
12463
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12464
- [=](sycl::nd_item<3> item_ct1) {
12465
- rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
12466
- p_delta_rows, ext_factor, attn_factor,
12467
- corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12468
- });
12469
- } else {
12470
- stream->parallel_for(
12471
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12472
- [=](sycl::nd_item<3> item_ct1) {
12473
- rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
12474
- p_delta_rows, ext_factor, attn_factor,
12475
- corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12476
- });
12477
- }
12478
12461
}
12479
12462
}
12480
12463
@@ -14005,8 +13988,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14005
13988
14006
13989
const int64_t ne00 = src0->ne[0];
14007
13990
const int64_t ne01 = src0->ne[1];
14008
- const int64_t ne2 = dst->ne[2];
14009
- const int64_t nrows = ggml_nrows(src0);
13991
+ const int64_t nr = ggml_nrows(src0);
14010
13992
14011
13993
//const int n_past = ((int32_t *) dst->op_params)[0];
14012
13994
const int n_dims = ((int32_t *) dst->op_params)[1];
@@ -14023,27 +14005,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14023
14005
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
14024
14006
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
14025
14007
14026
- const float * freq_factors = nullptr;
14027
- const int32_t * pos = nullptr;
14028
- if ((mode & 1) == 0) {
14029
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
14030
- GGML_ASSERT(src1->ne[0] == ne2);
14031
- pos = (const int32_t *) src1_dd;
14032
- }
14033
-
14034
14008
const bool is_neox = mode & 2;
14035
14009
14036
- #pragma message("TODO: update rope NORM mode to match NEOX mode")
14037
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
14038
-
14039
- if (is_neox) {
14040
- pos = (const int32_t *) src1_dd;
14010
+ const int32_t * pos = (const int32_t *) src1_dd;
14041
14011
14012
+ const float * freq_factors = nullptr;
14042
14013
if (src2 != nullptr) {
14043
14014
freq_factors = (const float *) src2->data;
14044
- }
14045
- } else {
14046
- GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
14047
14015
}
14048
14016
14049
14017
rope_corr_dims corr_dims;
@@ -14053,27 +14021,27 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14053
14021
if (is_neox) {
14054
14022
if (src0->type == GGML_TYPE_F32) {
14055
14023
rope_neox_sycl(
14056
- (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows , pos, freq_scale, ne01, freq_base, ext_factor,
14024
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr , pos, freq_scale, ne01, freq_base, ext_factor,
14057
14025
attn_factor, corr_dims, freq_factors, main_stream
14058
14026
);
14059
14027
} else if (src0->type == GGML_TYPE_F16) {
14060
14028
rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
14061
- ne00, n_dims, nrows , pos, freq_scale, ne01,
14029
+ ne00, n_dims, nr , pos, freq_scale, ne01,
14062
14030
freq_base, ext_factor, attn_factor, corr_dims,
14063
14031
freq_factors, main_stream);
14064
14032
} else {
14065
14033
GGML_ASSERT(false);
14066
14034
}
14067
14035
} else {
14068
14036
if (src0->type == GGML_TYPE_F32) {
14069
- rope_sycl (
14070
- (const float *)src0_dd, (float *)dst_dd, ne00, nrows , pos, freq_scale, ne01, freq_base, ext_factor,
14071
- attn_factor, corr_dims, main_stream
14037
+ rope_norm_sycl (
14038
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr , pos, freq_scale, ne01, freq_base, ext_factor,
14039
+ attn_factor, corr_dims, freq_factors, main_stream
14072
14040
);
14073
14041
} else if (src0->type == GGML_TYPE_F16) {
14074
- rope_sycl ((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00,
14075
- nrows , pos, freq_scale, ne01, freq_base, ext_factor,
14076
- attn_factor, corr_dims, main_stream);
14042
+ rope_norm_sycl ((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00,
14043
+ n_dims, nr , pos, freq_scale, ne01, freq_base, ext_factor,
14044
+ attn_factor, corr_dims, freq_factors, main_stream);
14077
14045
} else {
14078
14046
GGML_ASSERT(false);
14079
14047
}
0 commit comments