@@ -3154,6 +3154,7 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
3154
3154
#define SYCL_SCALE_BLOCK_SIZE 256
3155
3155
#define SYCL_CLAMP_BLOCK_SIZE 256
3156
3156
#define SYCL_ROPE_BLOCK_SIZE 256
3157
+ #define SYCL_SOFT_MAX_BLOCK_SIZE 1024
3157
3158
#define SYCL_ALIBI_BLOCK_SIZE 32
3158
3159
#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
3159
3160
#define SYCL_QUANTIZE_BLOCK_SIZE 256
@@ -13079,13 +13080,11 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13079
13080
const int nrows_y, const float scale, const float max_bias,
13080
13081
dpct::queue_ptr stream) {
13081
13082
int nth = WARP_SIZE;
13082
- int max_block_size = g_work_group_size;
13083
- while (nth < ncols_x && nth < max_block_size) nth *= 2;
13084
- if (nth>max_block_size) nth = max_block_size;
13085
-
13083
+ while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
13086
13084
const sycl::range<3> block_dims(1, 1, nth);
13087
13085
const sycl::range<3> block_nums(1, 1, nrows_x);
13088
13086
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
13087
+ static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
13089
13088
13090
13089
const uint32_t n_head_kv = nrows_x/nrows_y;
13091
13090
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
@@ -13095,12 +13094,6 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13095
13094
13096
13095
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
13097
13096
if (n_local_scratch*sizeof(float) < local_mem_size) {
13098
- if (ncols_x > max_block_size) {
13099
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13100
- max_bias, m0, m1, n_head_log2, block_nums,
13101
- block_dims, n_local_scratch, stream);
13102
- return;
13103
- }
13104
13097
switch (ncols_x) {
13105
13098
case 32:
13106
13099
soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
@@ -16825,13 +16818,11 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
16825
16818
const dpct::queue_ptr stream = g_syclStreams[ctx->device][0];
16826
16819
SYCL_CHECK(
16827
16820
CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
16828
- char* host_buf = (char*)malloc(size);
16829
- memcpy(host_buf, data, size);
16821
+
16830
16822
SYCL_CHECK(
16831
16823
CHECK_TRY_ERROR((*stream)
16832
- .memcpy((char *)tensor->data + offset, host_buf , size)
16824
+ .memcpy((char *)tensor->data + offset, data , size)
16833
16825
.wait()));
16834
- free(host_buf);
16835
16826
}
16836
16827
catch (sycl::exception const &exc) {
16837
16828
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
0 commit comments