Skip to content

Commit 5487593

Browse files
authored
Add freq factors (#7495)
1 parent 1d8fca7 commit 5487593

File tree

1 file changed

+57
-37
lines changed

1 file changed

+57
-37
lines changed

ggml-sycl.cpp

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8830,12 +8830,11 @@ static void rope(
88308830
dst[i + 1] = x0*sin_theta + x1*cos_theta;
88318831
}
88328832

8833-
template<typename T, bool has_pos>
8833+
template<typename T, bool has_pos, bool has_freq_facs>
88348834
static void rope_neox(
88358835
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
8836-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
8837-
,
8838-
const sycl::nd_item<3> &item_ct1) {
8836+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
8837+
const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
88398838
const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
88408839
item_ct1.get_local_id(1));
88418840

@@ -8863,8 +8862,10 @@ static void rope_neox(
88638862
float cur_rot = inv_ndims * ic - ib;
88648863

88658864
const int p = has_pos ? pos[i2] : 0;
8865+
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
8866+
88668867
const float theta_base =
8867-
p * freq_scale * dpct::pow(theta_scale, col / 2.0f);
8868+
p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
88688869

88698870
float cos_theta, sin_theta;
88708871
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -12413,7 +12414,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
1241312414
const int32_t *pos, float freq_scale,
1241412415
int p_delta_rows, float freq_base, float ext_factor,
1241512416
float attn_factor, rope_corr_dims corr_dims,
12416-
dpct::queue_ptr stream) {
12417+
const float * freq_factors, dpct::queue_ptr stream) {
1241712418
GGML_ASSERT(ncols % 2 == 0);
1241812419
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
1241912420
const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
@@ -12423,38 +12424,48 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
1242312424
const float inv_ndims = -1.0f / n_dims;
1242412425

1242512426
if (pos == nullptr) {
12426-
/*
12427-
DPCT1049:42: The work-group size passed to the SYCL kernel may exceed
12428-
the limit. To get the device limit, query
12429-
info::device::max_work_group_size. Adjust the work-group size if needed.
12430-
*/
1243112427
dpct::has_capability_or_fail(stream->get_device(),
1243212428
{sycl::aspect::fp16});
12433-
12434-
stream->parallel_for(
12435-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
12436-
[=](sycl::nd_item<3> item_ct1) {
12437-
rope_neox<T, false>(x, dst, ncols, n_dims, pos, freq_scale,
12438-
p_delta_rows, ext_factor, attn_factor,
12439-
corr_dims, theta_scale, inv_ndims,
12440-
item_ct1);
12441-
});
12429+
if (freq_factors == nullptr) {
12430+
stream->parallel_for(
12431+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
12432+
[=](sycl::nd_item<3> item_ct1) {
12433+
rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
12434+
p_delta_rows, ext_factor, attn_factor,
12435+
corr_dims, theta_scale, inv_ndims, freq_factors,
12436+
item_ct1);
12437+
});
12438+
} else {
12439+
stream->parallel_for(
12440+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
12441+
[=](sycl::nd_item<3> item_ct1) {
12442+
rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
12443+
p_delta_rows, ext_factor, attn_factor,
12444+
corr_dims, theta_scale, inv_ndims, freq_factors,
12445+
item_ct1);
12446+
});
12447+
}
1244212448
} else {
12443-
/*
12444-
DPCT1049:43: The work-group size passed to the SYCL kernel may exceed
12445-
the limit. To get the device limit, query
12446-
info::device::max_work_group_size. Adjust the work-group size if needed.
12447-
*/
1244812449
dpct::has_capability_or_fail(stream->get_device(),
1244912450
{sycl::aspect::fp16});
1245012451

12451-
stream->parallel_for(
12452-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
12453-
[=](sycl::nd_item<3> item_ct1) {
12454-
rope_neox<T, true>(x, dst, ncols, n_dims, pos, freq_scale,
12455-
p_delta_rows, ext_factor, attn_factor,
12456-
corr_dims, theta_scale, inv_ndims, item_ct1);
12457-
});
12452+
if (freq_factors == nullptr) {
12453+
stream->parallel_for(
12454+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
12455+
[=](sycl::nd_item<3> item_ct1) {
12456+
rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
12457+
p_delta_rows, ext_factor, attn_factor,
12458+
corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12459+
});
12460+
} else {
12461+
stream->parallel_for(
12462+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
12463+
[=](sycl::nd_item<3> item_ct1) {
12464+
rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
12465+
p_delta_rows, ext_factor, attn_factor,
12466+
corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12467+
});
12468+
}
1245812469
}
1245912470
}
1246012471

@@ -13986,9 +13997,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
1398613997
ggml_tensor *dst, const float *src0_dd,
1398713998
const float *src1_dd, float *dst_dd,
1398813999
const dpct::queue_ptr &main_stream) {
13989-
#pragma message("TODO: implement phi3 frequency factors support")
13990-
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
13991-
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
14000+
const ggml_tensor * src2 = dst->src[2];
1399214001

1399314002
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
1399414003
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -14014,6 +14023,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
1401414023
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1401514024
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1401614025

14026+
const float * freq_factors = nullptr;
1401714027
const int32_t * pos = nullptr;
1401814028
if ((mode & 1) == 0) {
1401914029
GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -14024,6 +14034,16 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
1402414034
const bool is_neox = mode & 2;
1402514035
const bool is_glm = mode & 4;
1402614036

14037+
if (is_neox) {
14038+
pos = (const int32_t *) src1_dd;
14039+
14040+
if (src2 != nullptr) {
14041+
freq_factors = (const float *) src2->data;
14042+
}
14043+
} else {
14044+
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
14045+
}
14046+
1402714047
rope_corr_dims corr_dims;
1402814048
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
1402914049

@@ -14035,13 +14055,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
1403514055
if (src0->type == GGML_TYPE_F32) {
1403614056
rope_neox_sycl(
1403714057
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
14038-
attn_factor, corr_dims, main_stream
14058+
attn_factor, corr_dims, freq_factors, main_stream
1403914059
);
1404014060
} else if (src0->type == GGML_TYPE_F16) {
1404114061
rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
1404214062
ne00, n_dims, nrows, pos, freq_scale, ne01,
1404314063
freq_base, ext_factor, attn_factor, corr_dims,
14044-
main_stream);
14064+
freq_factors, main_stream);
1404514065
} else {
1404614066
GGML_ASSERT(false);
1404714067
}

0 commit comments

Comments
 (0)