Skip to content

Commit 228f34c

Browse files
authored
SYCL: Implement few same quantized type copy kernels (#13739)
* SYCL: Implement few same quantized type copy kernels * Use memcpy for copying contiguous tensors ggml-ci * feat(sycl): add contiguous tensor copy support and device checks Adds a memcpy path for contiguous tensors of the same type to optimize data transfer. Updates device support checks to recognize contiguous tensor operations, improving compatibility and performance. * refactor: replace specific block copy functions with template The changes replace multiple redundant block copy functions (e.g., cpy_block_q8_0_q8_0, cpy_block_q5_0_q5_0) with a single templated function cpy_blck_q_q. This reduces code duplication by using a generic template that works for any block type, improving maintainability while preserving the same functionality. The template is instantiated with specific block types (e.g., block_q8_0) where needed. * Exclude BF16 support for COPY tensors for now ggml-ci * perf: adjust SYCL copy kernel block sizes for efficiency Use ceil_div to ensure full element coverage and update nd_range parameters to better align with SYCL block sizes, improving parallelism and device utilization in copy operations.
1 parent 0974ad7 commit 228f34c

File tree

2 files changed

+138
-2
lines changed

2 files changed

+138
-2
lines changed

ggml/src/ggml-sycl/cpy.cpp

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
#include "cpy.hpp"
22

33
#include <float.h>
4+
#include <string>
45

56
#include "dequantize.hpp"
7+
#include "ggml-sycl/common.hpp"
8+
#include "ggml-sycl/presets.hpp"
9+
#include "ggml.h"
610

711
static __dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) {
812
if (x <= val[0]) {
@@ -116,6 +120,15 @@ static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
116120
}
117121
}
118122

123+
/* quantized type same copy */
124+
template<typename T>
125+
static void cpy_blck_q_q(const char * cxi, char * cdsti) {
126+
const T * xi = (const T *) cxi;
127+
T * dsti = (T *) cdsti;
128+
*dsti = *xi;
129+
}
130+
131+
119132
static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
120133
float * cdstf = (float *) (cdsti);
121134

@@ -311,6 +324,34 @@ template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const
311324
}
312325
}
313326

327+
328+
template <typename T, int qk>
329+
static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
330+
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
331+
const int ne12, const int nb10, const int nb11, const int nb12, const int nb13,
332+
const sycl::nd_item<3> & item_ct1) {
333+
const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk;
334+
335+
if (i >= ne) {
336+
return;
337+
}
338+
339+
const int i03 = i / (ne00 * ne01 * ne02);
340+
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
341+
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
342+
const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00;
343+
const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03;
344+
345+
346+
const int i13 = i / (ne10 * ne11 * ne12);
347+
const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11);
348+
const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10;
349+
const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10;
350+
const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13;
351+
352+
cpy_blck_q_q<T>(cx + x_offset, cdst + dst_offset);
353+
}
354+
314355
template <cpy_kernel_t cpy_blck, int qk>
315356
static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02,
316357
const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11,
@@ -322,6 +363,7 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00
322363
return;
323364
}
324365

366+
325367
const int i03 = i / (ne00 * ne01 * ne02);
326368
const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
327369
const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00;
@@ -615,6 +657,70 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
615657
}
616658
}
617659

660+
static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
661+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
662+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
663+
const int nb12, const int nb13, queue_ptr stream) {
664+
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
665+
stream->parallel_for(
666+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
667+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
668+
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
669+
});
670+
}
671+
672+
673+
static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
674+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
675+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
676+
const int nb12, const int nb13, queue_ptr stream) {
677+
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
678+
stream->parallel_for(
679+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
680+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
681+
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
682+
});
683+
}
684+
685+
686+
static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
687+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
688+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
689+
const int nb12, const int nb13, queue_ptr stream) {
690+
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
691+
692+
stream->parallel_for(
693+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
694+
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
695+
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
696+
});
697+
}
698+
699+
700+
static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
701+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
702+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
703+
const int nb12, const int nb13, queue_ptr stream) {
704+
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
705+
stream->parallel_for(
706+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
707+
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
708+
});
709+
}
710+
711+
712+
static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
713+
const int ne02, const int nb00, const int nb01, const int nb02, const int nb03,
714+
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
715+
const int nb12, const int nb13, queue_ptr stream) {
716+
717+
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
718+
stream->parallel_for(
719+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
720+
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
721+
});
722+
}
723+
618724
void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {
619725
// Unlike other operators ggml_sycl_cpy takes 2 distinct tensors instead of a dst ggml_tensor and rely on its src field
620726
scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0,
@@ -632,8 +738,10 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
632738

633739
char * src0_ddc = (char *) src0->data;
634740
char * src1_ddc = (char *) src1->data;
635-
636-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
741+
if ((src0->type == src1->type) && (ggml_is_contiguous(src0) && ggml_is_contiguous(src1))) {
742+
GGML_SYCL_DEBUG("%s: memcpy path\n", __func__);
743+
main_stream->memcpy(src1_ddc, src0_ddc, ggml_nbytes(src0));
744+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
637745
ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10,
638746
nb11, nb12, nb13, main_stream);
639747
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
@@ -684,6 +792,16 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
684792
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
685793
ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12,
686794
nb10, nb11, nb12, nb13, main_stream);
795+
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
796+
ggml_cpy_q8_0_q8_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
797+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_Q5_0) {
798+
ggml_cpy_q5_0_q5_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
799+
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_Q5_1) {
800+
ggml_cpy_q5_1_q5_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
801+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_Q4_0) {
802+
ggml_cpy_q4_0_q4_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
803+
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_Q4_1) {
804+
ggml_cpy_q4_1_q4_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
687805
} else {
688806
GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type),
689807
ggml_type_name(src1->type));

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4226,6 +4226,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
42264226
{
42274227
ggml_type src0_type = op->src[0]->type;
42284228
ggml_type src1_type = op->src[1]->type;
4229+
if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) {
4230+
return true;
4231+
}
42294232
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
42304233
return true;
42314234
}
@@ -4271,6 +4274,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
42714274
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
42724275
return true;
42734276
}
4277+
if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) {
4278+
return true;
4279+
}
4280+
if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) {
4281+
return true;
4282+
}
4283+
if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) {
4284+
return true;
4285+
}
4286+
if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) {
4287+
return true;
4288+
}
4289+
if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) {
4290+
return true;
4291+
}
42744292
return false;
42754293
}
42764294
case GGML_OP_CONCAT:

0 commit comments

Comments
 (0)