@@ -978,114 +978,6 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne,
978
978
cpy_blck (cx + x_offset, cdst + dst_offset);
979
979
}
980
980
981
- static float rope_yarn_ramp (const float low, const float high, const int i0) {
982
- const float y = (i0 / 2 - low) / sycl::max (0 .001f , high - low);
983
- return 1 .0f - sycl::min (1 .0f , sycl::max (0 .0f , y));
984
- }
985
-
986
- struct rope_corr_dims {
987
- float v[4 ];
988
- };
989
-
990
- // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
991
- // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
992
- static void rope_yarn (
993
- float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
994
- float * cos_theta, float * sin_theta
995
- ) {
996
- // Get n-d rotational scaling corrected for extrapolation
997
- float theta_interp = freq_scale * theta_extrap;
998
- float theta = theta_interp;
999
- if (ext_factor != 0 .0f ) {
1000
- float ramp_mix = rope_yarn_ramp (corr_dims.v [0 ], corr_dims.v [1 ], i0) * ext_factor;
1001
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
1002
-
1003
- // Get n-d magnitude scaling corrected for interpolation
1004
- mscale *= 1 .0f + 0 .1f * sycl::log (1 .0f / freq_scale);
1005
- }
1006
- *cos_theta = sycl::cos (theta) * mscale;
1007
- *sin_theta = sycl::sin (theta) * mscale;
1008
- }
1009
-
1010
- // rope == RoPE == rotary positional embedding
1011
- template <typename T, bool has_pos>
1012
- static void rope (
1013
- const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
1014
- float ext_factor, float attn_factor, rope_corr_dims corr_dims
1015
- ,
1016
- const sycl::nd_item<3 > &item_ct1) {
1017
- const int col = 2 * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
1018
- item_ct1.get_local_id (1 ));
1019
-
1020
- if (col >= ncols) {
1021
- return ;
1022
- }
1023
-
1024
- const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
1025
- item_ct1.get_local_id (2 );
1026
- const int i = row*ncols + col;
1027
- const int i2 = row/p_delta_rows;
1028
-
1029
- const int p = has_pos ? pos[i2] : 0 ;
1030
- const float theta_base = p * dpct::pow (freq_base, -float (col) / ncols);
1031
-
1032
- float cos_theta, sin_theta;
1033
- rope_yarn (theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
1034
-
1035
- const float x0 = x[i + 0 ];
1036
- const float x1 = x[i + 1 ];
1037
-
1038
- dst[i + 0 ] = x0*cos_theta - x1*sin_theta;
1039
- dst[i + 1 ] = x0*sin_theta + x1*cos_theta;
1040
- }
1041
-
1042
- template <typename T, bool has_pos, bool has_freq_facs>
1043
- static void rope_neox (
1044
- const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
1045
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
1046
- const float * freq_factors, const sycl::nd_item<3 > &item_ct1) {
1047
- const int col = 2 * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
1048
- item_ct1.get_local_id (1 ));
1049
-
1050
- if (col >= ncols) {
1051
- return ;
1052
- }
1053
-
1054
- const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
1055
- item_ct1.get_local_id (2 );
1056
- const int ib = col / n_dims;
1057
- const int ic = col % n_dims;
1058
-
1059
- if (ib > 0 ) {
1060
- const int i = row*ncols + ib*n_dims + ic;
1061
-
1062
- dst[i + 0 ] = x[i + 0 ];
1063
- dst[i + 1 ] = x[i + 1 ];
1064
-
1065
- return ;
1066
- }
1067
-
1068
- const int i = row*ncols + ib*n_dims + ic/2 ;
1069
- const int i2 = row/p_delta_rows;
1070
-
1071
- float cur_rot = inv_ndims * ic - ib;
1072
-
1073
- const int p = has_pos ? pos[i2] : 0 ;
1074
- const float freq_factor = has_freq_facs ? freq_factors[ic/2 ] : 1 .0f ;
1075
-
1076
- const float theta_base =
1077
- p * freq_scale * dpct::pow (theta_scale, col / 2 .0f )/freq_factor;
1078
-
1079
- float cos_theta, sin_theta;
1080
- rope_yarn (theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1081
-
1082
- const float x0 = x[i + 0 ];
1083
- const float x1 = x[i + n_dims/2 ];
1084
-
1085
- dst[i + 0 ] = x0*cos_theta - x1*sin_theta;
1086
- dst[i + n_dims/2 ] = x0*sin_theta + x1*cos_theta;
1087
- }
1088
-
1089
981
static void k_sum_rows_f32 (const float * x, float * dst, const int ncols,
1090
982
const sycl::nd_item<3 > &item_ct1) {
1091
983
const int row = item_ct1.get_group (1 );
@@ -2241,110 +2133,6 @@ static void clamp_f32_sycl(const float *x, float *dst, const float min,
2241
2133
});
2242
2134
}
2243
2135
2244
- template <typename T>
2245
- static void rope_sycl (const T *x, T *dst, int ncols, int nrows,
2246
- const int32_t *pos, float freq_scale, int p_delta_rows,
2247
- float freq_base, float ext_factor, float attn_factor,
2248
- rope_corr_dims corr_dims, queue_ptr stream) {
2249
- GGML_ASSERT (ncols % 2 == 0 );
2250
- const sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
2251
- const int num_blocks_x = (ncols + 2 *SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 *SYCL_ROPE_BLOCK_SIZE);
2252
- const sycl::range<3 > block_nums (1 , num_blocks_x, nrows);
2253
- if (pos == nullptr ) {
2254
- /*
2255
- DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
2256
- the limit. To get the device limit, query
2257
- info::device::max_work_group_size. Adjust the work-group size if needed.
2258
- */
2259
- dpct::has_capability_or_fail (stream->get_device (),
2260
- {sycl::aspect::fp16});
2261
-
2262
- stream->parallel_for (
2263
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
2264
- [=](sycl::nd_item<3 > item_ct1) {
2265
- rope<T, false >(x, dst, ncols, pos, freq_scale, p_delta_rows,
2266
- freq_base, ext_factor, attn_factor, corr_dims,
2267
- item_ct1);
2268
- });
2269
- } else {
2270
- /*
2271
- DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
2272
- the limit. To get the device limit, query
2273
- info::device::max_work_group_size. Adjust the work-group size if needed.
2274
- */
2275
- dpct::has_capability_or_fail (stream->get_device (),
2276
- {sycl::aspect::fp16});
2277
-
2278
- stream->parallel_for (
2279
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
2280
- [=](sycl::nd_item<3 > item_ct1) {
2281
- rope<T, true >(x, dst, ncols, pos, freq_scale, p_delta_rows,
2282
- freq_base, ext_factor, attn_factor, corr_dims,
2283
- item_ct1);
2284
- });
2285
- }
2286
- }
2287
-
2288
- template <typename T>
2289
- static void rope_neox_sycl (const T *x, T *dst, int ncols, int n_dims, int nrows,
2290
- const int32_t *pos, float freq_scale,
2291
- int p_delta_rows, float freq_base, float ext_factor,
2292
- float attn_factor, rope_corr_dims corr_dims,
2293
- const float * freq_factors, queue_ptr stream) {
2294
- GGML_ASSERT (ncols % 2 == 0 );
2295
- const sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
2296
- const int num_blocks_x = (ncols + 2 *SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 *SYCL_ROPE_BLOCK_SIZE);
2297
- const sycl::range<3 > block_nums (1 , num_blocks_x, nrows);
2298
-
2299
- const float theta_scale = powf (freq_base, -2 .0f /n_dims);
2300
- const float inv_ndims = -1 .0f / n_dims;
2301
-
2302
- if (pos == nullptr ) {
2303
- dpct::has_capability_or_fail (stream->get_device (),
2304
- {sycl::aspect::fp16});
2305
- if (freq_factors == nullptr ) {
2306
- stream->parallel_for (
2307
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
2308
- [=](sycl::nd_item<3 > item_ct1) {
2309
- rope_neox<T, false , false >(x, dst, ncols, n_dims, pos, freq_scale,
2310
- p_delta_rows, ext_factor, attn_factor,
2311
- corr_dims, theta_scale, inv_ndims, freq_factors,
2312
- item_ct1);
2313
- });
2314
- } else {
2315
- stream->parallel_for (
2316
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
2317
- [=](sycl::nd_item<3 > item_ct1) {
2318
- rope_neox<T, false , true >(x, dst, ncols, n_dims, pos, freq_scale,
2319
- p_delta_rows, ext_factor, attn_factor,
2320
- corr_dims, theta_scale, inv_ndims, freq_factors,
2321
- item_ct1);
2322
- });
2323
- }
2324
- } else {
2325
- dpct::has_capability_or_fail (stream->get_device (),
2326
- {sycl::aspect::fp16});
2327
-
2328
- if (freq_factors == nullptr ) {
2329
- stream->parallel_for (
2330
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
2331
- [=](sycl::nd_item<3 > item_ct1) {
2332
- rope_neox<T, true , false >(x, dst, ncols, n_dims, pos, freq_scale,
2333
- p_delta_rows, ext_factor, attn_factor,
2334
- corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
2335
- });
2336
- } else {
2337
- stream->parallel_for (
2338
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
2339
- [=](sycl::nd_item<3 > item_ct1) {
2340
- rope_neox<T, true , true >(x, dst, ncols, n_dims, pos, freq_scale,
2341
- p_delta_rows, ext_factor, attn_factor,
2342
- corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
2343
- });
2344
- }
2345
- }
2346
- }
2347
-
2348
2136
static void sum_rows_f32_sycl (const float *x, float *dst, const int ncols,
2349
2137
const int nrows, queue_ptr stream) {
2350
2138
const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE);
@@ -3461,97 +3249,6 @@ catch (sycl::exception const &exc) {
3461
3249
std::exit (1 );
3462
3250
}
3463
3251
3464
- inline void ggml_sycl_op_rope (ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3465
- ggml_tensor *dst, const float *src0_dd,
3466
- const float *src1_dd, float *dst_dd,
3467
- const queue_ptr &main_stream) {
3468
- const ggml_tensor * src2 = dst->src [2 ];
3469
-
3470
- GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3471
- GGML_ASSERT ( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
3472
- GGML_ASSERT (src0->type == dst->type );
3473
-
3474
- const int64_t ne00 = src0->ne [0 ];
3475
- const int64_t ne01 = src0->ne [1 ];
3476
- const int64_t ne2 = dst->ne [2 ];
3477
- const int64_t nrows = ggml_nrows (src0);
3478
-
3479
- // const int n_past = ((int32_t *) dst->op_params)[0];
3480
- const int n_dims = ((int32_t *) dst->op_params )[1 ];
3481
- const int mode = ((int32_t *) dst->op_params )[2 ];
3482
- // const int n_ctx = ((int32_t *) dst->op_params)[3];
3483
- const int n_ctx_orig = ((int32_t *) dst->op_params )[4 ];
3484
-
3485
- // RoPE alteration for extended context
3486
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
3487
- memcpy (&freq_base, (int32_t *) dst->op_params + 5 , sizeof (float ));
3488
- memcpy (&freq_scale, (int32_t *) dst->op_params + 6 , sizeof (float ));
3489
- memcpy (&ext_factor, (int32_t *) dst->op_params + 7 , sizeof (float ));
3490
- memcpy (&attn_factor, (int32_t *) dst->op_params + 8 , sizeof (float ));
3491
- memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
3492
- memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
3493
-
3494
- const float * freq_factors = nullptr ;
3495
- const int32_t * pos = nullptr ;
3496
- if ((mode & 1 ) == 0 ) {
3497
- GGML_ASSERT (src1->type == GGML_TYPE_I32);
3498
- GGML_ASSERT (src1->ne [0 ] == ne2);
3499
- pos = (const int32_t *) src1_dd;
3500
- }
3501
-
3502
- const bool is_neox = mode & 2 ;
3503
-
3504
- #pragma message("TODO: update rope NORM mode to match NEOX mode")
3505
- #pragma message(" https:// github.com/ggerganov/llama.cpp/pull/7634")
3506
-
3507
- if (is_neox) {
3508
- pos = (const int32_t *) src1_dd;
3509
-
3510
- if (src2 != nullptr ) {
3511
- freq_factors = (const float *) src2->data ;
3512
- }
3513
- } else {
3514
- GGML_ASSERT (src2 == nullptr && " TODO: freq_factors not implemented for !is_neox" );
3515
- }
3516
-
3517
- rope_corr_dims corr_dims;
3518
- ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v );
3519
-
3520
- // compute
3521
- if (is_neox) {
3522
- if (src0->type == GGML_TYPE_F32) {
3523
- rope_neox_sycl (
3524
- (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
3525
- attn_factor, corr_dims, freq_factors, main_stream
3526
- );
3527
- } else if (src0->type == GGML_TYPE_F16) {
3528
- rope_neox_sycl ((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
3529
- ne00, n_dims, nrows, pos, freq_scale, ne01,
3530
- freq_base, ext_factor, attn_factor, corr_dims,
3531
- freq_factors, main_stream);
3532
- } else {
3533
- GGML_ASSERT (false );
3534
- }
3535
- } else {
3536
- if (src0->type == GGML_TYPE_F32) {
3537
- rope_sycl (
3538
- (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
3539
- attn_factor, corr_dims, main_stream
3540
- );
3541
- } else if (src0->type == GGML_TYPE_F16) {
3542
- rope_sycl ((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00,
3543
- nrows, pos, freq_scale, ne01, freq_base, ext_factor,
3544
- attn_factor, corr_dims, main_stream);
3545
- } else {
3546
- GGML_ASSERT (false );
3547
- }
3548
- }
3549
-
3550
- (void ) src1;
3551
- (void ) dst;
3552
- (void ) src1_dd;
3553
- }
3554
-
3555
3252
static void ggml_sycl_op_pool2d (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3556
3253
const ggml_tensor *src1, ggml_tensor *dst,
3557
3254
const float *src0_dd, const float *src1_dd,
@@ -6241,7 +5938,9 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
6241
5938
case GGML_OP_CONT:
6242
5939
case GGML_OP_DIAG_MASK_INF:
6243
5940
case GGML_OP_SOFT_MAX:
5941
+ return true ;
6244
5942
case GGML_OP_ROPE:
5943
+ return ggml_is_contiguous (op->src [0 ]);
6245
5944
case GGML_OP_IM2COL:
6246
5945
case GGML_OP_POOL_2D:
6247
5946
case GGML_OP_SUM_ROWS:
0 commit comments